mirror of
https://github.com/trycua/computer.git
synced 2026-01-03 12:00:00 -06:00
Reorganize lib folder w/typescript and python roots, initialize core library.
This commit is contained in:
144
libs/python/computer/README.md
Normal file
144
libs/python/computer/README.md
Normal file
@@ -0,0 +1,144 @@
|
||||
<div align="center">
|
||||
<h1>
|
||||
<div class="image-wrapper" style="display: inline-block;">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="../../img/logo_white.png" style="display: block; margin: auto;">
|
||||
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="../../img/logo_black.png" style="display: block; margin: auto;">
|
||||
<img alt="Shows my svg">
|
||||
</picture>
|
||||
</div>
|
||||
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/cua-computer/)
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
**cua-computer** is a Computer-Use Interface (CUI) framework powering Cua for interacting with local macOS and Linux sandboxes, PyAutoGUI-compatible, and pluggable with any AI agent systems (Cua, Langchain, CrewAI, AutoGen). Computer relies on [Lume](https://github.com/trycua/lume) for creating and managing sandbox environments.
|
||||
|
||||
### Get started with Computer
|
||||
|
||||
<div align="center">
|
||||
<img src="../../img/computer.png"/>
|
||||
</div>
|
||||
|
||||
```python
|
||||
from computer import Computer
|
||||
|
||||
computer = Computer(os_type="macos", display="1024x768", memory="8GB", cpu="4")
|
||||
try:
|
||||
await computer.run()
|
||||
|
||||
screenshot = await computer.interface.screenshot()
|
||||
with open("screenshot.png", "wb") as f:
|
||||
f.write(screenshot)
|
||||
|
||||
await computer.interface.move_cursor(100, 100)
|
||||
await computer.interface.left_click()
|
||||
await computer.interface.right_click(300, 300)
|
||||
await computer.interface.double_click(400, 400)
|
||||
|
||||
await computer.interface.type("Hello, World!")
|
||||
await computer.interface.press_key("enter")
|
||||
|
||||
await computer.interface.set_clipboard("Test clipboard")
|
||||
content = await computer.interface.copy_to_clipboard()
|
||||
print(f"Clipboard content: {content}")
|
||||
finally:
|
||||
await computer.stop()
|
||||
```
|
||||
|
||||
## Install
|
||||
|
||||
To install the Computer-Use Interface (CUI):
|
||||
|
||||
```bash
|
||||
pip install "cua-computer[all]"
|
||||
```
|
||||
|
||||
The `cua-computer` PyPi package pulls automatically the latest executable version of Lume through [pylume](https://github.com/trycua/pylume).
|
||||
|
||||
## Run
|
||||
|
||||
Refer to this notebook for a step-by-step guide on how to use the Computer-Use Interface (CUI):
|
||||
|
||||
- [Computer-Use Interface (CUI)](../../notebooks/computer_nb.ipynb)
|
||||
|
||||
## Using the Gradio Computer UI
|
||||
|
||||
The computer module includes a Gradio UI for creating and sharing demonstration data. We make it easy for people to build community datasets for better computer use models with an upload to Huggingface feature.
|
||||
|
||||
```bash
|
||||
# Install with UI support
|
||||
pip install "cua-computer[ui]"
|
||||
```
|
||||
|
||||
> **Note:** For precise control of the computer, we recommend using VNC or Screen Sharing instead of the Computer Gradio UI.
|
||||
|
||||
### Building and Sharing Demonstrations with Huggingface
|
||||
|
||||
Follow these steps to contribute your own demonstrations:
|
||||
|
||||
#### 1. Set up Huggingface Access
|
||||
|
||||
Set your HF_TOKEN in a .env file or in your environment variables:
|
||||
|
||||
```bash
|
||||
# In .env file
|
||||
HF_TOKEN=your_huggingface_token
|
||||
```
|
||||
|
||||
#### 2. Launch the Computer UI
|
||||
|
||||
```python
|
||||
# launch_ui.py
|
||||
from computer.ui.gradio.app import create_gradio_ui
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv('.env')
|
||||
|
||||
app = create_gradio_ui()
|
||||
app.launch(share=False)
|
||||
```
|
||||
|
||||
For examples, see [Computer UI Examples](../../examples/computer_ui_examples.py)
|
||||
|
||||
#### 3. Record Your Tasks
|
||||
|
||||
<details open>
|
||||
<summary>View demonstration video</summary>
|
||||
<video src="https://github.com/user-attachments/assets/de3c3477-62fe-413c-998d-4063e48de176" controls width="600"></video>
|
||||
</details>
|
||||
|
||||
Record yourself performing various computer tasks using the UI.
|
||||
|
||||
#### 4. Save Your Demonstrations
|
||||
|
||||
<details open>
|
||||
<summary>View demonstration video</summary>
|
||||
<video src="https://github.com/user-attachments/assets/5ad1df37-026a-457f-8b49-922ae805faef" controls width="600"></video>
|
||||
</details>
|
||||
|
||||
Save each task by picking a descriptive name and adding relevant tags (e.g., "office", "web-browsing", "coding").
|
||||
|
||||
#### 5. Record Additional Demonstrations
|
||||
|
||||
Repeat steps 3 and 4 until you have a good amount of demonstrations covering different tasks and scenarios.
|
||||
|
||||
#### 6. Upload to Huggingface
|
||||
|
||||
<details open>
|
||||
<summary>View demonstration video</summary>
|
||||
<video src="https://github.com/user-attachments/assets/c586d460-3877-4b5f-a736-3248886d2134" controls width="600"></video>
|
||||
</details>
|
||||
|
||||
Upload your dataset to Huggingface by:
|
||||
- Naming it as `{your_username}/{dataset_name}`
|
||||
- Choosing public or private visibility
|
||||
- Optionally selecting specific tags to upload only tasks with certain tags
|
||||
|
||||
#### Examples and Resources
|
||||
|
||||
- Example Dataset: [ddupont/test-dataset](https://huggingface.co/datasets/ddupont/test-dataset)
|
||||
- Find Community Datasets: 🔍 [Browse CUA Datasets on Huggingface](https://huggingface.co/datasets?other=cua)
|
||||
|
||||
51
libs/python/computer/computer/__init__.py
Normal file
51
libs/python/computer/computer/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""CUA Computer Interface for cross-platform computer control."""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
# Initialize logging
|
||||
logger = logging.getLogger("computer")
|
||||
|
||||
# Initialize telemetry when the package is imported
|
||||
try:
|
||||
# Import from core telemetry
|
||||
from core.telemetry import (
|
||||
flush,
|
||||
is_telemetry_enabled,
|
||||
record_event,
|
||||
)
|
||||
|
||||
# Check if telemetry is enabled
|
||||
if is_telemetry_enabled():
|
||||
logger.info("Telemetry is enabled")
|
||||
|
||||
# Record package initialization
|
||||
record_event(
|
||||
"module_init",
|
||||
{
|
||||
"module": "computer",
|
||||
"version": __version__,
|
||||
"python_version": sys.version,
|
||||
},
|
||||
)
|
||||
|
||||
# Flush events to ensure they're sent
|
||||
flush()
|
||||
else:
|
||||
logger.info("Telemetry is disabled")
|
||||
except ImportError as e:
|
||||
# Telemetry not available
|
||||
logger.warning(f"Telemetry not available: {e}")
|
||||
except Exception as e:
|
||||
# Other issues with telemetry
|
||||
logger.warning(f"Error initializing telemetry: {e}")
|
||||
|
||||
# Core components
|
||||
from .computer import Computer
|
||||
|
||||
# Provider components
|
||||
from .providers.base import VMProviderType
|
||||
|
||||
__all__ = ["Computer", "VMProviderType"]
|
||||
926
libs/python/computer/computer/computer.py
Normal file
926
libs/python/computer/computer/computer.py
Normal file
@@ -0,0 +1,926 @@
|
||||
from typing import Optional, List, Literal, Dict, Any, Union, TYPE_CHECKING, cast
|
||||
import asyncio
|
||||
from .models import Computer as ComputerConfig, Display
|
||||
from .interface.factory import InterfaceFactory
|
||||
import time
|
||||
from PIL import Image
|
||||
import io
|
||||
import re
|
||||
from .logger import Logger, LogLevel
|
||||
import json
|
||||
import logging
|
||||
from .telemetry import record_computer_initialization
|
||||
import os
|
||||
from . import helpers
|
||||
|
||||
# Import provider related modules
|
||||
from .providers.base import VMProviderType
|
||||
from .providers.factory import VMProviderFactory
|
||||
|
||||
OSType = Literal["macos", "linux", "windows"]
|
||||
|
||||
class Computer:
|
||||
"""Computer is the main class for interacting with the computer."""
|
||||
|
||||
def create_desktop_from_apps(self, apps):
|
||||
"""
|
||||
Create a virtual desktop from a list of app names, returning a DioramaComputer
|
||||
that proxies Diorama.Interface but uses diorama_cmds via the computer interface.
|
||||
|
||||
Args:
|
||||
apps (list[str]): List of application names to include in the desktop.
|
||||
Returns:
|
||||
DioramaComputer: A proxy object with the Diorama interface, but using diorama_cmds.
|
||||
"""
|
||||
assert "app-use" in self.experiments, "App Usage is an experimental feature. Enable it by passing experiments=['app-use'] to Computer()"
|
||||
from .diorama_computer import DioramaComputer
|
||||
return DioramaComputer(self, apps)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
display: Union[Display, Dict[str, int], str] = "1024x768",
|
||||
memory: str = "8GB",
|
||||
cpu: str = "4",
|
||||
os_type: OSType = "macos",
|
||||
name: str = "",
|
||||
image: str = "macos-sequoia-cua:latest",
|
||||
shared_directories: Optional[List[str]] = None,
|
||||
use_host_computer_server: bool = False,
|
||||
verbosity: Union[int, LogLevel] = logging.INFO,
|
||||
telemetry_enabled: bool = True,
|
||||
provider_type: Union[str, VMProviderType] = VMProviderType.LUME,
|
||||
port: Optional[int] = 7777,
|
||||
noVNC_port: Optional[int] = 8006,
|
||||
host: str = os.environ.get("PYLUME_HOST", "localhost"),
|
||||
storage: Optional[str] = None,
|
||||
ephemeral: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
experiments: Optional[List[str]] = None
|
||||
):
|
||||
"""Initialize a new Computer instance.
|
||||
|
||||
Args:
|
||||
display: The display configuration. Can be:
|
||||
- A Display object
|
||||
- A dict with 'width' and 'height'
|
||||
- A string in format "WIDTHxHEIGHT" (e.g. "1920x1080")
|
||||
Defaults to "1024x768"
|
||||
memory: The VM memory allocation. Defaults to "8GB"
|
||||
cpu: The VM CPU allocation. Defaults to "4"
|
||||
os_type: The operating system type ('macos' or 'linux')
|
||||
name: The VM name
|
||||
image: The VM image name
|
||||
shared_directories: Optional list of directory paths to share with the VM
|
||||
use_host_computer_server: If True, target localhost instead of starting a VM
|
||||
verbosity: Logging level (standard Python logging levels: logging.DEBUG, logging.INFO, etc.)
|
||||
LogLevel enum values are still accepted for backward compatibility
|
||||
telemetry_enabled: Whether to enable telemetry tracking. Defaults to True.
|
||||
provider_type: The VM provider type to use (lume, qemu, cloud)
|
||||
port: Optional port to use for the VM provider server
|
||||
noVNC_port: Optional port for the noVNC web interface (Lumier provider)
|
||||
host: Host to use for VM provider connections (e.g. "localhost", "host.docker.internal")
|
||||
storage: Optional path for persistent VM storage (Lumier provider)
|
||||
ephemeral: Whether to use ephemeral storage
|
||||
api_key: Optional API key for cloud providers
|
||||
experiments: Optional list of experimental features to enable (e.g. ["app-use"])
|
||||
"""
|
||||
|
||||
self.logger = Logger("computer", verbosity)
|
||||
self.logger.info("Initializing Computer...")
|
||||
|
||||
# Store original parameters
|
||||
self.image = image
|
||||
self.port = port
|
||||
self.noVNC_port = noVNC_port
|
||||
self.host = host
|
||||
self.os_type = os_type
|
||||
self.provider_type = provider_type
|
||||
self.ephemeral = ephemeral
|
||||
|
||||
self.api_key = api_key
|
||||
self.experiments = experiments or []
|
||||
|
||||
if "app-use" in self.experiments:
|
||||
assert self.os_type == "macos", "App use experiment is only supported on macOS"
|
||||
|
||||
# The default is currently to use non-ephemeral storage
|
||||
if storage and ephemeral and storage != "ephemeral":
|
||||
raise ValueError("Storage path and ephemeral flag cannot be used together")
|
||||
|
||||
# Windows Sandbox always uses ephemeral storage
|
||||
if self.provider_type == VMProviderType.WINSANDBOX:
|
||||
if not ephemeral and storage != None and storage != "ephemeral":
|
||||
self.logger.warning("Windows Sandbox storage is always ephemeral. Setting ephemeral=True.")
|
||||
self.ephemeral = True
|
||||
self.storage = "ephemeral"
|
||||
else:
|
||||
self.storage = "ephemeral" if ephemeral else storage
|
||||
|
||||
# For Lumier provider, store the first shared directory path to use
|
||||
# for VM file sharing
|
||||
self.shared_path = None
|
||||
if shared_directories and len(shared_directories) > 0:
|
||||
self.shared_path = shared_directories[0]
|
||||
self.logger.info(f"Using first shared directory for VM file sharing: {self.shared_path}")
|
||||
|
||||
# Store telemetry preference
|
||||
self._telemetry_enabled = telemetry_enabled
|
||||
|
||||
# Set initialization flag
|
||||
self._initialized = False
|
||||
self._running = False
|
||||
|
||||
# Configure root logger
|
||||
self.verbosity = verbosity
|
||||
self.logger = Logger("computer", verbosity)
|
||||
|
||||
# Configure component loggers with proper hierarchy
|
||||
self.vm_logger = Logger("computer.vm", verbosity)
|
||||
self.interface_logger = Logger("computer.interface", verbosity)
|
||||
|
||||
if not use_host_computer_server:
|
||||
if ":" not in image or len(image.split(":")) != 2:
|
||||
raise ValueError("Image must be in the format <image_name>:<tag>")
|
||||
|
||||
if not name:
|
||||
# Normalize the name to be used for the VM
|
||||
name = image.replace(":", "_")
|
||||
|
||||
# Convert display parameter to Display object
|
||||
if isinstance(display, str):
|
||||
# Parse string format "WIDTHxHEIGHT"
|
||||
match = re.match(r"(\d+)x(\d+)", display)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
"Display string must be in format 'WIDTHxHEIGHT' (e.g. '1024x768')"
|
||||
)
|
||||
width, height = map(int, match.groups())
|
||||
display_config = Display(width=width, height=height)
|
||||
elif isinstance(display, dict):
|
||||
display_config = Display(**display)
|
||||
else:
|
||||
display_config = display
|
||||
|
||||
self.config = ComputerConfig(
|
||||
image=image.split(":")[0],
|
||||
tag=image.split(":")[1],
|
||||
name=name,
|
||||
display=display_config,
|
||||
memory=memory,
|
||||
cpu=cpu,
|
||||
)
|
||||
# Initialize VM provider but don't start it yet - we'll do that in run()
|
||||
self.config.vm_provider = None # Will be initialized in run()
|
||||
|
||||
# Store shared directories config
|
||||
self.shared_directories = shared_directories or []
|
||||
|
||||
# Placeholder for VM provider context manager
|
||||
self._provider_context = None
|
||||
|
||||
# Initialize with proper typing - None at first, will be set in run()
|
||||
self._interface = None
|
||||
self.use_host_computer_server = use_host_computer_server
|
||||
|
||||
# Record initialization in telemetry (if enabled)
|
||||
if telemetry_enabled:
|
||||
record_computer_initialization()
|
||||
else:
|
||||
self.logger.debug("Telemetry disabled - skipping initialization tracking")
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Start the computer."""
|
||||
await self.run()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Stop the computer."""
|
||||
await self.disconnect()
|
||||
|
||||
def __enter__(self):
|
||||
"""Start the computer."""
|
||||
# Run the event loop to call the async enter method
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.__aenter__())
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Stop the computer."""
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.__aexit__(exc_type, exc_val, exc_tb))
|
||||
|
||||
async def run(self) -> Optional[str]:
|
||||
"""Initialize the VM and computer interface."""
|
||||
if TYPE_CHECKING:
|
||||
from .interface.base import BaseComputerInterface
|
||||
|
||||
# If already initialized, just log and return
|
||||
if hasattr(self, "_initialized") and self._initialized:
|
||||
self.logger.info("Computer already initialized, skipping initialization")
|
||||
return
|
||||
|
||||
self.logger.info("Starting computer...")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# If using host computer server
|
||||
if self.use_host_computer_server:
|
||||
self.logger.info("Using host computer server")
|
||||
# Set ip_address for host computer server mode
|
||||
ip_address = "localhost"
|
||||
# Create the interface with explicit type annotation
|
||||
from .interface.base import BaseComputerInterface
|
||||
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type, ip_address=ip_address # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
|
||||
self.logger.info("Waiting for host computer server to be ready...")
|
||||
await self._interface.wait_for_ready()
|
||||
self.logger.info("Host computer server ready")
|
||||
else:
|
||||
# Start or connect to VM
|
||||
self.logger.info(f"Starting VM: {self.image}")
|
||||
if not self._provider_context:
|
||||
try:
|
||||
provider_type_name = self.provider_type.name if isinstance(self.provider_type, VMProviderType) else self.provider_type
|
||||
self.logger.verbose(f"Initializing {provider_type_name} provider context...")
|
||||
|
||||
# Explicitly set provider parameters
|
||||
storage = "ephemeral" if self.ephemeral else self.storage
|
||||
verbose = self.verbosity >= LogLevel.DEBUG
|
||||
ephemeral = self.ephemeral
|
||||
port = self.port if self.port is not None else 7777
|
||||
host = self.host if self.host else "localhost"
|
||||
image = self.image
|
||||
shared_path = self.shared_path
|
||||
noVNC_port = self.noVNC_port
|
||||
|
||||
# Create VM provider instance with explicit parameters
|
||||
try:
|
||||
if self.provider_type == VMProviderType.LUMIER:
|
||||
self.logger.info(f"Using VM image for Lumier provider: {image}")
|
||||
if shared_path:
|
||||
self.logger.info(f"Using shared path for Lumier provider: {shared_path}")
|
||||
if noVNC_port:
|
||||
self.logger.info(f"Using noVNC port for Lumier provider: {noVNC_port}")
|
||||
self.config.vm_provider = VMProviderFactory.create_provider(
|
||||
self.provider_type,
|
||||
port=port,
|
||||
host=host,
|
||||
storage=storage,
|
||||
shared_path=shared_path,
|
||||
image=image,
|
||||
verbose=verbose,
|
||||
ephemeral=ephemeral,
|
||||
noVNC_port=noVNC_port,
|
||||
)
|
||||
elif self.provider_type == VMProviderType.LUME:
|
||||
self.config.vm_provider = VMProviderFactory.create_provider(
|
||||
self.provider_type,
|
||||
port=port,
|
||||
host=host,
|
||||
storage=storage,
|
||||
verbose=verbose,
|
||||
ephemeral=ephemeral,
|
||||
)
|
||||
elif self.provider_type == VMProviderType.CLOUD:
|
||||
self.config.vm_provider = VMProviderFactory.create_provider(
|
||||
self.provider_type,
|
||||
api_key=self.api_key,
|
||||
verbose=verbose,
|
||||
)
|
||||
elif self.provider_type == VMProviderType.WINSANDBOX:
|
||||
self.config.vm_provider = VMProviderFactory.create_provider(
|
||||
self.provider_type,
|
||||
port=port,
|
||||
host=host,
|
||||
storage=storage,
|
||||
verbose=verbose,
|
||||
ephemeral=ephemeral,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider type: {self.provider_type}")
|
||||
self._provider_context = await self.config.vm_provider.__aenter__()
|
||||
self.logger.verbose("VM provider context initialized successfully")
|
||||
except ImportError as ie:
|
||||
self.logger.error(f"Failed to import provider dependencies: {ie}")
|
||||
if str(ie).find("lume") >= 0 and str(ie).find("lumier") < 0:
|
||||
self.logger.error("Please install with: pip install cua-computer[lume]")
|
||||
elif str(ie).find("lumier") >= 0 or str(ie).find("docker") >= 0:
|
||||
self.logger.error("Please install with: pip install cua-computer[lumier] and make sure Docker is installed")
|
||||
elif str(ie).find("cloud") >= 0:
|
||||
self.logger.error("Please install with: pip install cua-computer[cloud]")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize provider context: {e}")
|
||||
raise RuntimeError(f"Failed to initialize VM provider: {e}")
|
||||
|
||||
# Check if VM exists or create it
|
||||
is_running = False
|
||||
try:
|
||||
if self.config.vm_provider is None:
|
||||
raise RuntimeError(f"VM provider not initialized for {self.config.name}")
|
||||
|
||||
vm = await self.config.vm_provider.get_vm(self.config.name)
|
||||
self.logger.verbose(f"Found existing VM: {self.config.name}")
|
||||
is_running = vm.get("status") == "running"
|
||||
except Exception as e:
|
||||
self.logger.error(f"VM not found: {self.config.name}")
|
||||
self.logger.error(f"Error: {e}")
|
||||
raise RuntimeError(
|
||||
f"VM {self.config.name} could not be found or created."
|
||||
)
|
||||
|
||||
# Start the VM if it's not running
|
||||
if not is_running:
|
||||
self.logger.info(f"VM {self.config.name} is not running, starting it...")
|
||||
|
||||
# Convert paths to dictionary format for shared directories
|
||||
shared_dirs = []
|
||||
for path in self.shared_directories:
|
||||
self.logger.verbose(f"Adding shared directory: {path}")
|
||||
path = os.path.abspath(os.path.expanduser(path))
|
||||
if os.path.exists(path):
|
||||
# Add path in format expected by Lume API
|
||||
shared_dirs.append({
|
||||
"hostPath": path,
|
||||
"readOnly": False
|
||||
})
|
||||
else:
|
||||
self.logger.warning(f"Shared directory does not exist: {path}")
|
||||
|
||||
# Prepare run options to pass to the provider
|
||||
run_opts = {}
|
||||
|
||||
# Add display information if available
|
||||
if self.config.display is not None:
|
||||
display_info = {
|
||||
"width": self.config.display.width,
|
||||
"height": self.config.display.height,
|
||||
}
|
||||
|
||||
# Check if scale_factor exists before adding it
|
||||
if hasattr(self.config.display, "scale_factor"):
|
||||
display_info["scale_factor"] = self.config.display.scale_factor
|
||||
|
||||
run_opts["display"] = display_info
|
||||
|
||||
# Add shared directories if available
|
||||
if self.shared_directories:
|
||||
run_opts["shared_directories"] = shared_dirs.copy()
|
||||
|
||||
# Run the VM with the provider
|
||||
try:
|
||||
if self.config.vm_provider is None:
|
||||
raise RuntimeError(f"VM provider not initialized for {self.config.name}")
|
||||
|
||||
# Use the complete run_opts we prepared earlier
|
||||
# Handle ephemeral storage for run_vm method too
|
||||
storage_param = "ephemeral" if self.ephemeral else self.storage
|
||||
|
||||
# Log the image being used
|
||||
self.logger.info(f"Running VM using image: {self.image}")
|
||||
|
||||
# Call provider.run_vm with explicit image parameter
|
||||
response = await self.config.vm_provider.run_vm(
|
||||
image=self.image,
|
||||
name=self.config.name,
|
||||
run_opts=run_opts,
|
||||
storage=storage_param
|
||||
)
|
||||
self.logger.info(f"VM run response: {response if response else 'None'}")
|
||||
except Exception as run_error:
|
||||
self.logger.error(f"Failed to run VM: {run_error}")
|
||||
raise RuntimeError(f"Failed to start VM: {run_error}")
|
||||
|
||||
# Wait for VM to be ready with a valid IP address
|
||||
self.logger.info("Waiting for VM to be ready with a valid IP address...")
|
||||
try:
|
||||
if self.provider_type == VMProviderType.LUMIER:
|
||||
max_retries = 60 # Increased for Lumier VM startup which takes longer
|
||||
retry_delay = 3 # 3 seconds between retries for Lumier
|
||||
else:
|
||||
max_retries = 30 # Default for other providers
|
||||
retry_delay = 2 # 2 seconds between retries
|
||||
|
||||
self.logger.info(f"Waiting up to {max_retries * retry_delay} seconds for VM to be ready...")
|
||||
ip = await self.get_ip(max_retries=max_retries, retry_delay=retry_delay)
|
||||
|
||||
# If we get here, we have a valid IP
|
||||
self.logger.info(f"VM is ready with IP: {ip}")
|
||||
ip_address = ip
|
||||
except TimeoutError as timeout_error:
|
||||
self.logger.error(str(timeout_error))
|
||||
raise RuntimeError(f"VM startup timed out: {timeout_error}")
|
||||
except Exception as wait_error:
|
||||
self.logger.error(f"Error waiting for VM: {wait_error}")
|
||||
raise RuntimeError(f"VM failed to become ready: {wait_error}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to initialize computer: {e}")
|
||||
raise RuntimeError(f"Failed to initialize computer: {e}")
|
||||
|
||||
try:
|
||||
# Verify we have a valid IP before initializing the interface
|
||||
if not ip_address or ip_address == "unknown" or ip_address == "0.0.0.0":
|
||||
raise RuntimeError(f"Cannot initialize interface - invalid IP address: {ip_address}")
|
||||
|
||||
# Initialize the interface using the factory with the specified OS
|
||||
self.logger.info(f"Initializing interface for {self.os_type} at {ip_address}")
|
||||
from .interface.base import BaseComputerInterface
|
||||
|
||||
# Pass authentication credentials if using cloud provider
|
||||
if self.provider_type == VMProviderType.CLOUD and self.api_key and self.config.name:
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type,
|
||||
ip_address=ip_address,
|
||||
api_key=self.api_key,
|
||||
vm_name=self.config.name
|
||||
),
|
||||
)
|
||||
else:
|
||||
self._interface = cast(
|
||||
BaseComputerInterface,
|
||||
InterfaceFactory.create_interface_for_os(
|
||||
os=self.os_type,
|
||||
ip_address=ip_address
|
||||
),
|
||||
)
|
||||
|
||||
# Wait for the WebSocket interface to be ready
|
||||
self.logger.info("Connecting to WebSocket interface...")
|
||||
|
||||
try:
|
||||
# Use a single timeout for the entire connection process
|
||||
# The VM should already be ready at this point, so we're just establishing the connection
|
||||
await self._interface.wait_for_ready(timeout=30)
|
||||
self.logger.info("WebSocket interface connected successfully")
|
||||
except TimeoutError as e:
|
||||
self.logger.error(f"Failed to connect to WebSocket interface at {ip_address}")
|
||||
raise TimeoutError(
|
||||
f"Could not connect to WebSocket interface at {ip_address}:8000/ws: {str(e)}"
|
||||
)
|
||||
# self.logger.warning(
|
||||
# f"Could not connect to WebSocket interface at {ip_address}:8000/ws: {str(e)}, expect missing functionality"
|
||||
# )
|
||||
|
||||
# Create an event to keep the VM running in background if needed
|
||||
if not self.use_host_computer_server:
|
||||
self._stop_event = asyncio.Event()
|
||||
self._keep_alive_task = asyncio.create_task(self._stop_event.wait())
|
||||
|
||||
self.logger.info("Computer is ready")
|
||||
|
||||
# Set the initialization flag and clear the initializing flag
|
||||
self._initialized = True
|
||||
|
||||
# Set this instance as the default computer for remote decorators
|
||||
helpers.set_default_computer(self)
|
||||
|
||||
self.logger.info("Computer successfully initialized")
|
||||
except Exception as e:
|
||||
raise
|
||||
finally:
|
||||
# Log initialization time for performance monitoring
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self.logger.debug(f"Computer initialization took {duration_ms:.2f}ms")
|
||||
return
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""Disconnect from the computer's WebSocket interface."""
|
||||
if self._interface:
|
||||
self._interface.close()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Disconnect from the computer's WebSocket interface and stop the computer."""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
self.logger.info("Stopping Computer...")
|
||||
|
||||
# In VM mode, first explicitly stop the VM, then exit the provider context
|
||||
if not self.use_host_computer_server and self._provider_context and self.config.vm_provider is not None:
|
||||
try:
|
||||
self.logger.info(f"Stopping VM {self.config.name}...")
|
||||
await self.config.vm_provider.stop_vm(
|
||||
name=self.config.name,
|
||||
storage=self.storage # Pass storage explicitly for clarity
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error stopping VM: {e}")
|
||||
|
||||
self.logger.verbose("Closing VM provider context...")
|
||||
await self.config.vm_provider.__aexit__(None, None, None)
|
||||
self._provider_context = None
|
||||
|
||||
await self.disconnect()
|
||||
self.logger.info("Computer stopped")
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Error during cleanup: {e}") # Log as debug since this might be expected
|
||||
finally:
|
||||
# Log stop time for performance monitoring
|
||||
duration_ms = (time.time() - start_time) * 1000
|
||||
self.logger.debug(f"Computer stop process took {duration_ms:.2f}ms")
|
||||
return
|
||||
|
||||
# @property
|
||||
async def get_ip(self, max_retries: int = 15, retry_delay: int = 3) -> str:
|
||||
"""Get the IP address of the VM or localhost if using host computer server.
|
||||
|
||||
This method delegates to the provider's get_ip method, which waits indefinitely
|
||||
until the VM has a valid IP address.
|
||||
|
||||
Args:
|
||||
max_retries: Unused parameter, kept for backward compatibility
|
||||
retry_delay: Delay between retries in seconds (default: 2)
|
||||
|
||||
Returns:
|
||||
IP address of the VM or localhost if using host computer server
|
||||
"""
|
||||
# For host computer server, always return localhost immediately
|
||||
if self.use_host_computer_server:
|
||||
return "127.0.0.1"
|
||||
|
||||
# Get IP from the provider - each provider implements its own waiting logic
|
||||
if self.config.vm_provider is None:
|
||||
raise RuntimeError("VM provider is not initialized")
|
||||
|
||||
# Log that we're waiting for the IP
|
||||
self.logger.info(f"Waiting for VM {self.config.name} to get an IP address...")
|
||||
|
||||
# Call the provider's get_ip method which will wait indefinitely
|
||||
storage_param = "ephemeral" if self.ephemeral else self.storage
|
||||
|
||||
# Log the image being used
|
||||
self.logger.info(f"Running VM using image: {self.image}")
|
||||
|
||||
# Call provider.get_ip with explicit image parameter
|
||||
ip = await self.config.vm_provider.get_ip(
|
||||
name=self.config.name,
|
||||
storage=storage_param,
|
||||
retry_delay=retry_delay
|
||||
)
|
||||
|
||||
# Log success
|
||||
self.logger.info(f"VM {self.config.name} has IP address: {ip}")
|
||||
return ip
|
||||
|
||||
|
||||
async def wait_vm_ready(self) -> Optional[Dict[str, Any]]:
|
||||
"""Wait for VM to be ready with an IP address.
|
||||
|
||||
Returns:
|
||||
VM status information or None if using host computer server.
|
||||
"""
|
||||
if self.use_host_computer_server:
|
||||
return None
|
||||
|
||||
timeout = 600 # 10 minutes timeout (increased from 4 minutes)
|
||||
interval = 2.0 # 2 seconds between checks (increased to reduce API load)
|
||||
start_time = time.time()
|
||||
last_status = None
|
||||
attempts = 0
|
||||
|
||||
self.logger.info(f"Waiting for VM {self.config.name} to be ready (timeout: {timeout}s)...")
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
attempts += 1
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
try:
|
||||
# Keep polling for VM info
|
||||
if self.config.vm_provider is None:
|
||||
self.logger.error("VM provider is not initialized")
|
||||
vm = None
|
||||
else:
|
||||
vm = await self.config.vm_provider.get_vm(self.config.name)
|
||||
|
||||
# Log full VM properties for debugging (every 30 attempts)
|
||||
if attempts % 30 == 0:
|
||||
self.logger.info(
|
||||
f"VM properties at attempt {attempts}: {vars(vm) if vm else 'None'}"
|
||||
)
|
||||
|
||||
# Get current status for logging
|
||||
current_status = getattr(vm, "status", None) if vm else None
|
||||
if current_status != last_status:
|
||||
self.logger.info(
|
||||
f"VM status changed to: {current_status} (after {elapsed:.1f}s)"
|
||||
)
|
||||
last_status = current_status
|
||||
|
||||
# Check for IP address - ensure it's not None or empty
|
||||
ip = getattr(vm, "ip_address", None) if vm else None
|
||||
if ip and ip.strip(): # Check for non-empty string
|
||||
self.logger.info(
|
||||
f"VM {self.config.name} got IP address: {ip} (after {elapsed:.1f}s)"
|
||||
)
|
||||
return vm
|
||||
|
||||
if attempts % 10 == 0: # Log every 10 attempts to avoid flooding
|
||||
self.logger.info(
|
||||
f"Still waiting for VM IP address... (elapsed: {elapsed:.1f}s)"
|
||||
)
|
||||
else:
|
||||
self.logger.debug(
|
||||
f"Waiting for VM IP address... Current IP: {ip}, Status: {current_status}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error checking VM status (attempt {attempts}): {str(e)}")
|
||||
# If we've been trying for a while and still getting errors, log more details
|
||||
if elapsed > 60: # After 1 minute of errors, log more details
|
||||
self.logger.error(f"Persistent error getting VM status: {str(e)}")
|
||||
self.logger.info("Trying to get VM list for debugging...")
|
||||
try:
|
||||
if self.config.vm_provider is not None:
|
||||
vms = await self.config.vm_provider.list_vms()
|
||||
self.logger.info(
|
||||
f"Available VMs: {[getattr(vm, 'name', None) for vm in vms if hasattr(vm, 'name')]}"
|
||||
)
|
||||
except Exception as list_error:
|
||||
self.logger.error(f"Failed to list VMs: {str(list_error)}")
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
# If we get here, we've timed out
|
||||
elapsed = time.time() - start_time
|
||||
self.logger.error(f"VM {self.config.name} not ready after {elapsed:.1f} seconds")
|
||||
|
||||
# Try to get final VM status for debugging
|
||||
try:
|
||||
if self.config.vm_provider is not None:
|
||||
vm = await self.config.vm_provider.get_vm(self.config.name)
|
||||
# VM data is returned as a dictionary from the Lumier provider
|
||||
status = vm.get('status', 'unknown') if vm else "unknown"
|
||||
ip = vm.get('ip_address') if vm else None
|
||||
else:
|
||||
status = "unknown"
|
||||
ip = None
|
||||
self.logger.error(f"Final VM status: {status}, IP: {ip}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get final VM status: {str(e)}")
|
||||
|
||||
raise TimeoutError(
|
||||
f"VM {self.config.name} not ready after {elapsed:.1f} seconds - IP address not assigned"
|
||||
)
|
||||
|
||||
async def update(self, cpu: Optional[int] = None, memory: Optional[str] = None):
|
||||
"""Update VM settings."""
|
||||
self.logger.info(
|
||||
f"Updating VM settings: CPU={cpu or self.config.cpu}, Memory={memory or self.config.memory}"
|
||||
)
|
||||
update_opts = {
|
||||
"cpu": cpu or int(self.config.cpu),
|
||||
"memory": memory or self.config.memory
|
||||
}
|
||||
if self.config.vm_provider is not None:
|
||||
await self.config.vm_provider.update_vm(
|
||||
name=self.config.name,
|
||||
update_opts=update_opts,
|
||||
storage=self.storage # Pass storage explicitly for clarity
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("VM provider not initialized")
|
||||
|
||||
def get_screenshot_size(self, screenshot: bytes) -> Dict[str, int]:
|
||||
"""Get the dimensions of a screenshot.
|
||||
|
||||
Args:
|
||||
screenshot: The screenshot bytes
|
||||
|
||||
Returns:
|
||||
Dict[str, int]: Dictionary containing 'width' and 'height' of the image
|
||||
"""
|
||||
image = Image.open(io.BytesIO(screenshot))
|
||||
width, height = image.size
|
||||
return {"width": width, "height": height}
|
||||
|
||||
@property
|
||||
def interface(self):
|
||||
"""Get the computer interface for interacting with the VM.
|
||||
|
||||
Returns:
|
||||
The computer interface
|
||||
"""
|
||||
if not hasattr(self, "_interface") or self._interface is None:
|
||||
error_msg = "Computer interface not initialized. Call run() first."
|
||||
self.logger.error(error_msg)
|
||||
self.logger.error(
|
||||
"Make sure to call await computer.run() before using any interface methods."
|
||||
)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
return self._interface
|
||||
|
||||
@property
|
||||
def telemetry_enabled(self) -> bool:
|
||||
"""Check if telemetry is enabled for this computer instance.
|
||||
|
||||
Returns:
|
||||
bool: True if telemetry is enabled, False otherwise
|
||||
"""
|
||||
return self._telemetry_enabled
|
||||
|
||||
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
|
||||
"""Convert normalized coordinates to screen coordinates.
|
||||
|
||||
Args:
|
||||
x: X coordinate between 0 and 1
|
||||
y: Y coordinate between 0 and 1
|
||||
|
||||
Returns:
|
||||
tuple[float, float]: Screen coordinates (x, y)
|
||||
"""
|
||||
return await self.interface.to_screen_coordinates(x, y)
|
||||
|
||||
async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]:
|
||||
"""Convert screen coordinates to screenshot coordinates.
|
||||
|
||||
Args:
|
||||
x: X coordinate in screen space
|
||||
y: Y coordinate in screen space
|
||||
|
||||
Returns:
|
||||
tuple[float, float]: (x, y) coordinates in screenshot space
|
||||
"""
|
||||
return await self.interface.to_screenshot_coordinates(x, y)
|
||||
|
||||
|
||||
# Add virtual environment management functions to computer interface
|
||||
async def venv_install(self, venv_name: str, requirements: list[str]) -> tuple[str, str]:
|
||||
"""Install packages in a virtual environment.
|
||||
|
||||
Args:
|
||||
venv_name: Name of the virtual environment
|
||||
requirements: List of package requirements to install
|
||||
|
||||
Returns:
|
||||
Tuple of (stdout, stderr) from the installation command
|
||||
"""
|
||||
requirements = requirements or []
|
||||
|
||||
# Create virtual environment if it doesn't exist
|
||||
venv_path = f"~/.venvs/{venv_name}"
|
||||
create_cmd = f"mkdir -p ~/.venvs && python3 -m venv {venv_path}"
|
||||
|
||||
# Check if venv exists, if not create it
|
||||
check_cmd = f"test -d {venv_path} || ({create_cmd})"
|
||||
_, _ = await self.interface.run_command(check_cmd)
|
||||
|
||||
# Install packages
|
||||
requirements_str = " ".join(requirements)
|
||||
install_cmd = f". {venv_path}/bin/activate && pip install {requirements_str}"
|
||||
return await self.interface.run_command(install_cmd)
|
||||
|
||||
async def venv_cmd(self, venv_name: str, command: str) -> tuple[str, str]:
|
||||
"""Execute a shell command in a virtual environment.
|
||||
|
||||
Args:
|
||||
venv_name: Name of the virtual environment
|
||||
command: Shell command to execute in the virtual environment
|
||||
|
||||
Returns:
|
||||
Tuple of (stdout, stderr) from the command execution
|
||||
"""
|
||||
venv_path = f"~/.venvs/{venv_name}"
|
||||
|
||||
# Check if virtual environment exists
|
||||
check_cmd = f"test -d {venv_path}"
|
||||
stdout, stderr = await self.interface.run_command(check_cmd)
|
||||
|
||||
if stderr or "test:" in stdout: # venv doesn't exist
|
||||
return "", f"Virtual environment '{venv_name}' does not exist. Create it first using venv_install."
|
||||
|
||||
# Activate virtual environment and run command
|
||||
full_command = f". {venv_path}/bin/activate && {command}"
|
||||
return await self.interface.run_command(full_command)
|
||||
|
||||
async def venv_exec(self, venv_name: str, python_func, *args, **kwargs):
|
||||
"""Execute Python function in a virtual environment using source code extraction.
|
||||
|
||||
Args:
|
||||
venv_name: Name of the virtual environment
|
||||
python_func: A callable function to execute
|
||||
*args: Positional arguments to pass to the function
|
||||
**kwargs: Keyword arguments to pass to the function
|
||||
|
||||
Returns:
|
||||
The result of the function execution, or raises any exception that occurred
|
||||
"""
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
try:
|
||||
# Get function source code using inspect.getsource
|
||||
source = inspect.getsource(python_func)
|
||||
# Remove common leading whitespace (dedent)
|
||||
func_source = textwrap.dedent(source).strip()
|
||||
|
||||
# Remove decorators
|
||||
while func_source.lstrip().startswith("@"):
|
||||
func_source = func_source.split("\n", 1)[1].strip()
|
||||
|
||||
# Get function name for execution
|
||||
func_name = python_func.__name__
|
||||
|
||||
# Serialize args and kwargs as JSON (safer than dill for cross-version compatibility)
|
||||
args_json = json.dumps(args, default=str)
|
||||
kwargs_json = json.dumps(kwargs, default=str)
|
||||
|
||||
except OSError as e:
|
||||
raise Exception(f"Cannot retrieve source code for function {python_func.__name__}: {e}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to reconstruct function source: {e}")
|
||||
|
||||
# Create Python code that will define and execute the function
|
||||
python_code = f'''
|
||||
import json
|
||||
import traceback
|
||||
|
||||
try:
|
||||
# Define the function from source
|
||||
{textwrap.indent(func_source, " ")}
|
||||
|
||||
# Deserialize args and kwargs from JSON
|
||||
args_json = """{args_json}"""
|
||||
kwargs_json = """{kwargs_json}"""
|
||||
args = json.loads(args_json)
|
||||
kwargs = json.loads(kwargs_json)
|
||||
|
||||
# Execute the function
|
||||
result = {func_name}(*args, **kwargs)
|
||||
|
||||
# Create success output payload
|
||||
output_payload = {{
|
||||
"success": True,
|
||||
"result": result,
|
||||
"error": None
|
||||
}}
|
||||
|
||||
except Exception as e:
|
||||
# Create error output payload
|
||||
output_payload = {{
|
||||
"success": False,
|
||||
"result": None,
|
||||
"error": {{
|
||||
"type": type(e).__name__,
|
||||
"message": str(e),
|
||||
"traceback": traceback.format_exc()
|
||||
}}
|
||||
}}
|
||||
|
||||
# Serialize the output payload as JSON
|
||||
import json
|
||||
output_json = json.dumps(output_payload, default=str)
|
||||
|
||||
# Print the JSON output with markers
|
||||
print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
|
||||
'''
|
||||
|
||||
# Encode the Python code in base64 to avoid shell escaping issues
|
||||
encoded_code = base64.b64encode(python_code.encode('utf-8')).decode('ascii')
|
||||
|
||||
# Execute the Python code in the virtual environment
|
||||
python_command = f"python -c \"import base64; exec(base64.b64decode('{encoded_code}').decode('utf-8'))\""
|
||||
stdout, stderr = await self.venv_cmd(venv_name, python_command)
|
||||
|
||||
# Parse the output to extract the payload
|
||||
start_marker = "<<<VENV_EXEC_START>>>"
|
||||
end_marker = "<<<VENV_EXEC_END>>>"
|
||||
|
||||
# Print original stdout
|
||||
print(stdout[:stdout.find(start_marker)])
|
||||
|
||||
if start_marker in stdout and end_marker in stdout:
|
||||
start_idx = stdout.find(start_marker) + len(start_marker)
|
||||
end_idx = stdout.find(end_marker)
|
||||
|
||||
if start_idx < end_idx:
|
||||
output_json = stdout[start_idx:end_idx]
|
||||
|
||||
try:
|
||||
# Decode and deserialize the output payload from JSON
|
||||
output_payload = json.loads(output_json)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to decode output payload: {e}")
|
||||
|
||||
if output_payload["success"]:
|
||||
return output_payload["result"]
|
||||
else:
|
||||
# Recreate and raise the original exception
|
||||
error_info = output_payload["error"]
|
||||
error_class = eval(error_info["type"])
|
||||
raise error_class(error_info["message"])
|
||||
else:
|
||||
raise Exception("Invalid output format: markers found but no content between them")
|
||||
else:
|
||||
# Fallback: return stdout/stderr if no payload markers found
|
||||
raise Exception(f"No output payload found. stdout: {stdout}, stderr: {stderr}")
|
||||
104
libs/python/computer/computer/diorama_computer.py
Normal file
104
libs/python/computer/computer/diorama_computer.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import asyncio
|
||||
from .interface.models import KeyType, Key
|
||||
|
||||
class DioramaComputer:
|
||||
"""
|
||||
A Computer-compatible proxy for Diorama that sends commands over the ComputerInterface.
|
||||
"""
|
||||
def __init__(self, computer, apps):
|
||||
self.computer = computer
|
||||
self.apps = apps
|
||||
self.interface = DioramaComputerInterface(computer, apps)
|
||||
self._initialized = False
|
||||
|
||||
async def __aenter__(self):
|
||||
self._initialized = True
|
||||
return self
|
||||
|
||||
async def run(self):
|
||||
if not self._initialized:
|
||||
await self.__aenter__()
|
||||
return self
|
||||
|
||||
class DioramaComputerInterface:
|
||||
"""
|
||||
Diorama Interface proxy that sends diorama_cmds via the Computer's interface.
|
||||
"""
|
||||
def __init__(self, computer, apps):
|
||||
self.computer = computer
|
||||
self.apps = apps
|
||||
self._scene_size = None
|
||||
|
||||
async def _send_cmd(self, action, arguments=None):
|
||||
arguments = arguments or {}
|
||||
arguments = {"app_list": self.apps, **arguments}
|
||||
# Use the computer's interface (must be initialized)
|
||||
iface = getattr(self.computer, "_interface", None)
|
||||
if iface is None:
|
||||
raise RuntimeError("Computer interface not initialized. Call run() first.")
|
||||
result = await iface.diorama_cmd(action, arguments)
|
||||
if not result.get("success"):
|
||||
raise RuntimeError(f"Diorama command failed: {result.get('error')}\n{result.get('trace')}")
|
||||
return result.get("result")
|
||||
|
||||
async def screenshot(self, as_bytes=True):
|
||||
from PIL import Image
|
||||
import base64
|
||||
result = await self._send_cmd("screenshot")
|
||||
# assume result is a b64 string of an image
|
||||
img_bytes = base64.b64decode(result)
|
||||
import io
|
||||
img = Image.open(io.BytesIO(img_bytes))
|
||||
self._scene_size = img.size
|
||||
return img_bytes if as_bytes else img
|
||||
|
||||
async def get_screen_size(self):
|
||||
if not self._scene_size:
|
||||
await self.screenshot(as_bytes=False)
|
||||
return {"width": self._scene_size[0], "height": self._scene_size[1]}
|
||||
|
||||
async def move_cursor(self, x, y):
|
||||
await self._send_cmd("move_cursor", {"x": x, "y": y})
|
||||
|
||||
async def left_click(self, x=None, y=None):
|
||||
await self._send_cmd("left_click", {"x": x, "y": y})
|
||||
|
||||
async def right_click(self, x=None, y=None):
|
||||
await self._send_cmd("right_click", {"x": x, "y": y})
|
||||
|
||||
async def double_click(self, x=None, y=None):
|
||||
await self._send_cmd("double_click", {"x": x, "y": y})
|
||||
|
||||
async def scroll_up(self, clicks=1):
|
||||
await self._send_cmd("scroll_up", {"clicks": clicks})
|
||||
|
||||
async def scroll_down(self, clicks=1):
|
||||
await self._send_cmd("scroll_down", {"clicks": clicks})
|
||||
|
||||
async def drag_to(self, x, y, duration=0.5):
|
||||
await self._send_cmd("drag_to", {"x": x, "y": y, "duration": duration})
|
||||
|
||||
async def get_cursor_position(self):
|
||||
return await self._send_cmd("get_cursor_position")
|
||||
|
||||
async def type_text(self, text):
|
||||
await self._send_cmd("type_text", {"text": text})
|
||||
|
||||
async def press_key(self, key):
|
||||
await self._send_cmd("press_key", {"key": key})
|
||||
|
||||
async def hotkey(self, *keys):
|
||||
actual_keys = []
|
||||
for key in keys:
|
||||
if isinstance(key, Key):
|
||||
actual_keys.append(key.value)
|
||||
elif isinstance(key, str):
|
||||
# Try to convert to enum if it matches a known key
|
||||
key_or_enum = Key.from_string(key)
|
||||
actual_keys.append(key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum)
|
||||
else:
|
||||
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
|
||||
await self._send_cmd("hotkey", {"keys": actual_keys})
|
||||
|
||||
async def to_screen_coordinates(self, x, y):
|
||||
return await self._send_cmd("to_screen_coordinates", {"x": x, "y": y})
|
||||
52
libs/python/computer/computer/helpers.py
Normal file
52
libs/python/computer/computer/helpers.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Helper functions and decorators for the Computer module.
|
||||
"""
|
||||
import logging
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, Optional, TypeVar, cast
|
||||
|
||||
# Global reference to the default computer instance
|
||||
_default_computer = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def set_default_computer(computer):
|
||||
"""
|
||||
Set the default computer instance to be used by the remote decorator.
|
||||
|
||||
Args:
|
||||
computer: The computer instance to use as default
|
||||
"""
|
||||
global _default_computer
|
||||
_default_computer = computer
|
||||
|
||||
|
||||
def sandboxed(venv_name: str = "default", computer: str = "default", max_retries: int = 3):
|
||||
"""
|
||||
Decorator that wraps a function to be executed remotely via computer.venv_exec
|
||||
|
||||
Args:
|
||||
venv_name: Name of the virtual environment to execute in
|
||||
computer: The computer instance to use, or "default" to use the globally set default
|
||||
max_retries: Maximum number of retries for the remote execution
|
||||
"""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Determine which computer instance to use
|
||||
comp = computer if computer != "default" else _default_computer
|
||||
|
||||
if comp is None:
|
||||
raise RuntimeError("No computer instance available. Either specify a computer instance or call set_default_computer() first.")
|
||||
|
||||
for i in range(max_retries):
|
||||
try:
|
||||
return await comp.venv_exec(venv_name, func, *args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"Attempt {i+1} failed: {e}")
|
||||
await asyncio.sleep(1)
|
||||
if i == max_retries - 1:
|
||||
raise e
|
||||
return wrapper
|
||||
return decorator
|
||||
13
libs/python/computer/computer/interface/__init__.py
Normal file
13
libs/python/computer/computer/interface/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
"""
|
||||
Interface package for Computer SDK.
|
||||
"""
|
||||
|
||||
from .factory import InterfaceFactory
|
||||
from .base import BaseComputerInterface
|
||||
from .macos import MacOSComputerInterface
|
||||
|
||||
__all__ = [
|
||||
"InterfaceFactory",
|
||||
"BaseComputerInterface",
|
||||
"MacOSComputerInterface",
|
||||
]
|
||||
271
libs/python/computer/computer/interface/base.py
Normal file
271
libs/python/computer/computer/interface/base.py
Normal file
@@ -0,0 +1,271 @@
|
||||
"""Base interface for computer control."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Dict, Any, Tuple, List
|
||||
from ..logger import Logger, LogLevel
|
||||
from .models import MouseButton
|
||||
|
||||
|
||||
class BaseComputerInterface(ABC):
|
||||
"""Base class for computer control interfaces."""
|
||||
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
|
||||
"""Initialize interface.
|
||||
|
||||
Args:
|
||||
ip_address: IP address of the computer to control
|
||||
username: Username for authentication
|
||||
password: Password for authentication
|
||||
api_key: Optional API key for cloud authentication
|
||||
vm_name: Optional VM name for cloud authentication
|
||||
"""
|
||||
self.ip_address = ip_address
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.api_key = api_key
|
||||
self.vm_name = vm_name
|
||||
self.logger = Logger("cua.interface", LogLevel.NORMAL)
|
||||
|
||||
@abstractmethod
|
||||
async def wait_for_ready(self, timeout: int = 60) -> None:
|
||||
"""Wait for interface to be ready.
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait in seconds
|
||||
|
||||
Raises:
|
||||
TimeoutError: If interface is not ready within timeout
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
"""Close the interface connection."""
|
||||
pass
|
||||
|
||||
def force_close(self) -> None:
|
||||
"""Force close the interface connection.
|
||||
|
||||
By default, this just calls close(), but subclasses can override
|
||||
to provide more forceful cleanup.
|
||||
"""
|
||||
self.close()
|
||||
|
||||
# Mouse Actions
|
||||
@abstractmethod
|
||||
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: "MouseButton" = "left") -> None:
|
||||
"""Press and hold a mouse button."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: "MouseButton" = "left") -> None:
|
||||
"""Release a mouse button."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Perform a left click."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Perform a right click."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
"""Perform a double click."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def move_cursor(self, x: int, y: int) -> None:
|
||||
"""Move the cursor to specified position."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> None:
|
||||
"""Drag from current position to specified coordinates.
|
||||
|
||||
Args:
|
||||
x: The x coordinate to drag to
|
||||
y: The y coordinate to drag to
|
||||
button: The mouse button to use ('left', 'middle', 'right')
|
||||
duration: How long the drag should take in seconds
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> None:
|
||||
"""Drag the cursor along a path of coordinates.
|
||||
|
||||
Args:
|
||||
path: List of (x, y) coordinate tuples defining the drag path
|
||||
button: The mouse button to use ('left', 'middle', 'right')
|
||||
duration: Total time in seconds that the drag operation should take
|
||||
"""
|
||||
pass
|
||||
|
||||
# Keyboard Actions
|
||||
@abstractmethod
|
||||
async def key_down(self, key: str) -> None:
|
||||
"""Press and hold a key."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def key_up(self, key: str) -> None:
|
||||
"""Release a key."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def type_text(self, text: str) -> None:
|
||||
"""Type the specified text."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def press_key(self, key: str) -> None:
|
||||
"""Press a single key."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def hotkey(self, *keys: str) -> None:
|
||||
"""Press multiple keys simultaneously."""
|
||||
pass
|
||||
|
||||
# Scrolling Actions
|
||||
@abstractmethod
|
||||
async def scroll(self, x: int, y: int) -> None:
|
||||
"""Scroll the mouse wheel."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def scroll_down(self, clicks: int = 1) -> None:
|
||||
"""Scroll down."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def scroll_up(self, clicks: int = 1) -> None:
|
||||
"""Scroll up."""
|
||||
pass
|
||||
|
||||
# Screen Actions
|
||||
@abstractmethod
|
||||
async def screenshot(self) -> bytes:
|
||||
"""Take a screenshot.
|
||||
|
||||
Returns:
|
||||
Raw bytes of the screenshot image
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_screen_size(self) -> Dict[str, int]:
|
||||
"""Get the screen dimensions.
|
||||
|
||||
Returns:
|
||||
Dict with 'width' and 'height' keys
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_cursor_position(self) -> Dict[str, int]:
|
||||
"""Get current cursor position."""
|
||||
pass
|
||||
|
||||
# Clipboard Actions
|
||||
@abstractmethod
|
||||
async def copy_to_clipboard(self) -> str:
|
||||
"""Get clipboard content."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_clipboard(self, text: str) -> None:
|
||||
"""Set clipboard content."""
|
||||
pass
|
||||
|
||||
# File System Actions
|
||||
@abstractmethod
|
||||
async def file_exists(self, path: str) -> bool:
|
||||
"""Check if file exists."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def directory_exists(self, path: str) -> bool:
|
||||
"""Check if directory exists."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def list_dir(self, path: str) -> List[str]:
|
||||
"""List directory contents."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def read_text(self, path: str) -> str:
|
||||
"""Read file text contents."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def write_text(self, path: str, content: str) -> None:
|
||||
"""Write file text contents."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def read_bytes(self, path: str) -> bytes:
|
||||
"""Read file binary contents."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def write_bytes(self, path: str, content: bytes) -> None:
|
||||
"""Write file binary contents."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_file(self, path: str) -> None:
|
||||
"""Delete file."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def create_dir(self, path: str) -> None:
|
||||
"""Create directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def delete_dir(self, path: str) -> None:
|
||||
"""Delete directory."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def run_command(self, command: str) -> Tuple[str, str]:
|
||||
"""Run shell command."""
|
||||
pass
|
||||
|
||||
# Accessibility Actions
|
||||
@abstractmethod
|
||||
async def get_accessibility_tree(self) -> Dict:
|
||||
"""Get the accessibility tree of the current screen."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
|
||||
"""Convert screenshot coordinates to screen coordinates.
|
||||
|
||||
Args:
|
||||
x: X coordinate in screenshot space
|
||||
y: Y coordinate in screenshot space
|
||||
|
||||
Returns:
|
||||
tuple[float, float]: (x, y) coordinates in screen space
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]:
|
||||
"""Convert screen coordinates to screenshot coordinates.
|
||||
|
||||
Args:
|
||||
x: X coordinate in screen space
|
||||
y: Y coordinate in screen space
|
||||
|
||||
Returns:
|
||||
tuple[float, float]: (x, y) coordinates in screenshot space
|
||||
"""
|
||||
pass
|
||||
42
libs/python/computer/computer/interface/factory.py
Normal file
42
libs/python/computer/computer/interface/factory.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""Factory for creating computer interfaces."""
|
||||
|
||||
from typing import Literal, Optional
|
||||
from .base import BaseComputerInterface
|
||||
|
||||
class InterfaceFactory:
|
||||
"""Factory for creating OS-specific computer interfaces."""
|
||||
|
||||
@staticmethod
|
||||
def create_interface_for_os(
|
||||
os: Literal['macos', 'linux', 'windows'],
|
||||
ip_address: str,
|
||||
api_key: Optional[str] = None,
|
||||
vm_name: Optional[str] = None
|
||||
) -> BaseComputerInterface:
|
||||
"""Create an interface for the specified OS.
|
||||
|
||||
Args:
|
||||
os: Operating system type ('macos', 'linux', or 'windows')
|
||||
ip_address: IP address of the computer to control
|
||||
api_key: Optional API key for cloud authentication
|
||||
vm_name: Optional VM name for cloud authentication
|
||||
|
||||
Returns:
|
||||
BaseComputerInterface: The appropriate interface for the OS
|
||||
|
||||
Raises:
|
||||
ValueError: If the OS type is not supported
|
||||
"""
|
||||
# Import implementations here to avoid circular imports
|
||||
from .macos import MacOSComputerInterface
|
||||
from .linux import LinuxComputerInterface
|
||||
from .windows import WindowsComputerInterface
|
||||
|
||||
if os == 'macos':
|
||||
return MacOSComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
|
||||
elif os == 'linux':
|
||||
return LinuxComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
|
||||
elif os == 'windows':
|
||||
return WindowsComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
|
||||
else:
|
||||
raise ValueError(f"Unsupported OS type: {os}")
|
||||
688
libs/python/computer/computer/interface/linux.py
Normal file
688
libs/python/computer/computer/interface/linux.py
Normal file
@@ -0,0 +1,688 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from PIL import Image
|
||||
|
||||
import websockets
|
||||
|
||||
from ..logger import Logger, LogLevel
|
||||
from .base import BaseComputerInterface
|
||||
from ..utils import decode_base64_image, encode_base64_image, bytes_to_image, draw_box, resize_image
|
||||
from .models import Key, KeyType, MouseButton
|
||||
|
||||
|
||||
class LinuxComputerInterface(BaseComputerInterface):
|
||||
"""Interface for Linux."""
|
||||
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
|
||||
super().__init__(ip_address, username, password, api_key, vm_name)
|
||||
self._ws = None
|
||||
self._reconnect_task = None
|
||||
self._closed = False
|
||||
self._last_ping = 0
|
||||
self._ping_interval = 5 # Send ping every 5 seconds
|
||||
self._ping_timeout = 120 # Wait 120 seconds for pong response
|
||||
self._reconnect_delay = 1 # Start with 1 second delay
|
||||
self._max_reconnect_delay = 30 # Maximum delay between reconnection attempts
|
||||
self._log_connection_attempts = True # Flag to control connection attempt logging
|
||||
self._authenticated = False # Track authentication status
|
||||
self._command_lock = asyncio.Lock() # Lock to ensure only one command at a time
|
||||
|
||||
# Set logger name for Linux interface
|
||||
self.logger = Logger("computer.interface.linux", LogLevel.NORMAL)
|
||||
|
||||
@property
|
||||
def ws_uri(self) -> str:
|
||||
"""Get the WebSocket URI using the current IP address.
|
||||
|
||||
Returns:
|
||||
WebSocket URI for the Computer API Server
|
||||
"""
|
||||
protocol = "wss" if self.api_key else "ws"
|
||||
port = "8443" if self.api_key else "8000"
|
||||
return f"{protocol}://{self.ip_address}:{port}/ws"
|
||||
|
||||
async def _keep_alive(self):
|
||||
"""Keep the WebSocket connection alive with automatic reconnection."""
|
||||
retry_count = 0
|
||||
max_log_attempts = 1 # Only log the first attempt at INFO level
|
||||
log_interval = 500 # Then log every 500th attempt (significantly increased from 30)
|
||||
last_warning_time = 0
|
||||
min_warning_interval = 30 # Minimum seconds between connection lost warnings
|
||||
min_retry_delay = 0.5 # Minimum delay between connection attempts (500ms)
|
||||
|
||||
while not self._closed:
|
||||
try:
|
||||
if self._ws is None or (
|
||||
self._ws and self._ws.state == websockets.protocol.State.CLOSED
|
||||
):
|
||||
try:
|
||||
retry_count += 1
|
||||
|
||||
# Add a minimum delay between connection attempts to avoid flooding
|
||||
if retry_count > 1:
|
||||
await asyncio.sleep(min_retry_delay)
|
||||
|
||||
# Only log the first attempt at INFO level, then every Nth attempt
|
||||
if retry_count == 1:
|
||||
self.logger.info(f"Attempting WebSocket connection to {self.ws_uri}")
|
||||
elif retry_count % log_interval == 0:
|
||||
self.logger.info(
|
||||
f"Still attempting WebSocket connection (attempt {retry_count})..."
|
||||
)
|
||||
else:
|
||||
# All other attempts are logged at DEBUG level
|
||||
self.logger.debug(
|
||||
f"Attempting WebSocket connection to {self.ws_uri} (attempt {retry_count})"
|
||||
)
|
||||
|
||||
self._ws = await asyncio.wait_for(
|
||||
websockets.connect(
|
||||
self.ws_uri,
|
||||
max_size=1024 * 1024 * 10, # 10MB limit
|
||||
max_queue=32,
|
||||
ping_interval=self._ping_interval,
|
||||
ping_timeout=self._ping_timeout,
|
||||
close_timeout=5,
|
||||
compression=None, # Disable compression to reduce overhead
|
||||
),
|
||||
timeout=120,
|
||||
)
|
||||
self.logger.info("WebSocket connection established")
|
||||
|
||||
# Authentication will be handled by the first command that needs it
|
||||
# Don't do authentication here to avoid recv conflicts
|
||||
|
||||
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
|
||||
self._last_ping = time.time()
|
||||
retry_count = 0 # Reset retry count on successful connection
|
||||
self._authenticated = False # Reset auth status on new connection
|
||||
|
||||
except (asyncio.TimeoutError, websockets.exceptions.WebSocketException) as e:
|
||||
next_retry = self._reconnect_delay
|
||||
|
||||
# Only log the first error at WARNING level, then every Nth attempt
|
||||
if retry_count == 1:
|
||||
self.logger.warning(
|
||||
f"Computer API Server not ready yet. Will retry automatically."
|
||||
)
|
||||
elif retry_count % log_interval == 0:
|
||||
self.logger.warning(
|
||||
f"Still waiting for Computer API Server (attempt {retry_count})..."
|
||||
)
|
||||
else:
|
||||
# All other errors are logged at DEBUG level
|
||||
self.logger.debug(f"Connection attempt {retry_count} failed: {e}")
|
||||
|
||||
if self._ws:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except:
|
||||
pass
|
||||
self._ws = None
|
||||
|
||||
# Regular ping to check connection
|
||||
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
|
||||
try:
|
||||
if time.time() - self._last_ping >= self._ping_interval:
|
||||
pong_waiter = await self._ws.ping()
|
||||
await asyncio.wait_for(pong_waiter, timeout=self._ping_timeout)
|
||||
self._last_ping = time.time()
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Ping failed: {e}")
|
||||
if self._ws:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except:
|
||||
pass
|
||||
self._ws = None
|
||||
continue
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
current_time = time.time()
|
||||
# Only log connection lost warnings at most once every min_warning_interval seconds
|
||||
if current_time - last_warning_time >= min_warning_interval:
|
||||
self.logger.warning(
|
||||
f"Computer API Server connection lost. Will retry automatically."
|
||||
)
|
||||
last_warning_time = current_time
|
||||
else:
|
||||
# Log at debug level instead
|
||||
self.logger.debug(f"Connection lost: {e}")
|
||||
|
||||
if self._ws:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except:
|
||||
pass
|
||||
self._ws = None
|
||||
|
||||
async def _ensure_connection(self):
|
||||
"""Ensure WebSocket connection is established."""
|
||||
if self._reconnect_task is None or self._reconnect_task.done():
|
||||
self._reconnect_task = asyncio.create_task(self._keep_alive())
|
||||
|
||||
retry_count = 0
|
||||
max_retries = 5
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
|
||||
return
|
||||
retry_count += 1
|
||||
await asyncio.sleep(1)
|
||||
except Exception as e:
|
||||
# Only log at ERROR level for the last retry attempt
|
||||
if retry_count == max_retries - 1:
|
||||
self.logger.error(
|
||||
f"Persistent connection check error after {retry_count} attempts: {e}"
|
||||
)
|
||||
else:
|
||||
self.logger.debug(f"Connection check error (attempt {retry_count}): {e}")
|
||||
retry_count += 1
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
raise ConnectionError("Failed to establish WebSocket connection after multiple retries")
|
||||
|
||||
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Send command through WebSocket."""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
# Acquire lock to ensure only one command is processed at a time
|
||||
async with self._command_lock:
|
||||
self.logger.debug(f"Acquired lock for command: {command}")
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
await self._ensure_connection()
|
||||
if not self._ws:
|
||||
raise ConnectionError("WebSocket connection is not established")
|
||||
|
||||
# Handle authentication if needed
|
||||
if self.api_key and self.vm_name and not self._authenticated:
|
||||
self.logger.info("Performing authentication handshake...")
|
||||
auth_message = {
|
||||
"command": "authenticate",
|
||||
"params": {
|
||||
"api_key": self.api_key,
|
||||
"container_name": self.vm_name
|
||||
}
|
||||
}
|
||||
await self._ws.send(json.dumps(auth_message))
|
||||
|
||||
# Wait for authentication response
|
||||
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
|
||||
auth_result = json.loads(auth_response)
|
||||
|
||||
if not auth_result.get("success"):
|
||||
error_msg = auth_result.get("error", "Authentication failed")
|
||||
self.logger.error(f"Authentication failed: {error_msg}")
|
||||
self._authenticated = False
|
||||
raise ConnectionError(f"Authentication failed: {error_msg}")
|
||||
|
||||
self.logger.info("Authentication successful")
|
||||
self._authenticated = True
|
||||
|
||||
message = {"command": command, "params": params or {}}
|
||||
await self._ws.send(json.dumps(message))
|
||||
response = await asyncio.wait_for(self._ws.recv(), timeout=30)
|
||||
self.logger.debug(f"Completed command: {command}")
|
||||
return json.loads(response)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
retry_count += 1
|
||||
if retry_count < max_retries:
|
||||
# Only log at debug level for intermediate retries
|
||||
self.logger.debug(
|
||||
f"Command '{command}' failed (attempt {retry_count}/{max_retries}): {e}"
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
# Only log at error level for the final failure
|
||||
self.logger.error(
|
||||
f"Failed to send command '{command}' after {max_retries} retries"
|
||||
)
|
||||
self.logger.debug(f"Command failure details: {e}")
|
||||
raise last_error if last_error else RuntimeError("Failed to send command")
|
||||
|
||||
async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0):
|
||||
"""Wait for WebSocket connection to become available."""
|
||||
start_time = time.time()
|
||||
last_error = None
|
||||
attempt_count = 0
|
||||
progress_interval = 10 # Log progress every 10 seconds
|
||||
last_progress_time = start_time
|
||||
|
||||
# Disable detailed logging for connection attempts
|
||||
self._log_connection_attempts = False
|
||||
|
||||
try:
|
||||
self.logger.info(
|
||||
f"Waiting for Computer API Server to be ready (timeout: {timeout}s)..."
|
||||
)
|
||||
|
||||
# Start the keep-alive task if it's not already running
|
||||
if self._reconnect_task is None or self._reconnect_task.done():
|
||||
self._reconnect_task = asyncio.create_task(self._keep_alive())
|
||||
|
||||
# Wait for the connection to be established
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
attempt_count += 1
|
||||
current_time = time.time()
|
||||
|
||||
# Log progress periodically without flooding logs
|
||||
if current_time - last_progress_time >= progress_interval:
|
||||
elapsed = current_time - start_time
|
||||
self.logger.info(
|
||||
f"Still waiting for Computer API Server... (elapsed: {elapsed:.1f}s, attempts: {attempt_count})"
|
||||
)
|
||||
last_progress_time = current_time
|
||||
|
||||
# Check if we have a connection
|
||||
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
|
||||
# Test the connection with a simple command
|
||||
try:
|
||||
await self._send_command("get_screen_size")
|
||||
elapsed = time.time() - start_time
|
||||
self.logger.info(
|
||||
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"
|
||||
)
|
||||
return # Connection is fully working
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.logger.debug(f"Connection test failed: {e}")
|
||||
|
||||
# Wait before trying again
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.logger.debug(f"Connection attempt {attempt_count} failed: {e}")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
# If we get here, we've timed out
|
||||
error_msg = f"Could not connect to {self.ip_address} after {timeout} seconds"
|
||||
if last_error:
|
||||
error_msg += f": {str(last_error)}"
|
||||
self.logger.error(error_msg)
|
||||
raise TimeoutError(error_msg)
|
||||
finally:
|
||||
# Reset to default logging behavior
|
||||
self._log_connection_attempts = False
|
||||
|
||||
def close(self):
|
||||
"""Close WebSocket connection.
|
||||
|
||||
Note: In host computer server mode, we leave the connection open
|
||||
to allow other clients to connect to the same server. The server
|
||||
will handle cleaning up idle connections.
|
||||
"""
|
||||
# Only cancel the reconnect task
|
||||
if self._reconnect_task:
|
||||
self._reconnect_task.cancel()
|
||||
|
||||
# Don't set closed flag or close websocket by default
|
||||
# This allows the server to stay connected for other clients
|
||||
# self._closed = True
|
||||
# if self._ws:
|
||||
# asyncio.create_task(self._ws.close())
|
||||
# self._ws = None
|
||||
|
||||
def force_close(self):
|
||||
"""Force close the WebSocket connection.
|
||||
|
||||
This method should be called when you want to completely
|
||||
shut down the connection, not just for regular cleanup.
|
||||
"""
|
||||
self._closed = True
|
||||
if self._reconnect_task:
|
||||
self._reconnect_task.cancel()
|
||||
if self._ws:
|
||||
asyncio.create_task(self._ws.close())
|
||||
self._ws = None
|
||||
|
||||
# Mouse Actions
|
||||
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> None:
|
||||
await self._send_command("mouse_down", {"x": x, "y": y, "button": button})
|
||||
|
||||
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> None:
|
||||
await self._send_command("mouse_up", {"x": x, "y": y, "button": button})
|
||||
|
||||
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
await self._send_command("left_click", {"x": x, "y": y})
|
||||
|
||||
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
await self._send_command("right_click", {"x": x, "y": y})
|
||||
|
||||
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
await self._send_command("double_click", {"x": x, "y": y})
|
||||
|
||||
async def move_cursor(self, x: int, y: int) -> None:
|
||||
await self._send_command("move_cursor", {"x": x, "y": y})
|
||||
|
||||
async def drag_to(self, x: int, y: int, button: "MouseButton" = "left", duration: float = 0.5) -> None:
|
||||
await self._send_command(
|
||||
"drag_to", {"x": x, "y": y, "button": button, "duration": duration}
|
||||
)
|
||||
|
||||
async def drag(self, path: List[Tuple[int, int]], button: "MouseButton" = "left", duration: float = 0.5) -> None:
|
||||
await self._send_command(
|
||||
"drag", {"path": path, "button": button, "duration": duration}
|
||||
)
|
||||
|
||||
# Keyboard Actions
|
||||
async def key_down(self, key: "KeyType") -> None:
|
||||
await self._send_command("key_down", {"key": key})
|
||||
|
||||
async def key_up(self, key: "KeyType") -> None:
|
||||
await self._send_command("key_up", {"key": key})
|
||||
|
||||
async def type_text(self, text: str) -> None:
|
||||
# Temporary fix for https://github.com/trycua/cua/issues/165
|
||||
# Check if text contains Unicode characters
|
||||
if any(ord(char) > 127 for char in text):
|
||||
# For Unicode text, use clipboard and paste
|
||||
await self.set_clipboard(text)
|
||||
await self.hotkey(Key.COMMAND, 'v')
|
||||
else:
|
||||
# For ASCII text, use the regular typing method
|
||||
await self._send_command("type_text", {"text": text})
|
||||
|
||||
async def press(self, key: "KeyType") -> None:
|
||||
"""Press a single key.
|
||||
|
||||
Args:
|
||||
key: The key to press. Can be any of:
|
||||
- A Key enum value (recommended), e.g. Key.PAGE_DOWN
|
||||
- A direct key value string, e.g. 'pagedown'
|
||||
- A single character string, e.g. 'a'
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Using enum (recommended)
|
||||
await interface.press(Key.PAGE_DOWN)
|
||||
await interface.press(Key.ENTER)
|
||||
|
||||
# Using direct values
|
||||
await interface.press('pagedown')
|
||||
await interface.press('enter')
|
||||
|
||||
# Using single characters
|
||||
await interface.press('a')
|
||||
```
|
||||
|
||||
Raises:
|
||||
ValueError: If the key type is invalid or the key is not recognized
|
||||
"""
|
||||
if isinstance(key, Key):
|
||||
actual_key = key.value
|
||||
elif isinstance(key, str):
|
||||
# Try to convert to enum if it matches a known key
|
||||
key_or_enum = Key.from_string(key)
|
||||
actual_key = key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum
|
||||
else:
|
||||
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
|
||||
|
||||
await self._send_command("press_key", {"key": actual_key})
|
||||
|
||||
async def press_key(self, key: "KeyType") -> None:
|
||||
"""DEPRECATED: Use press() instead.
|
||||
|
||||
This method is kept for backward compatibility but will be removed in a future version.
|
||||
Please use the press() method instead.
|
||||
"""
|
||||
await self.press(key)
|
||||
|
||||
async def hotkey(self, *keys: "KeyType") -> None:
|
||||
"""Press multiple keys simultaneously.
|
||||
|
||||
Args:
|
||||
*keys: Multiple keys to press simultaneously. Each key can be any of:
|
||||
- A Key enum value (recommended), e.g. Key.COMMAND
|
||||
- A direct key value string, e.g. 'command'
|
||||
- A single character string, e.g. 'a'
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Using enums (recommended)
|
||||
await interface.hotkey(Key.COMMAND, Key.C) # Copy
|
||||
await interface.hotkey(Key.COMMAND, Key.V) # Paste
|
||||
|
||||
# Using mixed formats
|
||||
await interface.hotkey(Key.COMMAND, 'a') # Select all
|
||||
```
|
||||
|
||||
Raises:
|
||||
ValueError: If any key type is invalid or not recognized
|
||||
"""
|
||||
actual_keys = []
|
||||
for key in keys:
|
||||
if isinstance(key, Key):
|
||||
actual_keys.append(key.value)
|
||||
elif isinstance(key, str):
|
||||
# Try to convert to enum if it matches a known key
|
||||
key_or_enum = Key.from_string(key)
|
||||
actual_keys.append(key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum)
|
||||
else:
|
||||
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
|
||||
|
||||
await self._send_command("hotkey", {"keys": actual_keys})
|
||||
|
||||
# Scrolling Actions
|
||||
async def scroll(self, x: int, y: int) -> None:
|
||||
await self._send_command("scroll", {"x": x, "y": y})
|
||||
|
||||
async def scroll_down(self, clicks: int = 1) -> None:
|
||||
await self._send_command("scroll_down", {"clicks": clicks})
|
||||
|
||||
async def scroll_up(self, clicks: int = 1) -> None:
|
||||
await self._send_command("scroll_up", {"clicks": clicks})
|
||||
|
||||
# Screen Actions
|
||||
async def screenshot(
|
||||
self,
|
||||
boxes: Optional[List[Tuple[int, int, int, int]]] = None,
|
||||
box_color: str = "#FF0000",
|
||||
box_thickness: int = 2,
|
||||
scale_factor: float = 1.0,
|
||||
) -> bytes:
|
||||
"""Take a screenshot with optional box drawing and scaling.
|
||||
|
||||
Args:
|
||||
boxes: Optional list of (x, y, width, height) tuples defining boxes to draw in screen coordinates
|
||||
box_color: Color of the boxes in hex format (default: "#FF0000" red)
|
||||
box_thickness: Thickness of the box borders in pixels (default: 2)
|
||||
scale_factor: Factor to scale the final image by (default: 1.0)
|
||||
Use > 1.0 to enlarge, < 1.0 to shrink (e.g., 0.5 for half size, 2.0 for double)
|
||||
|
||||
Returns:
|
||||
bytes: The screenshot image data, optionally with boxes drawn on it and scaled
|
||||
"""
|
||||
result = await self._send_command("screenshot")
|
||||
if not result.get("image_data"):
|
||||
raise RuntimeError("Failed to take screenshot")
|
||||
|
||||
screenshot = decode_base64_image(result["image_data"])
|
||||
|
||||
if boxes:
|
||||
# Get the natural scaling between screen and screenshot
|
||||
screen_size = await self.get_screen_size()
|
||||
screenshot_width, screenshot_height = bytes_to_image(screenshot).size
|
||||
width_scale = screenshot_width / screen_size["width"]
|
||||
height_scale = screenshot_height / screen_size["height"]
|
||||
|
||||
# Scale box coordinates from screen space to screenshot space
|
||||
for box in boxes:
|
||||
scaled_box = (
|
||||
int(box[0] * width_scale), # x
|
||||
int(box[1] * height_scale), # y
|
||||
int(box[2] * width_scale), # width
|
||||
int(box[3] * height_scale), # height
|
||||
)
|
||||
screenshot = draw_box(
|
||||
screenshot,
|
||||
x=scaled_box[0],
|
||||
y=scaled_box[1],
|
||||
width=scaled_box[2],
|
||||
height=scaled_box[3],
|
||||
color=box_color,
|
||||
thickness=box_thickness,
|
||||
)
|
||||
|
||||
if scale_factor != 1.0:
|
||||
screenshot = resize_image(screenshot, scale_factor)
|
||||
|
||||
return screenshot
|
||||
|
||||
async def get_screen_size(self) -> Dict[str, int]:
|
||||
result = await self._send_command("get_screen_size")
|
||||
if result["success"] and result["size"]:
|
||||
return result["size"]
|
||||
raise RuntimeError("Failed to get screen size")
|
||||
|
||||
async def get_cursor_position(self) -> Dict[str, int]:
|
||||
result = await self._send_command("get_cursor_position")
|
||||
if result["success"] and result["position"]:
|
||||
return result["position"]
|
||||
raise RuntimeError("Failed to get cursor position")
|
||||
|
||||
# Clipboard Actions
|
||||
async def copy_to_clipboard(self) -> str:
|
||||
result = await self._send_command("copy_to_clipboard")
|
||||
if result["success"] and result["content"]:
|
||||
return result["content"]
|
||||
raise RuntimeError("Failed to get clipboard content")
|
||||
|
||||
async def set_clipboard(self, text: str) -> None:
|
||||
await self._send_command("set_clipboard", {"text": text})
|
||||
|
||||
# File System Actions
|
||||
async def file_exists(self, path: str) -> bool:
|
||||
result = await self._send_command("file_exists", {"path": path})
|
||||
return result.get("exists", False)
|
||||
|
||||
async def directory_exists(self, path: str) -> bool:
|
||||
result = await self._send_command("directory_exists", {"path": path})
|
||||
return result.get("exists", False)
|
||||
|
||||
async def list_dir(self, path: str) -> list[str]:
|
||||
result = await self._send_command("list_dir", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to list directory"))
|
||||
return result.get("files", [])
|
||||
|
||||
async def read_text(self, path: str) -> str:
|
||||
result = await self._send_command("read_text", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to read file"))
|
||||
return result.get("content", "")
|
||||
|
||||
async def write_text(self, path: str, content: str) -> None:
|
||||
result = await self._send_command("write_text", {"path": path, "content": content})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to write file"))
|
||||
|
||||
async def read_bytes(self, path: str) -> bytes:
|
||||
result = await self._send_command("read_bytes", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to read file"))
|
||||
content_b64 = result.get("content_b64", "")
|
||||
return decode_base64_image(content_b64)
|
||||
|
||||
async def write_bytes(self, path: str, content: bytes) -> None:
|
||||
result = await self._send_command("write_bytes", {"path": path, "content_b64": encode_base64_image(content)})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to write file"))
|
||||
|
||||
async def delete_file(self, path: str) -> None:
|
||||
result = await self._send_command("delete_file", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to delete file"))
|
||||
|
||||
async def create_dir(self, path: str) -> None:
|
||||
result = await self._send_command("create_dir", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to create directory"))
|
||||
|
||||
async def delete_dir(self, path: str) -> None:
|
||||
result = await self._send_command("delete_dir", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to delete directory"))
|
||||
|
||||
async def run_command(self, command: str) -> Tuple[str, str]:
|
||||
result = await self._send_command("run_command", {"command": command})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to run command"))
|
||||
return result.get("stdout", ""), result.get("stderr", "")
|
||||
|
||||
# Accessibility Actions
|
||||
async def get_accessibility_tree(self) -> Dict[str, Any]:
|
||||
"""Get the accessibility tree of the current screen."""
|
||||
result = await self._send_command("get_accessibility_tree")
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to get accessibility tree"))
|
||||
return result
|
||||
|
||||
async def get_active_window_bounds(self) -> Dict[str, int]:
|
||||
"""Get the bounds of the currently active window."""
|
||||
result = await self._send_command("get_active_window_bounds")
|
||||
if result["success"] and result["bounds"]:
|
||||
return result["bounds"]
|
||||
raise RuntimeError("Failed to get active window bounds")
|
||||
|
||||
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
|
||||
"""Convert screenshot coordinates to screen coordinates.
|
||||
|
||||
Args:
|
||||
x: X coordinate in screenshot space
|
||||
y: Y coordinate in screenshot space
|
||||
|
||||
Returns:
|
||||
tuple[float, float]: (x, y) coordinates in screen space
|
||||
"""
|
||||
screen_size = await self.get_screen_size()
|
||||
screenshot = await self.screenshot()
|
||||
screenshot_img = bytes_to_image(screenshot)
|
||||
screenshot_width, screenshot_height = screenshot_img.size
|
||||
|
||||
# Calculate scaling factors
|
||||
width_scale = screen_size["width"] / screenshot_width
|
||||
height_scale = screen_size["height"] / screenshot_height
|
||||
|
||||
# Convert coordinates
|
||||
screen_x = x * width_scale
|
||||
screen_y = y * height_scale
|
||||
|
||||
return screen_x, screen_y
|
||||
|
||||
async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]:
|
||||
"""Convert screen coordinates to screenshot coordinates.
|
||||
|
||||
Args:
|
||||
x: X coordinate in screen space
|
||||
y: Y coordinate in screen space
|
||||
|
||||
Returns:
|
||||
tuple[float, float]: (x, y) coordinates in screenshot space
|
||||
"""
|
||||
screen_size = await self.get_screen_size()
|
||||
screenshot = await self.screenshot()
|
||||
screenshot_img = bytes_to_image(screenshot)
|
||||
screenshot_width, screenshot_height = screenshot_img.size
|
||||
|
||||
# Calculate scaling factors
|
||||
width_scale = screenshot_width / screen_size["width"]
|
||||
height_scale = screenshot_height / screen_size["height"]
|
||||
|
||||
# Convert coordinates
|
||||
screenshot_x = x * width_scale
|
||||
screenshot_y = y * height_scale
|
||||
|
||||
return screenshot_x, screenshot_y
|
||||
695
libs/python/computer/computer/interface/macos.py
Normal file
695
libs/python/computer/computer/interface/macos.py
Normal file
@@ -0,0 +1,695 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from PIL import Image
|
||||
|
||||
import websockets
|
||||
|
||||
from ..logger import Logger, LogLevel
|
||||
from .base import BaseComputerInterface
|
||||
from ..utils import decode_base64_image, encode_base64_image, bytes_to_image, draw_box, resize_image
|
||||
from .models import Key, KeyType, MouseButton
|
||||
|
||||
|
||||
class MacOSComputerInterface(BaseComputerInterface):
|
||||
"""Interface for macOS."""
|
||||
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
|
||||
super().__init__(ip_address, username, password, api_key, vm_name)
|
||||
self._ws = None
|
||||
self._reconnect_task = None
|
||||
self._closed = False
|
||||
self._last_ping = 0
|
||||
self._ping_interval = 5 # Send ping every 5 seconds
|
||||
self._ping_timeout = 120 # Wait 120 seconds for pong response
|
||||
self._reconnect_delay = 1 # Start with 1 second delay
|
||||
self._max_reconnect_delay = 30 # Maximum delay between reconnection attempts
|
||||
self._log_connection_attempts = True # Flag to control connection attempt logging
|
||||
self._command_lock = asyncio.Lock() # Lock to ensure only one command at a time
|
||||
|
||||
# Set logger name for macOS interface
|
||||
self.logger = Logger("computer.interface.macos", LogLevel.NORMAL)
|
||||
|
||||
@property
|
||||
def ws_uri(self) -> str:
|
||||
"""Get the WebSocket URI using the current IP address.
|
||||
|
||||
Returns:
|
||||
WebSocket URI for the Computer API Server
|
||||
"""
|
||||
protocol = "wss" if self.api_key else "ws"
|
||||
port = "8443" if self.api_key else "8000"
|
||||
return f"{protocol}://{self.ip_address}:{port}/ws"
|
||||
|
||||
async def _keep_alive(self):
|
||||
"""Keep the WebSocket connection alive with automatic reconnection."""
|
||||
retry_count = 0
|
||||
max_log_attempts = 1 # Only log the first attempt at INFO level
|
||||
log_interval = 500 # Then log every 500th attempt (significantly increased from 30)
|
||||
last_warning_time = 0
|
||||
min_warning_interval = 30 # Minimum seconds between connection lost warnings
|
||||
min_retry_delay = 0.5 # Minimum delay between connection attempts (500ms)
|
||||
|
||||
while not self._closed:
|
||||
try:
|
||||
if self._ws is None or (
|
||||
self._ws and self._ws.state == websockets.protocol.State.CLOSED
|
||||
):
|
||||
try:
|
||||
retry_count += 1
|
||||
|
||||
# Add a minimum delay between connection attempts to avoid flooding
|
||||
if retry_count > 1:
|
||||
await asyncio.sleep(min_retry_delay)
|
||||
|
||||
# Only log the first attempt at INFO level, then every Nth attempt
|
||||
if retry_count == 1:
|
||||
self.logger.info(f"Attempting WebSocket connection to {self.ws_uri}")
|
||||
elif retry_count % log_interval == 0:
|
||||
self.logger.info(
|
||||
f"Still attempting WebSocket connection (attempt {retry_count})..."
|
||||
)
|
||||
else:
|
||||
# All other attempts are logged at DEBUG level
|
||||
self.logger.debug(
|
||||
f"Attempting WebSocket connection to {self.ws_uri} (attempt {retry_count})"
|
||||
)
|
||||
|
||||
self._ws = await asyncio.wait_for(
|
||||
websockets.connect(
|
||||
self.ws_uri,
|
||||
max_size=1024 * 1024 * 10, # 10MB limit
|
||||
max_queue=32,
|
||||
ping_interval=self._ping_interval,
|
||||
ping_timeout=self._ping_timeout,
|
||||
close_timeout=5,
|
||||
compression=None, # Disable compression to reduce overhead
|
||||
),
|
||||
timeout=120,
|
||||
)
|
||||
self.logger.info("WebSocket connection established")
|
||||
|
||||
# If api_key and vm_name are provided, perform authentication handshake
|
||||
if self.api_key and self.vm_name:
|
||||
self.logger.info("Performing authentication handshake...")
|
||||
auth_message = {
|
||||
"command": "authenticate",
|
||||
"params": {
|
||||
"api_key": self.api_key,
|
||||
"container_name": self.vm_name
|
||||
}
|
||||
}
|
||||
await self._ws.send(json.dumps(auth_message))
|
||||
|
||||
# Wait for authentication response
|
||||
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
|
||||
auth_result = json.loads(auth_response)
|
||||
|
||||
if not auth_result.get("success"):
|
||||
error_msg = auth_result.get("error", "Authentication failed")
|
||||
self.logger.error(f"Authentication failed: {error_msg}")
|
||||
await self._ws.close()
|
||||
self._ws = None
|
||||
raise ConnectionError(f"Authentication failed: {error_msg}")
|
||||
|
||||
self.logger.info("Authentication successful")
|
||||
|
||||
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
|
||||
self._last_ping = time.time()
|
||||
retry_count = 0 # Reset retry count on successful connection
|
||||
except (asyncio.TimeoutError, websockets.exceptions.WebSocketException) as e:
|
||||
next_retry = self._reconnect_delay
|
||||
|
||||
# Only log the first error at WARNING level, then every Nth attempt
|
||||
if retry_count == 1:
|
||||
self.logger.warning(
|
||||
f"Computer API Server not ready yet. Will retry automatically."
|
||||
)
|
||||
elif retry_count % log_interval == 0:
|
||||
self.logger.warning(
|
||||
f"Still waiting for Computer API Server (attempt {retry_count})..."
|
||||
)
|
||||
else:
|
||||
# All other errors are logged at DEBUG level
|
||||
self.logger.debug(f"Connection attempt {retry_count} failed: {e}")
|
||||
|
||||
if self._ws:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except:
|
||||
pass
|
||||
self._ws = None
|
||||
|
||||
# Use exponential backoff for connection retries
|
||||
await asyncio.sleep(self._reconnect_delay)
|
||||
self._reconnect_delay = min(
|
||||
self._reconnect_delay * 2, self._max_reconnect_delay
|
||||
)
|
||||
continue
|
||||
|
||||
# Regular ping to check connection
|
||||
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
|
||||
try:
|
||||
if time.time() - self._last_ping >= self._ping_interval:
|
||||
pong_waiter = await self._ws.ping()
|
||||
await asyncio.wait_for(pong_waiter, timeout=self._ping_timeout)
|
||||
self._last_ping = time.time()
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Ping failed: {e}")
|
||||
if self._ws:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except:
|
||||
pass
|
||||
self._ws = None
|
||||
continue
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
current_time = time.time()
|
||||
# Only log connection lost warnings at most once every min_warning_interval seconds
|
||||
if current_time - last_warning_time >= min_warning_interval:
|
||||
self.logger.warning(
|
||||
f"Computer API Server connection lost. Will retry automatically."
|
||||
)
|
||||
last_warning_time = current_time
|
||||
else:
|
||||
# Log at debug level instead
|
||||
self.logger.debug(f"Connection lost: {e}")
|
||||
|
||||
if self._ws:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except:
|
||||
pass
|
||||
self._ws = None
|
||||
|
||||
async def _ensure_connection(self):
|
||||
"""Ensure WebSocket connection is established."""
|
||||
if self._reconnect_task is None or self._reconnect_task.done():
|
||||
self._reconnect_task = asyncio.create_task(self._keep_alive())
|
||||
|
||||
retry_count = 0
|
||||
max_retries = 5
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
|
||||
return
|
||||
retry_count += 1
|
||||
await asyncio.sleep(1)
|
||||
except Exception as e:
|
||||
# Only log at ERROR level for the last retry attempt
|
||||
if retry_count == max_retries - 1:
|
||||
self.logger.error(
|
||||
f"Persistent connection check error after {retry_count} attempts: {e}"
|
||||
)
|
||||
else:
|
||||
self.logger.debug(f"Connection check error (attempt {retry_count}): {e}")
|
||||
retry_count += 1
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
raise ConnectionError("Failed to establish WebSocket connection after multiple retries")
|
||||
|
||||
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Send command through WebSocket."""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
# Acquire lock to ensure only one command is processed at a time
|
||||
async with self._command_lock:
|
||||
self.logger.debug(f"Acquired lock for command: {command}")
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
await self._ensure_connection()
|
||||
if not self._ws:
|
||||
raise ConnectionError("WebSocket connection is not established")
|
||||
|
||||
message = {"command": command, "params": params or {}}
|
||||
await self._ws.send(json.dumps(message))
|
||||
response = await asyncio.wait_for(self._ws.recv(), timeout=120)
|
||||
self.logger.debug(f"Completed command: {command}")
|
||||
return json.loads(response)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
retry_count += 1
|
||||
if retry_count < max_retries:
|
||||
# Only log at debug level for intermediate retries
|
||||
self.logger.debug(
|
||||
f"Command '{command}' failed (attempt {retry_count}/{max_retries}): {e}"
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
# Only log at error level for the final failure
|
||||
self.logger.error(
|
||||
f"Failed to send command '{command}' after {max_retries} retries"
|
||||
)
|
||||
self.logger.debug(f"Command failure details: {e}")
|
||||
raise
|
||||
|
||||
raise last_error if last_error else RuntimeError("Failed to send command")
|
||||
|
||||
async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0):
|
||||
"""Wait for WebSocket connection to become available."""
|
||||
start_time = time.time()
|
||||
last_error = None
|
||||
attempt_count = 0
|
||||
progress_interval = 10 # Log progress every 10 seconds
|
||||
last_progress_time = start_time
|
||||
|
||||
# Disable detailed logging for connection attempts
|
||||
self._log_connection_attempts = False
|
||||
|
||||
try:
|
||||
self.logger.info(
|
||||
f"Waiting for Computer API Server to be ready (timeout: {timeout}s)..."
|
||||
)
|
||||
|
||||
# Start the keep-alive task if it's not already running
|
||||
if self._reconnect_task is None or self._reconnect_task.done():
|
||||
self._reconnect_task = asyncio.create_task(self._keep_alive())
|
||||
|
||||
# Wait for the connection to be established
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
attempt_count += 1
|
||||
current_time = time.time()
|
||||
|
||||
# Log progress periodically without flooding logs
|
||||
if current_time - last_progress_time >= progress_interval:
|
||||
elapsed = current_time - start_time
|
||||
self.logger.info(
|
||||
f"Still waiting for Computer API Server... (elapsed: {elapsed:.1f}s, attempts: {attempt_count})"
|
||||
)
|
||||
last_progress_time = current_time
|
||||
|
||||
# Check if we have a connection
|
||||
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
|
||||
# Test the connection with a simple command
|
||||
try:
|
||||
await self._send_command("get_screen_size")
|
||||
elapsed = time.time() - start_time
|
||||
self.logger.info(
|
||||
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"
|
||||
)
|
||||
return # Connection is fully working
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.logger.debug(f"Connection test failed: {e}")
|
||||
|
||||
# Wait before trying again
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.logger.debug(f"Connection attempt {attempt_count} failed: {e}")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
# If we get here, we've timed out
|
||||
error_msg = f"Could not connect to {self.ip_address} after {timeout} seconds"
|
||||
if last_error:
|
||||
error_msg += f": {str(last_error)}"
|
||||
self.logger.error(error_msg)
|
||||
raise TimeoutError(error_msg)
|
||||
finally:
|
||||
# Reset to default logging behavior
|
||||
self._log_connection_attempts = False
|
||||
|
||||
def close(self):
|
||||
"""Close WebSocket connection.
|
||||
|
||||
Note: In host computer server mode, we leave the connection open
|
||||
to allow other clients to connect to the same server. The server
|
||||
will handle cleaning up idle connections.
|
||||
"""
|
||||
# Only cancel the reconnect task
|
||||
if self._reconnect_task:
|
||||
self._reconnect_task.cancel()
|
||||
|
||||
# Don't set closed flag or close websocket by default
|
||||
# This allows the server to stay connected for other clients
|
||||
# self._closed = True
|
||||
# if self._ws:
|
||||
# asyncio.create_task(self._ws.close())
|
||||
# self._ws = None
|
||||
|
||||
def force_close(self):
|
||||
"""Force close the WebSocket connection.
|
||||
|
||||
This method should be called when you want to completely
|
||||
shut down the connection, not just for regular cleanup.
|
||||
"""
|
||||
self._closed = True
|
||||
if self._reconnect_task:
|
||||
self._reconnect_task.cancel()
|
||||
if self._ws:
|
||||
asyncio.create_task(self._ws.close())
|
||||
self._ws = None
|
||||
|
||||
async def diorama_cmd(self, action: str, arguments: Optional[dict] = None) -> dict:
|
||||
"""Send a diorama command to the server (macOS only)."""
|
||||
return await self._send_command("diorama_cmd", {"action": action, "arguments": arguments or {}})
|
||||
|
||||
# Mouse Actions
|
||||
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: "MouseButton" = "left") -> None:
|
||||
await self._send_command("mouse_down", {"x": x, "y": y, "button": button})
|
||||
|
||||
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: "MouseButton" = "left") -> None:
|
||||
await self._send_command("mouse_up", {"x": x, "y": y, "button": button})
|
||||
|
||||
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
await self._send_command("left_click", {"x": x, "y": y})
|
||||
|
||||
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
await self._send_command("right_click", {"x": x, "y": y})
|
||||
|
||||
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
await self._send_command("double_click", {"x": x, "y": y})
|
||||
|
||||
async def move_cursor(self, x: int, y: int) -> None:
|
||||
await self._send_command("move_cursor", {"x": x, "y": y})
|
||||
|
||||
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> None:
|
||||
await self._send_command(
|
||||
"drag_to", {"x": x, "y": y, "button": button, "duration": duration}
|
||||
)
|
||||
|
||||
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> None:
|
||||
await self._send_command(
|
||||
"drag", {"path": path, "button": button, "duration": duration}
|
||||
)
|
||||
|
||||
# Keyboard Actions
|
||||
async def key_down(self, key: "KeyType") -> None:
|
||||
await self._send_command("key_down", {"key": key})
|
||||
|
||||
async def key_up(self, key: "KeyType") -> None:
|
||||
await self._send_command("key_up", {"key": key})
|
||||
|
||||
async def type_text(self, text: str) -> None:
|
||||
# Temporary fix for https://github.com/trycua/cua/issues/165
|
||||
# Check if text contains Unicode characters
|
||||
if any(ord(char) > 127 for char in text):
|
||||
# For Unicode text, use clipboard and paste
|
||||
await self.set_clipboard(text)
|
||||
await self.hotkey(Key.COMMAND, 'v')
|
||||
else:
|
||||
# For ASCII text, use the regular typing method
|
||||
await self._send_command("type_text", {"text": text})
|
||||
|
||||
async def press(self, key: "KeyType") -> None:
|
||||
"""Press a single key.
|
||||
|
||||
Args:
|
||||
key: The key to press. Can be any of:
|
||||
- A Key enum value (recommended), e.g. Key.PAGE_DOWN
|
||||
- A direct key value string, e.g. 'pagedown'
|
||||
- A single character string, e.g. 'a'
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Using enum (recommended)
|
||||
await interface.press(Key.PAGE_DOWN)
|
||||
await interface.press(Key.ENTER)
|
||||
|
||||
# Using direct values
|
||||
await interface.press('pagedown')
|
||||
await interface.press('enter')
|
||||
|
||||
# Using single characters
|
||||
await interface.press('a')
|
||||
```
|
||||
|
||||
Raises:
|
||||
ValueError: If the key type is invalid or the key is not recognized
|
||||
"""
|
||||
if isinstance(key, Key):
|
||||
actual_key = key.value
|
||||
elif isinstance(key, str):
|
||||
# Try to convert to enum if it matches a known key
|
||||
key_or_enum = Key.from_string(key)
|
||||
actual_key = key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum
|
||||
else:
|
||||
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
|
||||
|
||||
await self._send_command("press_key", {"key": actual_key})
|
||||
|
||||
async def press_key(self, key: "KeyType") -> None:
|
||||
"""DEPRECATED: Use press() instead.
|
||||
|
||||
This method is kept for backward compatibility but will be removed in a future version.
|
||||
Please use the press() method instead.
|
||||
"""
|
||||
await self.press(key)
|
||||
|
||||
async def hotkey(self, *keys: "KeyType") -> None:
|
||||
"""Press multiple keys simultaneously.
|
||||
|
||||
Args:
|
||||
*keys: Multiple keys to press simultaneously. Each key can be any of:
|
||||
- A Key enum value (recommended), e.g. Key.COMMAND
|
||||
- A direct key value string, e.g. 'command'
|
||||
- A single character string, e.g. 'a'
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Using enums (recommended)
|
||||
await interface.hotkey(Key.COMMAND, Key.C) # Copy
|
||||
await interface.hotkey(Key.COMMAND, Key.V) # Paste
|
||||
|
||||
# Using mixed formats
|
||||
await interface.hotkey(Key.COMMAND, 'a') # Select all
|
||||
```
|
||||
|
||||
Raises:
|
||||
ValueError: If any key type is invalid or not recognized
|
||||
"""
|
||||
actual_keys = []
|
||||
for key in keys:
|
||||
if isinstance(key, Key):
|
||||
actual_keys.append(key.value)
|
||||
elif isinstance(key, str):
|
||||
# Try to convert to enum if it matches a known key
|
||||
key_or_enum = Key.from_string(key)
|
||||
actual_keys.append(key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum)
|
||||
else:
|
||||
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
|
||||
|
||||
await self._send_command("hotkey", {"keys": actual_keys})
|
||||
|
||||
# Scrolling Actions
|
||||
async def scroll(self, x: int, y: int) -> None:
|
||||
await self._send_command("scroll", {"x": x, "y": y})
|
||||
|
||||
async def scroll_down(self, clicks: int = 1) -> None:
|
||||
await self._send_command("scroll_down", {"clicks": clicks})
|
||||
|
||||
async def scroll_up(self, clicks: int = 1) -> None:
|
||||
await self._send_command("scroll_up", {"clicks": clicks})
|
||||
|
||||
# Screen Actions
|
||||
async def screenshot(
|
||||
self,
|
||||
boxes: Optional[List[Tuple[int, int, int, int]]] = None,
|
||||
box_color: str = "#FF0000",
|
||||
box_thickness: int = 2,
|
||||
scale_factor: float = 1.0,
|
||||
) -> bytes:
|
||||
"""Take a screenshot with optional box drawing and scaling.
|
||||
|
||||
Args:
|
||||
boxes: Optional list of (x, y, width, height) tuples defining boxes to draw in screen coordinates
|
||||
box_color: Color of the boxes in hex format (default: "#FF0000" red)
|
||||
box_thickness: Thickness of the box borders in pixels (default: 2)
|
||||
scale_factor: Factor to scale the final image by (default: 1.0)
|
||||
Use > 1.0 to enlarge, < 1.0 to shrink (e.g., 0.5 for half size, 2.0 for double)
|
||||
|
||||
Returns:
|
||||
bytes: The screenshot image data, optionally with boxes drawn on it and scaled
|
||||
"""
|
||||
result = await self._send_command("screenshot")
|
||||
if not result.get("image_data"):
|
||||
raise RuntimeError("Failed to take screenshot")
|
||||
|
||||
screenshot = decode_base64_image(result["image_data"])
|
||||
|
||||
if boxes:
|
||||
# Get the natural scaling between screen and screenshot
|
||||
screen_size = await self.get_screen_size()
|
||||
screenshot_width, screenshot_height = bytes_to_image(screenshot).size
|
||||
width_scale = screenshot_width / screen_size["width"]
|
||||
height_scale = screenshot_height / screen_size["height"]
|
||||
|
||||
# Scale box coordinates from screen space to screenshot space
|
||||
for box in boxes:
|
||||
scaled_box = (
|
||||
int(box[0] * width_scale), # x
|
||||
int(box[1] * height_scale), # y
|
||||
int(box[2] * width_scale), # width
|
||||
int(box[3] * height_scale), # height
|
||||
)
|
||||
screenshot = draw_box(
|
||||
screenshot,
|
||||
x=scaled_box[0],
|
||||
y=scaled_box[1],
|
||||
width=scaled_box[2],
|
||||
height=scaled_box[3],
|
||||
color=box_color,
|
||||
thickness=box_thickness,
|
||||
)
|
||||
|
||||
if scale_factor != 1.0:
|
||||
screenshot = resize_image(screenshot, scale_factor)
|
||||
|
||||
return screenshot
|
||||
|
||||
async def get_screen_size(self) -> Dict[str, int]:
|
||||
result = await self._send_command("get_screen_size")
|
||||
if result["success"] and result["size"]:
|
||||
return result["size"]
|
||||
raise RuntimeError("Failed to get screen size")
|
||||
|
||||
async def get_cursor_position(self) -> Dict[str, int]:
|
||||
result = await self._send_command("get_cursor_position")
|
||||
if result["success"] and result["position"]:
|
||||
return result["position"]
|
||||
raise RuntimeError("Failed to get cursor position")
|
||||
|
||||
# Clipboard Actions
|
||||
async def copy_to_clipboard(self) -> str:
|
||||
result = await self._send_command("copy_to_clipboard")
|
||||
if result["success"] and result["content"]:
|
||||
return result["content"]
|
||||
raise RuntimeError("Failed to get clipboard content")
|
||||
|
||||
async def set_clipboard(self, text: str) -> None:
|
||||
await self._send_command("set_clipboard", {"text": text})
|
||||
|
||||
# File System Actions
|
||||
async def file_exists(self, path: str) -> bool:
|
||||
result = await self._send_command("file_exists", {"path": path})
|
||||
return result.get("exists", False)
|
||||
|
||||
async def directory_exists(self, path: str) -> bool:
|
||||
result = await self._send_command("directory_exists", {"path": path})
|
||||
return result.get("exists", False)
|
||||
|
||||
async def list_dir(self, path: str) -> list[str]:
|
||||
result = await self._send_command("list_dir", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to list directory"))
|
||||
return result.get("files", [])
|
||||
|
||||
async def read_text(self, path: str) -> str:
|
||||
result = await self._send_command("read_text", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to read file"))
|
||||
return result.get("content", "")
|
||||
|
||||
async def write_text(self, path: str, content: str) -> None:
|
||||
result = await self._send_command("write_text", {"path": path, "content": content})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to write file"))
|
||||
|
||||
async def read_bytes(self, path: str) -> bytes:
|
||||
result = await self._send_command("read_bytes", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to read file"))
|
||||
content_b64 = result.get("content_b64", "")
|
||||
return decode_base64_image(content_b64)
|
||||
|
||||
async def write_bytes(self, path: str, content: bytes) -> None:
|
||||
result = await self._send_command("write_bytes", {"path": path, "content_b64": encode_base64_image(content)})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to write file"))
|
||||
|
||||
async def delete_file(self, path: str) -> None:
|
||||
result = await self._send_command("delete_file", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to delete file"))
|
||||
|
||||
async def create_dir(self, path: str) -> None:
|
||||
result = await self._send_command("create_dir", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to create directory"))
|
||||
|
||||
async def delete_dir(self, path: str) -> None:
|
||||
result = await self._send_command("delete_dir", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to delete directory"))
|
||||
|
||||
async def run_command(self, command: str) -> Tuple[str, str]:
|
||||
result = await self._send_command("run_command", {"command": command})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to run command"))
|
||||
return result.get("stdout", ""), result.get("stderr", "")
|
||||
|
||||
# Accessibility Actions
|
||||
async def get_accessibility_tree(self) -> Dict[str, Any]:
|
||||
"""Get the accessibility tree of the current screen."""
|
||||
result = await self._send_command("get_accessibility_tree")
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to get accessibility tree"))
|
||||
return result
|
||||
|
||||
async def get_active_window_bounds(self) -> Dict[str, int]:
|
||||
"""Get the bounds of the currently active window."""
|
||||
result = await self._send_command("get_active_window_bounds")
|
||||
if result["success"] and result["bounds"]:
|
||||
return result["bounds"]
|
||||
raise RuntimeError("Failed to get active window bounds")
|
||||
|
||||
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
|
||||
"""Convert screenshot coordinates to screen coordinates.
|
||||
|
||||
Args:
|
||||
x: X coordinate in screenshot space
|
||||
y: Y coordinate in screenshot space
|
||||
|
||||
Returns:
|
||||
tuple[float, float]: (x, y) coordinates in screen space
|
||||
"""
|
||||
screen_size = await self.get_screen_size()
|
||||
screenshot = await self.screenshot()
|
||||
screenshot_img = bytes_to_image(screenshot)
|
||||
screenshot_width, screenshot_height = screenshot_img.size
|
||||
|
||||
# Calculate scaling factors
|
||||
width_scale = screen_size["width"] / screenshot_width
|
||||
height_scale = screen_size["height"] / screenshot_height
|
||||
|
||||
# Convert coordinates
|
||||
screen_x = x * width_scale
|
||||
screen_y = y * height_scale
|
||||
|
||||
return screen_x, screen_y
|
||||
|
||||
async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]:
|
||||
"""Convert screen coordinates to screenshot coordinates.
|
||||
|
||||
Args:
|
||||
x: X coordinate in screen space
|
||||
y: Y coordinate in screen space
|
||||
|
||||
Returns:
|
||||
tuple[float, float]: (x, y) coordinates in screenshot space
|
||||
"""
|
||||
screen_size = await self.get_screen_size()
|
||||
screenshot = await self.screenshot()
|
||||
screenshot_img = bytes_to_image(screenshot)
|
||||
screenshot_width, screenshot_height = screenshot_img.size
|
||||
|
||||
# Calculate scaling factors
|
||||
width_scale = screenshot_width / screen_size["width"]
|
||||
height_scale = screenshot_height / screen_size["height"]
|
||||
|
||||
# Convert coordinates
|
||||
screenshot_x = x * width_scale
|
||||
screenshot_y = y * height_scale
|
||||
|
||||
return screenshot_x, screenshot_y
|
||||
124
libs/python/computer/computer/interface/models.py
Normal file
124
libs/python/computer/computer/interface/models.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Any, TypedDict, Union, Literal
|
||||
|
||||
# Navigation key literals
|
||||
NavigationKey = Literal['pagedown', 'pageup', 'home', 'end', 'left', 'right', 'up', 'down']
|
||||
|
||||
# Special key literals
|
||||
SpecialKey = Literal['enter', 'esc', 'tab', 'space', 'backspace', 'del']
|
||||
|
||||
# Modifier key literals
|
||||
ModifierKey = Literal['ctrl', 'alt', 'shift', 'win', 'command', 'option']
|
||||
|
||||
# Function key literals
|
||||
FunctionKey = Literal['f1', 'f2', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'f10', 'f11', 'f12']
|
||||
|
||||
class Key(Enum):
|
||||
"""Keyboard keys that can be used with press_key.
|
||||
|
||||
These key names map to PyAutoGUI's expected key names.
|
||||
"""
|
||||
# Navigation
|
||||
PAGE_DOWN = 'pagedown'
|
||||
PAGE_UP = 'pageup'
|
||||
HOME = 'home'
|
||||
END = 'end'
|
||||
LEFT = 'left'
|
||||
RIGHT = 'right'
|
||||
UP = 'up'
|
||||
DOWN = 'down'
|
||||
|
||||
# Special keys
|
||||
RETURN = 'enter'
|
||||
ENTER = 'enter'
|
||||
ESCAPE = 'esc'
|
||||
ESC = 'esc'
|
||||
TAB = 'tab'
|
||||
SPACE = 'space'
|
||||
BACKSPACE = 'backspace'
|
||||
DELETE = 'del'
|
||||
|
||||
# Modifier keys
|
||||
ALT = 'alt'
|
||||
CTRL = 'ctrl'
|
||||
SHIFT = 'shift'
|
||||
WIN = 'win'
|
||||
COMMAND = 'command'
|
||||
OPTION = 'option'
|
||||
|
||||
# Function keys
|
||||
F1 = 'f1'
|
||||
F2 = 'f2'
|
||||
F3 = 'f3'
|
||||
F4 = 'f4'
|
||||
F5 = 'f5'
|
||||
F6 = 'f6'
|
||||
F7 = 'f7'
|
||||
F8 = 'f8'
|
||||
F9 = 'f9'
|
||||
F10 = 'f10'
|
||||
F11 = 'f11'
|
||||
F12 = 'f12'
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, key: str) -> 'Key | str':
|
||||
"""Convert a string key name to a Key enum value.
|
||||
|
||||
Args:
|
||||
key: String key name to convert
|
||||
|
||||
Returns:
|
||||
Key enum value if the string matches a known key,
|
||||
otherwise returns the original string for single character keys
|
||||
"""
|
||||
# Map common alternative names to enum values
|
||||
key_mapping = {
|
||||
'page_down': cls.PAGE_DOWN,
|
||||
'page down': cls.PAGE_DOWN,
|
||||
'pagedown': cls.PAGE_DOWN,
|
||||
'page_up': cls.PAGE_UP,
|
||||
'page up': cls.PAGE_UP,
|
||||
'pageup': cls.PAGE_UP,
|
||||
'return': cls.RETURN,
|
||||
'enter': cls.ENTER,
|
||||
'escape': cls.ESCAPE,
|
||||
'esc': cls.ESC,
|
||||
'delete': cls.DELETE,
|
||||
'del': cls.DELETE,
|
||||
# Modifier key mappings
|
||||
'alt': cls.ALT,
|
||||
'ctrl': cls.CTRL,
|
||||
'control': cls.CTRL,
|
||||
'shift': cls.SHIFT,
|
||||
'win': cls.WIN,
|
||||
'windows': cls.WIN,
|
||||
'super': cls.WIN,
|
||||
'command': cls.COMMAND,
|
||||
'cmd': cls.COMMAND,
|
||||
'⌘': cls.COMMAND,
|
||||
'option': cls.OPTION,
|
||||
'⌥': cls.OPTION,
|
||||
}
|
||||
|
||||
normalized = key.lower().strip()
|
||||
return key_mapping.get(normalized, key)
|
||||
|
||||
# Combined key type
|
||||
KeyType = Union[Key, NavigationKey, SpecialKey, ModifierKey, FunctionKey, str]
|
||||
|
||||
# Key type for mouse actions
|
||||
MouseButton = Literal['left', 'right', 'middle']
|
||||
|
||||
class AccessibilityWindow(TypedDict):
|
||||
"""Information about a window in the accessibility tree."""
|
||||
app_name: str
|
||||
pid: int
|
||||
frontmost: bool
|
||||
has_windows: bool
|
||||
windows: List[Dict[str, Any]]
|
||||
|
||||
class AccessibilityTree(TypedDict):
|
||||
"""Complete accessibility tree information."""
|
||||
success: bool
|
||||
frontmost_application: str
|
||||
windows: List[AccessibilityWindow]
|
||||
687
libs/python/computer/computer/interface/windows.py
Normal file
687
libs/python/computer/computer/interface/windows.py
Normal file
@@ -0,0 +1,687 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from PIL import Image
|
||||
|
||||
import websockets
|
||||
|
||||
from ..logger import Logger, LogLevel
|
||||
from .base import BaseComputerInterface
|
||||
from ..utils import decode_base64_image, encode_base64_image, bytes_to_image, draw_box, resize_image
|
||||
from .models import Key, KeyType, MouseButton
|
||||
|
||||
|
||||
class WindowsComputerInterface(BaseComputerInterface):
|
||||
"""Interface for Windows."""
|
||||
|
||||
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
|
||||
super().__init__(ip_address, username, password, api_key, vm_name)
|
||||
self._ws = None
|
||||
self._reconnect_task = None
|
||||
self._closed = False
|
||||
self._last_ping = 0
|
||||
self._ping_interval = 5 # Send ping every 5 seconds
|
||||
self._ping_timeout = 120 # Wait 120 seconds for pong response
|
||||
self._reconnect_delay = 1 # Start with 1 second delay
|
||||
self._max_reconnect_delay = 30 # Maximum delay between reconnection attempts
|
||||
self._log_connection_attempts = True # Flag to control connection attempt logging
|
||||
self._authenticated = False # Track authentication status
|
||||
self._command_lock = asyncio.Lock() # Lock to ensure only one command at a time
|
||||
|
||||
# Set logger name for Windows interface
|
||||
self.logger = Logger("computer.interface.windows", LogLevel.NORMAL)
|
||||
|
||||
@property
|
||||
def ws_uri(self) -> str:
|
||||
"""Get the WebSocket URI using the current IP address.
|
||||
|
||||
Returns:
|
||||
WebSocket URI for the Computer API Server
|
||||
"""
|
||||
protocol = "wss" if self.api_key else "ws"
|
||||
port = "8443" if self.api_key else "8000"
|
||||
return f"{protocol}://{self.ip_address}:{port}/ws"
|
||||
|
||||
async def _keep_alive(self):
|
||||
"""Keep the WebSocket connection alive with automatic reconnection."""
|
||||
retry_count = 0
|
||||
max_log_attempts = 1 # Only log the first attempt at INFO level
|
||||
log_interval = 500 # Then log every 500th attempt (significantly increased from 30)
|
||||
last_warning_time = 0
|
||||
min_warning_interval = 30 # Minimum seconds between connection lost warnings
|
||||
min_retry_delay = 0.5 # Minimum delay between connection attempts (500ms)
|
||||
|
||||
while not self._closed:
|
||||
try:
|
||||
if self._ws is None or (
|
||||
self._ws and self._ws.state == websockets.protocol.State.CLOSED
|
||||
):
|
||||
try:
|
||||
retry_count += 1
|
||||
|
||||
# Add a minimum delay between connection attempts to avoid flooding
|
||||
if retry_count > 1:
|
||||
await asyncio.sleep(min_retry_delay)
|
||||
|
||||
# Only log the first attempt at INFO level, then every Nth attempt
|
||||
if retry_count == 1:
|
||||
self.logger.info(f"Attempting WebSocket connection to {self.ws_uri}")
|
||||
elif retry_count % log_interval == 0:
|
||||
self.logger.info(
|
||||
f"Still attempting WebSocket connection (attempt {retry_count})..."
|
||||
)
|
||||
else:
|
||||
# All other attempts are logged at DEBUG level
|
||||
self.logger.debug(
|
||||
f"Attempting WebSocket connection to {self.ws_uri} (attempt {retry_count})"
|
||||
)
|
||||
|
||||
self._ws = await asyncio.wait_for(
|
||||
websockets.connect(
|
||||
self.ws_uri,
|
||||
max_size=1024 * 1024 * 10, # 10MB limit
|
||||
max_queue=32,
|
||||
ping_interval=self._ping_interval,
|
||||
ping_timeout=self._ping_timeout,
|
||||
close_timeout=5,
|
||||
compression=None, # Disable compression to reduce overhead
|
||||
),
|
||||
timeout=120,
|
||||
)
|
||||
self.logger.info("WebSocket connection established")
|
||||
|
||||
# Authentication will be handled by the first command that needs it
|
||||
# Don't do authentication here to avoid recv conflicts
|
||||
|
||||
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
|
||||
self._last_ping = time.time()
|
||||
retry_count = 0 # Reset retry count on successful connection
|
||||
self._authenticated = False # Reset auth status on new connection
|
||||
|
||||
except (asyncio.TimeoutError, websockets.exceptions.WebSocketException) as e:
|
||||
next_retry = self._reconnect_delay
|
||||
|
||||
# Only log the first error at WARNING level, then every Nth attempt
|
||||
if retry_count == 1:
|
||||
self.logger.warning(
|
||||
f"Computer API Server not ready yet. Will retry automatically."
|
||||
)
|
||||
elif retry_count % log_interval == 0:
|
||||
self.logger.warning(
|
||||
f"Still waiting for Computer API Server (attempt {retry_count})..."
|
||||
)
|
||||
else:
|
||||
# All other errors are logged at DEBUG level
|
||||
self.logger.debug(f"Connection attempt {retry_count} failed: {e}")
|
||||
|
||||
if self._ws:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except:
|
||||
pass
|
||||
self._ws = None
|
||||
|
||||
# Regular ping to check connection
|
||||
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
|
||||
try:
|
||||
if time.time() - self._last_ping >= self._ping_interval:
|
||||
pong_waiter = await self._ws.ping()
|
||||
await asyncio.wait_for(pong_waiter, timeout=self._ping_timeout)
|
||||
self._last_ping = time.time()
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Ping failed: {e}")
|
||||
if self._ws:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except:
|
||||
pass
|
||||
self._ws = None
|
||||
continue
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
current_time = time.time()
|
||||
# Only log connection lost warnings at most once every min_warning_interval seconds
|
||||
if current_time - last_warning_time >= min_warning_interval:
|
||||
self.logger.warning(
|
||||
f"Computer API Server connection lost. Will retry automatically."
|
||||
)
|
||||
last_warning_time = current_time
|
||||
else:
|
||||
# Log at debug level instead
|
||||
self.logger.debug(f"Connection lost: {e}")
|
||||
|
||||
if self._ws:
|
||||
try:
|
||||
await self._ws.close()
|
||||
except:
|
||||
pass
|
||||
self._ws = None
|
||||
|
||||
async def _ensure_connection(self):
|
||||
"""Ensure WebSocket connection is established."""
|
||||
if self._reconnect_task is None or self._reconnect_task.done():
|
||||
self._reconnect_task = asyncio.create_task(self._keep_alive())
|
||||
|
||||
retry_count = 0
|
||||
max_retries = 5
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
|
||||
return
|
||||
retry_count += 1
|
||||
await asyncio.sleep(1)
|
||||
except Exception as e:
|
||||
# Only log at ERROR level for the last retry attempt
|
||||
if retry_count == max_retries - 1:
|
||||
self.logger.error(
|
||||
f"Persistent connection check error after {retry_count} attempts: {e}"
|
||||
)
|
||||
else:
|
||||
self.logger.debug(f"Connection check error (attempt {retry_count}): {e}")
|
||||
retry_count += 1
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
raise ConnectionError("Failed to establish WebSocket connection after multiple retries")
|
||||
|
||||
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
|
||||
"""Send command through WebSocket."""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
# Acquire lock to ensure only one command is processed at a time
|
||||
async with self._command_lock:
|
||||
self.logger.debug(f"Acquired lock for command: {command}")
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
await self._ensure_connection()
|
||||
if not self._ws:
|
||||
raise ConnectionError("WebSocket connection is not established")
|
||||
|
||||
# Handle authentication if needed
|
||||
if self.api_key and self.vm_name and not self._authenticated:
|
||||
self.logger.info("Performing authentication handshake...")
|
||||
auth_message = {
|
||||
"command": "authenticate",
|
||||
"params": {
|
||||
"api_key": self.api_key,
|
||||
"container_name": self.vm_name
|
||||
}
|
||||
}
|
||||
await self._ws.send(json.dumps(auth_message))
|
||||
|
||||
# Wait for authentication response
|
||||
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
|
||||
auth_result = json.loads(auth_response)
|
||||
|
||||
if not auth_result.get("success"):
|
||||
error_msg = auth_result.get("error", "Authentication failed")
|
||||
self.logger.error(f"Authentication failed: {error_msg}")
|
||||
self._authenticated = False
|
||||
raise ConnectionError(f"Authentication failed: {error_msg}")
|
||||
|
||||
self.logger.info("Authentication successful")
|
||||
self._authenticated = True
|
||||
|
||||
message = {"command": command, "params": params or {}}
|
||||
await self._ws.send(json.dumps(message))
|
||||
response = await asyncio.wait_for(self._ws.recv(), timeout=30)
|
||||
self.logger.debug(f"Completed command: {command}")
|
||||
return json.loads(response)
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
retry_count += 1
|
||||
if retry_count < max_retries:
|
||||
# Only log at debug level for intermediate retries
|
||||
self.logger.debug(
|
||||
f"Command '{command}' failed (attempt {retry_count}/{max_retries}): {e}"
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
# Only log at error level for the final failure
|
||||
self.logger.error(
|
||||
f"Failed to send command '{command}' after {max_retries} retries"
|
||||
)
|
||||
self.logger.debug(f"Command failure details: {e}")
|
||||
raise last_error if last_error else RuntimeError("Failed to send command")
|
||||
|
||||
async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0):
|
||||
"""Wait for WebSocket connection to become available."""
|
||||
start_time = time.time()
|
||||
last_error = None
|
||||
attempt_count = 0
|
||||
progress_interval = 10 # Log progress every 10 seconds
|
||||
last_progress_time = start_time
|
||||
|
||||
# Disable detailed logging for connection attempts
|
||||
self._log_connection_attempts = False
|
||||
|
||||
try:
|
||||
self.logger.info(
|
||||
f"Waiting for Computer API Server to be ready (timeout: {timeout}s)..."
|
||||
)
|
||||
|
||||
# Start the keep-alive task if it's not already running
|
||||
if self._reconnect_task is None or self._reconnect_task.done():
|
||||
self._reconnect_task = asyncio.create_task(self._keep_alive())
|
||||
|
||||
# Wait for the connection to be established
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
attempt_count += 1
|
||||
current_time = time.time()
|
||||
|
||||
# Log progress periodically without flooding logs
|
||||
if current_time - last_progress_time >= progress_interval:
|
||||
elapsed = current_time - start_time
|
||||
self.logger.info(
|
||||
f"Still waiting for Computer API Server... (elapsed: {elapsed:.1f}s, attempts: {attempt_count})"
|
||||
)
|
||||
last_progress_time = current_time
|
||||
|
||||
# Check if we have a connection
|
||||
if self._ws and self._ws.state == websockets.protocol.State.OPEN:
|
||||
# Test the connection with a simple command
|
||||
try:
|
||||
await self._send_command("get_screen_size")
|
||||
elapsed = time.time() - start_time
|
||||
self.logger.info(
|
||||
f"Computer API Server is ready (after {elapsed:.1f}s, {attempt_count} attempts)"
|
||||
)
|
||||
return # Connection is fully working
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.logger.debug(f"Connection test failed: {e}")
|
||||
|
||||
# Wait before trying again
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
self.logger.debug(f"Connection attempt {attempt_count} failed: {e}")
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
# If we get here, we've timed out
|
||||
error_msg = f"Could not connect to {self.ip_address} after {timeout} seconds"
|
||||
if last_error:
|
||||
error_msg += f": {str(last_error)}"
|
||||
self.logger.error(error_msg)
|
||||
raise TimeoutError(error_msg)
|
||||
finally:
|
||||
# Reset to default logging behavior
|
||||
self._log_connection_attempts = False
|
||||
|
||||
def close(self):
|
||||
"""Close WebSocket connection.
|
||||
|
||||
Note: In host computer server mode, we leave the connection open
|
||||
to allow other clients to connect to the same server. The server
|
||||
will handle cleaning up idle connections.
|
||||
"""
|
||||
# Only cancel the reconnect task
|
||||
if self._reconnect_task:
|
||||
self._reconnect_task.cancel()
|
||||
|
||||
# Don't set closed flag or close websocket by default
|
||||
# This allows the server to stay connected for other clients
|
||||
# self._closed = True
|
||||
# if self._ws:
|
||||
# asyncio.create_task(self._ws.close())
|
||||
# self._ws = None
|
||||
|
||||
def force_close(self):
|
||||
"""Force close the WebSocket connection.
|
||||
|
||||
This method should be called when you want to completely
|
||||
shut down the connection, not just for regular cleanup.
|
||||
"""
|
||||
self._closed = True
|
||||
if self._reconnect_task:
|
||||
self._reconnect_task.cancel()
|
||||
if self._ws:
|
||||
asyncio.create_task(self._ws.close())
|
||||
self._ws = None
|
||||
|
||||
# Mouse Actions
|
||||
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> None:
|
||||
await self._send_command("mouse_down", {"x": x, "y": y, "button": button})
|
||||
|
||||
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> None:
|
||||
await self._send_command("mouse_up", {"x": x, "y": y, "button": button})
|
||||
|
||||
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
await self._send_command("left_click", {"x": x, "y": y})
|
||||
|
||||
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
await self._send_command("right_click", {"x": x, "y": y})
|
||||
|
||||
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
|
||||
await self._send_command("double_click", {"x": x, "y": y})
|
||||
|
||||
async def move_cursor(self, x: int, y: int) -> None:
|
||||
await self._send_command("move_cursor", {"x": x, "y": y})
|
||||
|
||||
async def drag_to(self, x: int, y: int, button: "MouseButton" = "left", duration: float = 0.5) -> None:
|
||||
await self._send_command(
|
||||
"drag_to", {"x": x, "y": y, "button": button, "duration": duration}
|
||||
)
|
||||
|
||||
async def drag(self, path: List[Tuple[int, int]], button: "MouseButton" = "left", duration: float = 0.5) -> None:
|
||||
await self._send_command(
|
||||
"drag", {"path": path, "button": button, "duration": duration}
|
||||
)
|
||||
|
||||
# Keyboard Actions
|
||||
async def key_down(self, key: "KeyType") -> None:
|
||||
await self._send_command("key_down", {"key": key})
|
||||
|
||||
async def key_up(self, key: "KeyType") -> None:
|
||||
await self._send_command("key_up", {"key": key})
|
||||
|
||||
async def type_text(self, text: str) -> None:
|
||||
# For Windows, use clipboard for Unicode text like Linux
|
||||
if any(ord(char) > 127 for char in text):
|
||||
# For Unicode text, use clipboard and paste
|
||||
await self.set_clipboard(text)
|
||||
await self.hotkey(Key.CTRL, 'v') # Windows uses Ctrl+V instead of Cmd+V
|
||||
else:
|
||||
# For ASCII text, use the regular typing method
|
||||
await self._send_command("type_text", {"text": text})
|
||||
|
||||
async def press(self, key: "KeyType") -> None:
|
||||
"""Press a single key.
|
||||
|
||||
Args:
|
||||
key: The key to press. Can be any of:
|
||||
- A Key enum value (recommended), e.g. Key.PAGE_DOWN
|
||||
- A direct key value string, e.g. 'pagedown'
|
||||
- A single character string, e.g. 'a'
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Using enum (recommended)
|
||||
await interface.press(Key.PAGE_DOWN)
|
||||
await interface.press(Key.ENTER)
|
||||
|
||||
# Using direct values
|
||||
await interface.press('pagedown')
|
||||
await interface.press('enter')
|
||||
|
||||
# Using single characters
|
||||
await interface.press('a')
|
||||
```
|
||||
|
||||
Raises:
|
||||
ValueError: If the key type is invalid or the key is not recognized
|
||||
"""
|
||||
if isinstance(key, Key):
|
||||
actual_key = key.value
|
||||
elif isinstance(key, str):
|
||||
# Try to convert to enum if it matches a known key
|
||||
key_or_enum = Key.from_string(key)
|
||||
actual_key = key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum
|
||||
else:
|
||||
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
|
||||
|
||||
await self._send_command("press_key", {"key": actual_key})
|
||||
|
||||
async def press_key(self, key: "KeyType") -> None:
|
||||
"""DEPRECATED: Use press() instead.
|
||||
|
||||
This method is kept for backward compatibility but will be removed in a future version.
|
||||
Please use the press() method instead.
|
||||
"""
|
||||
await self.press(key)
|
||||
|
||||
async def hotkey(self, *keys: "KeyType") -> None:
|
||||
"""Press multiple keys simultaneously.
|
||||
|
||||
Args:
|
||||
*keys: Multiple keys to press simultaneously. Each key can be any of:
|
||||
- A Key enum value (recommended), e.g. Key.CTRL
|
||||
- A direct key value string, e.g. 'ctrl'
|
||||
- A single character string, e.g. 'a'
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Using enums (recommended)
|
||||
await interface.hotkey(Key.CTRL, Key.C) # Copy
|
||||
await interface.hotkey(Key.CTRL, Key.V) # Paste
|
||||
|
||||
# Using mixed formats
|
||||
await interface.hotkey(Key.CTRL, 'a') # Select all
|
||||
```
|
||||
|
||||
Raises:
|
||||
ValueError: If any key type is invalid or not recognized
|
||||
"""
|
||||
actual_keys = []
|
||||
for key in keys:
|
||||
if isinstance(key, Key):
|
||||
actual_keys.append(key.value)
|
||||
elif isinstance(key, str):
|
||||
# Try to convert to enum if it matches a known key
|
||||
key_or_enum = Key.from_string(key)
|
||||
actual_keys.append(key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum)
|
||||
else:
|
||||
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
|
||||
|
||||
await self._send_command("hotkey", {"keys": actual_keys})
|
||||
|
||||
# Scrolling Actions
|
||||
async def scroll(self, x: int, y: int) -> None:
|
||||
await self._send_command("scroll", {"x": x, "y": y})
|
||||
|
||||
async def scroll_down(self, clicks: int = 1) -> None:
|
||||
await self._send_command("scroll_down", {"clicks": clicks})
|
||||
|
||||
async def scroll_up(self, clicks: int = 1) -> None:
|
||||
await self._send_command("scroll_up", {"clicks": clicks})
|
||||
|
||||
# Screen Actions
|
||||
async def screenshot(
|
||||
self,
|
||||
boxes: Optional[List[Tuple[int, int, int, int]]] = None,
|
||||
box_color: str = "#FF0000",
|
||||
box_thickness: int = 2,
|
||||
scale_factor: float = 1.0,
|
||||
) -> bytes:
|
||||
"""Take a screenshot with optional box drawing and scaling.
|
||||
|
||||
Args:
|
||||
boxes: Optional list of (x, y, width, height) tuples defining boxes to draw in screen coordinates
|
||||
box_color: Color of the boxes in hex format (default: "#FF0000" red)
|
||||
box_thickness: Thickness of the box borders in pixels (default: 2)
|
||||
scale_factor: Factor to scale the final image by (default: 1.0)
|
||||
Use > 1.0 to enlarge, < 1.0 to shrink (e.g., 0.5 for half size, 2.0 for double)
|
||||
|
||||
Returns:
|
||||
bytes: The screenshot image data, optionally with boxes drawn on it and scaled
|
||||
"""
|
||||
result = await self._send_command("screenshot")
|
||||
if not result.get("image_data"):
|
||||
raise RuntimeError("Failed to take screenshot")
|
||||
|
||||
screenshot = decode_base64_image(result["image_data"])
|
||||
|
||||
if boxes:
|
||||
# Get the natural scaling between screen and screenshot
|
||||
screen_size = await self.get_screen_size()
|
||||
screenshot_width, screenshot_height = bytes_to_image(screenshot).size
|
||||
width_scale = screenshot_width / screen_size["width"]
|
||||
height_scale = screenshot_height / screen_size["height"]
|
||||
|
||||
# Scale box coordinates from screen space to screenshot space
|
||||
for box in boxes:
|
||||
scaled_box = (
|
||||
int(box[0] * width_scale), # x
|
||||
int(box[1] * height_scale), # y
|
||||
int(box[2] * width_scale), # width
|
||||
int(box[3] * height_scale), # height
|
||||
)
|
||||
screenshot = draw_box(
|
||||
screenshot,
|
||||
x=scaled_box[0],
|
||||
y=scaled_box[1],
|
||||
width=scaled_box[2],
|
||||
height=scaled_box[3],
|
||||
color=box_color,
|
||||
thickness=box_thickness,
|
||||
)
|
||||
|
||||
if scale_factor != 1.0:
|
||||
screenshot = resize_image(screenshot, scale_factor)
|
||||
|
||||
return screenshot
|
||||
|
||||
async def get_screen_size(self) -> Dict[str, int]:
|
||||
result = await self._send_command("get_screen_size")
|
||||
if result["success"] and result["size"]:
|
||||
return result["size"]
|
||||
raise RuntimeError("Failed to get screen size")
|
||||
|
||||
async def get_cursor_position(self) -> Dict[str, int]:
|
||||
result = await self._send_command("get_cursor_position")
|
||||
if result["success"] and result["position"]:
|
||||
return result["position"]
|
||||
raise RuntimeError("Failed to get cursor position")
|
||||
|
||||
# Clipboard Actions
|
||||
async def copy_to_clipboard(self) -> str:
|
||||
result = await self._send_command("copy_to_clipboard")
|
||||
if result["success"] and result["content"]:
|
||||
return result["content"]
|
||||
raise RuntimeError("Failed to get clipboard content")
|
||||
|
||||
async def set_clipboard(self, text: str) -> None:
|
||||
await self._send_command("set_clipboard", {"text": text})
|
||||
|
||||
# File System Actions
|
||||
async def file_exists(self, path: str) -> bool:
|
||||
result = await self._send_command("file_exists", {"path": path})
|
||||
return result.get("exists", False)
|
||||
|
||||
async def directory_exists(self, path: str) -> bool:
|
||||
result = await self._send_command("directory_exists", {"path": path})
|
||||
return result.get("exists", False)
|
||||
|
||||
async def list_dir(self, path: str) -> list[str]:
|
||||
result = await self._send_command("list_dir", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to list directory"))
|
||||
return result.get("files", [])
|
||||
|
||||
async def read_text(self, path: str) -> str:
|
||||
result = await self._send_command("read_text", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to read file"))
|
||||
return result.get("content", "")
|
||||
|
||||
async def write_text(self, path: str, content: str) -> None:
|
||||
result = await self._send_command("write_text", {"path": path, "content": content})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to write file"))
|
||||
|
||||
async def read_bytes(self, path: str) -> bytes:
|
||||
result = await self._send_command("read_bytes", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to read file"))
|
||||
content_b64 = result.get("content_b64", "")
|
||||
return decode_base64_image(content_b64)
|
||||
|
||||
async def write_bytes(self, path: str, content: bytes) -> None:
|
||||
result = await self._send_command("write_bytes", {"path": path, "content_b64": encode_base64_image(content)})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to write file"))
|
||||
|
||||
async def delete_file(self, path: str) -> None:
|
||||
result = await self._send_command("delete_file", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to delete file"))
|
||||
|
||||
async def create_dir(self, path: str) -> None:
|
||||
result = await self._send_command("create_dir", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to create directory"))
|
||||
|
||||
async def delete_dir(self, path: str) -> None:
|
||||
result = await self._send_command("delete_dir", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to delete directory"))
|
||||
|
||||
async def run_command(self, command: str) -> Tuple[str, str]:
|
||||
result = await self._send_command("run_command", {"command": command})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to run command"))
|
||||
return result.get("stdout", ""), result.get("stderr", "")
|
||||
|
||||
# Accessibility Actions
|
||||
async def get_accessibility_tree(self) -> Dict[str, Any]:
|
||||
"""Get the accessibility tree of the current screen."""
|
||||
result = await self._send_command("get_accessibility_tree")
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to get accessibility tree"))
|
||||
return result
|
||||
|
||||
async def get_active_window_bounds(self) -> Dict[str, int]:
|
||||
"""Get the bounds of the currently active window."""
|
||||
result = await self._send_command("get_active_window_bounds")
|
||||
if result["success"] and result["bounds"]:
|
||||
return result["bounds"]
|
||||
raise RuntimeError("Failed to get active window bounds")
|
||||
|
||||
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
|
||||
"""Convert screenshot coordinates to screen coordinates.
|
||||
|
||||
Args:
|
||||
x: X coordinate in screenshot space
|
||||
y: Y coordinate in screenshot space
|
||||
|
||||
Returns:
|
||||
tuple[float, float]: (x, y) coordinates in screen space
|
||||
"""
|
||||
screen_size = await self.get_screen_size()
|
||||
screenshot = await self.screenshot()
|
||||
screenshot_img = bytes_to_image(screenshot)
|
||||
screenshot_width, screenshot_height = screenshot_img.size
|
||||
|
||||
# Calculate scaling factors
|
||||
width_scale = screen_size["width"] / screenshot_width
|
||||
height_scale = screen_size["height"] / screenshot_height
|
||||
|
||||
# Convert coordinates
|
||||
screen_x = x * width_scale
|
||||
screen_y = y * height_scale
|
||||
|
||||
return screen_x, screen_y
|
||||
|
||||
async def to_screenshot_coordinates(self, x: float, y: float) -> tuple[float, float]:
|
||||
"""Convert screen coordinates to screenshot coordinates.
|
||||
|
||||
Args:
|
||||
x: X coordinate in screen space
|
||||
y: Y coordinate in screen space
|
||||
|
||||
Returns:
|
||||
tuple[float, float]: (x, y) coordinates in screenshot space
|
||||
"""
|
||||
screen_size = await self.get_screen_size()
|
||||
screenshot = await self.screenshot()
|
||||
screenshot_img = bytes_to_image(screenshot)
|
||||
screenshot_width, screenshot_height = screenshot_img.size
|
||||
|
||||
# Calculate scaling factors
|
||||
width_scale = screenshot_width / screen_size["width"]
|
||||
height_scale = screenshot_height / screen_size["height"]
|
||||
|
||||
# Convert coordinates
|
||||
screenshot_x = x * width_scale
|
||||
screenshot_y = y * height_scale
|
||||
|
||||
return screenshot_x, screenshot_y
|
||||
84
libs/python/computer/computer/logger.py
Normal file
84
libs/python/computer/computer/logger.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Logging utilities for the Computer module."""
|
||||
|
||||
import logging
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
# Keep LogLevel for backward compatibility, but it will be deprecated
|
||||
class LogLevel(IntEnum):
|
||||
"""Log levels for logging. Deprecated - use standard logging levels instead."""
|
||||
|
||||
QUIET = 0 # Only warnings and errors
|
||||
NORMAL = 1 # Info level, standard output
|
||||
VERBOSE = 2 # More detailed information
|
||||
DEBUG = 3 # Full debug information
|
||||
|
||||
|
||||
# Map LogLevel to standard logging levels for backward compatibility
|
||||
LOGLEVEL_MAP = {
|
||||
LogLevel.QUIET: logging.WARNING,
|
||||
LogLevel.NORMAL: logging.INFO,
|
||||
LogLevel.VERBOSE: logging.DEBUG,
|
||||
LogLevel.DEBUG: logging.DEBUG,
|
||||
}
|
||||
|
||||
|
||||
class Logger:
|
||||
"""Logger class for Computer."""
|
||||
|
||||
def __init__(self, name: str, verbosity: int):
|
||||
"""Initialize the logger.
|
||||
|
||||
Args:
|
||||
name: The name of the logger.
|
||||
verbosity: The log level (use standard logging levels like logging.INFO).
|
||||
For backward compatibility, LogLevel enum values are also accepted.
|
||||
"""
|
||||
self.logger = logging.getLogger(name)
|
||||
|
||||
# Convert LogLevel enum to standard logging level if needed
|
||||
if isinstance(verbosity, LogLevel):
|
||||
self.verbosity = LOGLEVEL_MAP.get(verbosity, logging.INFO)
|
||||
else:
|
||||
self.verbosity = verbosity
|
||||
|
||||
self._configure()
|
||||
|
||||
def _configure(self):
|
||||
"""Configure the logger based on log level."""
|
||||
# Set the logging level directly
|
||||
self.logger.setLevel(self.verbosity)
|
||||
|
||||
# Log the verbosity level that was set
|
||||
if self.verbosity <= logging.DEBUG:
|
||||
self.logger.info("Logger set to DEBUG level")
|
||||
elif self.verbosity <= logging.INFO:
|
||||
self.logger.info("Logger set to INFO level")
|
||||
elif self.verbosity <= logging.WARNING:
|
||||
self.logger.warning("Logger set to WARNING level")
|
||||
elif self.verbosity <= logging.ERROR:
|
||||
self.logger.warning("Logger set to ERROR level")
|
||||
elif self.verbosity <= logging.CRITICAL:
|
||||
self.logger.warning("Logger set to CRITICAL level")
|
||||
|
||||
def debug(self, message: str):
|
||||
"""Log a debug message if log level is DEBUG or lower."""
|
||||
self.logger.debug(message)
|
||||
|
||||
def info(self, message: str):
|
||||
"""Log an info message if log level is INFO or lower."""
|
||||
self.logger.info(message)
|
||||
|
||||
def verbose(self, message: str):
|
||||
"""Log a verbose message between INFO and DEBUG levels."""
|
||||
# Since there's no standard verbose level,
|
||||
# use debug level with [VERBOSE] prefix for backward compatibility
|
||||
self.logger.debug(f"[VERBOSE] {message}")
|
||||
|
||||
def warning(self, message: str):
|
||||
"""Log a warning message."""
|
||||
self.logger.warning(message)
|
||||
|
||||
def error(self, message: str):
|
||||
"""Log an error message."""
|
||||
self.logger.error(message)
|
||||
47
libs/python/computer/computer/models.py
Normal file
47
libs/python/computer/computer/models.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Models for computer configuration."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any, Dict
|
||||
|
||||
# Import base provider interface
|
||||
from .providers.base import BaseVMProvider
|
||||
|
||||
@dataclass
|
||||
class Display:
|
||||
"""Display configuration."""
|
||||
width: int
|
||||
height: int
|
||||
|
||||
@dataclass
|
||||
class Image:
|
||||
"""VM image configuration."""
|
||||
image: str
|
||||
tag: str
|
||||
name: str
|
||||
|
||||
@dataclass
|
||||
class Computer:
|
||||
"""Computer configuration."""
|
||||
image: str
|
||||
tag: str
|
||||
name: str
|
||||
display: Display
|
||||
memory: str
|
||||
cpu: str
|
||||
vm_provider: Optional[BaseVMProvider] = None
|
||||
|
||||
# @property # Remove the property decorator
|
||||
async def get_ip(self) -> Optional[str]:
|
||||
"""Get the IP address of the VM."""
|
||||
if not self.vm_provider:
|
||||
return None
|
||||
|
||||
vm = await self.vm_provider.get_vm(self.name)
|
||||
# Handle both object attribute and dictionary access for ip_address
|
||||
if vm:
|
||||
if isinstance(vm, dict):
|
||||
return vm.get("ip_address")
|
||||
else:
|
||||
# Access as attribute for object-based return values
|
||||
return getattr(vm, "ip_address", None)
|
||||
return None
|
||||
4
libs/python/computer/computer/providers/__init__.py
Normal file
4
libs/python/computer/computer/providers/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""Provider implementations for different VM backends."""
|
||||
|
||||
# Import specific providers only when needed to avoid circular imports
|
||||
__all__ = [] # Let each provider module handle its own exports
|
||||
106
libs/python/computer/computer/providers/base.py
Normal file
106
libs/python/computer/computer/providers/base.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Base provider interface for VM backends."""
|
||||
|
||||
import abc
|
||||
from enum import StrEnum
|
||||
from typing import Dict, List, Optional, Any, AsyncContextManager
|
||||
|
||||
|
||||
class VMProviderType(StrEnum):
|
||||
"""Enum of supported VM provider types."""
|
||||
LUME = "lume"
|
||||
LUMIER = "lumier"
|
||||
CLOUD = "cloud"
|
||||
WINSANDBOX = "winsandbox"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class BaseVMProvider(AsyncContextManager):
|
||||
"""Base interface for VM providers.
|
||||
|
||||
All VM provider implementations must implement this interface.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def provider_type(self) -> VMProviderType:
|
||||
"""Get the provider type."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get VM information by name.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to get information for
|
||||
storage: Optional storage path override. If provided, this will be used
|
||||
instead of the provider's default storage path.
|
||||
|
||||
Returns:
|
||||
Dictionary with VM information including status, IP address, etc.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_vms(self) -> List[Dict[str, Any]]:
|
||||
"""List all available VMs."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Run a VM by name with the given options.
|
||||
|
||||
Args:
|
||||
image: Name/tag of the image to use
|
||||
name: Name of the VM to run
|
||||
run_opts: Dictionary of run options (memory, cpu, etc.)
|
||||
storage: Optional storage path override. If provided, this will be used
|
||||
instead of the provider's default storage path.
|
||||
|
||||
Returns:
|
||||
Dictionary with VM run status and information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Stop a VM by name.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to stop
|
||||
storage: Optional storage path override. If provided, this will be used
|
||||
instead of the provider's default storage path.
|
||||
|
||||
Returns:
|
||||
Dictionary with VM stop status and information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Update VM configuration.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to update
|
||||
update_opts: Dictionary of update options (memory, cpu, etc.)
|
||||
storage: Optional storage path override. If provided, this will be used
|
||||
instead of the provider's default storage path.
|
||||
|
||||
Returns:
|
||||
Dictionary with VM update status and information
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
|
||||
"""Get the IP address of a VM, waiting indefinitely until it's available.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to get the IP for
|
||||
storage: Optional storage path override. If provided, this will be used
|
||||
instead of the provider's default storage path.
|
||||
retry_delay: Delay between retries in seconds (default: 2)
|
||||
|
||||
Returns:
|
||||
IP address of the VM when it becomes available
|
||||
"""
|
||||
pass
|
||||
@@ -0,0 +1,5 @@
|
||||
"""CloudProvider module for interacting with cloud-based virtual machines."""
|
||||
|
||||
from .provider import CloudProvider
|
||||
|
||||
__all__ = ["CloudProvider"]
|
||||
75
libs/python/computer/computer/providers/cloud/provider.py
Normal file
75
libs/python/computer/computer/providers/cloud/provider.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Cloud VM provider implementation.
|
||||
|
||||
This module contains a stub implementation for a future cloud VM provider.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
from ..base import BaseVMProvider, VMProviderType
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
import asyncio
|
||||
import aiohttp
|
||||
from urllib.parse import urlparse
|
||||
|
||||
class CloudProvider(BaseVMProvider):
|
||||
"""Cloud VM Provider implementation."""
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
api_key: API key for authentication
|
||||
name: Name of the VM
|
||||
verbose: Enable verbose logging
|
||||
"""
|
||||
assert api_key, "api_key required for CloudProvider"
|
||||
self.api_key = api_key
|
||||
self.verbose = verbose
|
||||
|
||||
@property
|
||||
def provider_type(self) -> VMProviderType:
|
||||
return VMProviderType.CLOUD
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get VM VNC URL by name using the cloud API."""
|
||||
return {"name": name, "hostname": f"{name}.containers.cloud.trycua.com"}
|
||||
|
||||
async def list_vms(self) -> List[Dict[str, Any]]:
|
||||
logger.warning("CloudProvider.list_vms is not implemented")
|
||||
return []
|
||||
|
||||
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
# logger.warning("CloudProvider.run_vm is not implemented")
|
||||
return {"name": name, "status": "unavailable", "message": "CloudProvider.run_vm is not implemented"}
|
||||
|
||||
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
logger.warning("CloudProvider.stop_vm is not implemented. To clean up resources, please use Computer.disconnect()")
|
||||
return {"name": name, "status": "stopped", "message": "CloudProvider is not implemented"}
|
||||
|
||||
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
logger.warning("CloudProvider.update_vm is not implemented")
|
||||
return {"name": name, "status": "unchanged", "message": "CloudProvider is not implemented"}
|
||||
|
||||
async def get_ip(self, name: Optional[str] = None, storage: Optional[str] = None, retry_delay: int = 2) -> str:
|
||||
"""
|
||||
Return the VM's IP address as '{container_name}.containers.cloud.trycua.com'.
|
||||
Uses the provided 'name' argument (the VM name requested by the caller),
|
||||
falling back to self.name only if 'name' is None.
|
||||
Retries up to 3 times with retry_delay seconds if hostname is not available.
|
||||
"""
|
||||
if name is None:
|
||||
raise ValueError("VM name is required for CloudProvider.get_ip")
|
||||
return f"{name}.containers.cloud.trycua.com"
|
||||
138
libs/python/computer/computer/providers/factory.py
Normal file
138
libs/python/computer/computer/providers/factory.py
Normal file
@@ -0,0 +1,138 @@
|
||||
"""Factory for creating VM providers."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Optional, Any, Type, Union
|
||||
|
||||
from .base import BaseVMProvider, VMProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VMProviderFactory:
|
||||
"""Factory for creating VM providers based on provider type."""
|
||||
|
||||
@staticmethod
|
||||
def create_provider(
|
||||
provider_type: Union[str, VMProviderType],
|
||||
port: int = 7777,
|
||||
host: str = "localhost",
|
||||
bin_path: Optional[str] = None,
|
||||
storage: Optional[str] = None,
|
||||
shared_path: Optional[str] = None,
|
||||
image: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
ephemeral: bool = False,
|
||||
noVNC_port: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> BaseVMProvider:
|
||||
"""Create a VM provider of the specified type.
|
||||
|
||||
Args:
|
||||
provider_type: Type of VM provider to create
|
||||
port: Port for the API server
|
||||
host: Hostname for the API server
|
||||
bin_path: Path to provider binary if needed
|
||||
storage: Path for persistent VM storage
|
||||
shared_path: Path for shared folder between host and VM
|
||||
image: VM image to use (for Lumier provider)
|
||||
verbose: Enable verbose logging
|
||||
ephemeral: Use ephemeral (temporary) storage
|
||||
noVNC_port: Specific port for noVNC interface (for Lumier provider)
|
||||
|
||||
Returns:
|
||||
An instance of the requested VM provider
|
||||
|
||||
Raises:
|
||||
ImportError: If the required dependencies for the provider are not installed
|
||||
ValueError: If the provider type is not supported
|
||||
"""
|
||||
# Convert string to enum if needed
|
||||
if isinstance(provider_type, str):
|
||||
try:
|
||||
provider_type = VMProviderType(provider_type.lower())
|
||||
except ValueError:
|
||||
provider_type = VMProviderType.UNKNOWN
|
||||
|
||||
if provider_type == VMProviderType.LUME:
|
||||
try:
|
||||
from .lume import LumeProvider, HAS_LUME
|
||||
if not HAS_LUME:
|
||||
raise ImportError(
|
||||
"The pylume package is required for LumeProvider. "
|
||||
"Please install it with 'pip install cua-computer[lume]'"
|
||||
)
|
||||
return LumeProvider(
|
||||
port=port,
|
||||
host=host,
|
||||
storage=storage,
|
||||
verbose=verbose,
|
||||
ephemeral=ephemeral
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import LumeProvider: {e}")
|
||||
raise ImportError(
|
||||
"The pylume package is required for LumeProvider. "
|
||||
"Please install it with 'pip install cua-computer[lume]'"
|
||||
) from e
|
||||
elif provider_type == VMProviderType.LUMIER:
|
||||
try:
|
||||
from .lumier import LumierProvider, HAS_LUMIER
|
||||
if not HAS_LUMIER:
|
||||
raise ImportError(
|
||||
"Docker is required for LumierProvider. "
|
||||
"Please install Docker for Apple Silicon and Lume CLI before using this provider."
|
||||
)
|
||||
return LumierProvider(
|
||||
port=port,
|
||||
host=host,
|
||||
storage=storage,
|
||||
shared_path=shared_path,
|
||||
image=image or "macos-sequoia-cua:latest",
|
||||
verbose=verbose,
|
||||
ephemeral=ephemeral,
|
||||
noVNC_port=noVNC_port
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import LumierProvider: {e}")
|
||||
raise ImportError(
|
||||
"Docker and Lume CLI are required for LumierProvider. "
|
||||
"Please install Docker for Apple Silicon and run the Lume installer script."
|
||||
) from e
|
||||
|
||||
elif provider_type == VMProviderType.CLOUD:
|
||||
try:
|
||||
from .cloud import CloudProvider
|
||||
return CloudProvider(
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import CloudProvider: {e}")
|
||||
raise ImportError(
|
||||
"The CloudProvider is not fully implemented yet. "
|
||||
"Please use LUME or LUMIER provider instead."
|
||||
) from e
|
||||
elif provider_type == VMProviderType.WINSANDBOX:
|
||||
try:
|
||||
from .winsandbox import WinSandboxProvider, HAS_WINSANDBOX
|
||||
if not HAS_WINSANDBOX:
|
||||
raise ImportError(
|
||||
"pywinsandbox is required for WinSandboxProvider. "
|
||||
"Please install it with 'pip install -U git+https://github.com/karkason/pywinsandbox.git'"
|
||||
)
|
||||
return WinSandboxProvider(
|
||||
port=port,
|
||||
host=host,
|
||||
storage=storage,
|
||||
verbose=verbose,
|
||||
ephemeral=ephemeral,
|
||||
**kwargs
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import WinSandboxProvider: {e}")
|
||||
raise ImportError(
|
||||
"pywinsandbox is required for WinSandboxProvider. "
|
||||
"Please install it with 'pip install -U git+https://github.com/karkason/pywinsandbox.git'"
|
||||
) from e
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider type: {provider_type}")
|
||||
9
libs/python/computer/computer/providers/lume/__init__.py
Normal file
9
libs/python/computer/computer/providers/lume/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Lume VM provider implementation."""
|
||||
|
||||
try:
|
||||
from .provider import LumeProvider
|
||||
HAS_LUME = True
|
||||
__all__ = ["LumeProvider"]
|
||||
except ImportError:
|
||||
HAS_LUME = False
|
||||
__all__ = []
|
||||
541
libs/python/computer/computer/providers/lume/provider.py
Normal file
541
libs/python/computer/computer/providers/lume/provider.py
Normal file
@@ -0,0 +1,541 @@
|
||||
"""Lume VM provider implementation using curl commands.
|
||||
|
||||
This provider uses direct curl commands to interact with the Lume API,
|
||||
removing the dependency on the pylume Python package.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import urllib.parse
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
|
||||
from ..base import BaseVMProvider, VMProviderType
|
||||
from ...logger import Logger, LogLevel
|
||||
from ..lume_api import (
|
||||
lume_api_get,
|
||||
lume_api_run,
|
||||
lume_api_stop,
|
||||
lume_api_update,
|
||||
lume_api_pull,
|
||||
HAS_CURL,
|
||||
parse_memory
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class LumeProvider(BaseVMProvider):
|
||||
"""Lume VM provider implementation using direct curl commands.
|
||||
|
||||
This provider uses curl to interact with the Lume API server,
|
||||
removing the dependency on the pylume Python package.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: int = 7777,
|
||||
host: str = "localhost",
|
||||
storage: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
ephemeral: bool = False,
|
||||
):
|
||||
"""Initialize the Lume provider.
|
||||
|
||||
Args:
|
||||
port: Port for the Lume API server (default: 7777)
|
||||
host: Host to use for API connections (default: localhost)
|
||||
storage: Path to store VM data
|
||||
verbose: Enable verbose logging
|
||||
"""
|
||||
if not HAS_CURL:
|
||||
raise ImportError(
|
||||
"curl is required for LumeProvider. "
|
||||
"Please ensure it is installed and in your PATH."
|
||||
)
|
||||
|
||||
self.host = host
|
||||
self.port = port # Default port for Lume API
|
||||
self.storage = storage
|
||||
self.verbose = verbose
|
||||
self.ephemeral = ephemeral # If True, VMs will be deleted after stopping
|
||||
|
||||
# Base API URL for Lume API calls
|
||||
self.api_base_url = f"http://{self.host}:{self.port}"
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
@property
|
||||
def provider_type(self) -> VMProviderType:
|
||||
"""Get the provider type."""
|
||||
return VMProviderType.LUME
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter async context manager."""
|
||||
# No initialization needed, just return self
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit async context manager."""
|
||||
# No cleanup needed
|
||||
pass
|
||||
|
||||
def _lume_api_get(self, vm_name: str = "", storage: Optional[str] = None, debug: bool = False) -> Dict[str, Any]:
|
||||
"""Get VM information using shared lume_api function.
|
||||
|
||||
Args:
|
||||
vm_name: Optional name of the VM to get info for.
|
||||
If empty, lists all VMs.
|
||||
storage: Optional storage path override. If provided, this will be used instead of self.storage
|
||||
debug: Whether to show debug output
|
||||
|
||||
Returns:
|
||||
Dictionary with VM status information parsed from JSON response
|
||||
"""
|
||||
# Use the shared implementation from lume_api module
|
||||
return lume_api_get(
|
||||
vm_name=vm_name,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
storage=storage if storage is not None else self.storage,
|
||||
debug=debug,
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
def _lume_api_run(self, vm_name: str, run_opts: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
|
||||
"""Run a VM using shared lume_api function.
|
||||
|
||||
Args:
|
||||
vm_name: Name of the VM to run
|
||||
run_opts: Dictionary of run options
|
||||
debug: Whether to show debug output
|
||||
|
||||
Returns:
|
||||
Dictionary with API response or error information
|
||||
"""
|
||||
# Use the shared implementation from lume_api module
|
||||
return lume_api_run(
|
||||
vm_name=vm_name,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
run_opts=run_opts,
|
||||
storage=self.storage,
|
||||
debug=debug,
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
def _lume_api_stop(self, vm_name: str, debug: bool = False) -> Dict[str, Any]:
|
||||
"""Stop a VM using shared lume_api function.
|
||||
|
||||
Args:
|
||||
vm_name: Name of the VM to stop
|
||||
debug: Whether to show debug output
|
||||
|
||||
Returns:
|
||||
Dictionary with API response or error information
|
||||
"""
|
||||
# Use the shared implementation from lume_api module
|
||||
return lume_api_stop(
|
||||
vm_name=vm_name,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
storage=self.storage,
|
||||
debug=debug,
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
def _lume_api_update(self, vm_name: str, update_opts: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
|
||||
"""Update VM configuration using shared lume_api function.
|
||||
|
||||
Args:
|
||||
vm_name: Name of the VM to update
|
||||
update_opts: Dictionary of update options
|
||||
debug: Whether to show debug output
|
||||
|
||||
Returns:
|
||||
Dictionary with API response or error information
|
||||
"""
|
||||
# Use the shared implementation from lume_api module
|
||||
return lume_api_update(
|
||||
vm_name=vm_name,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
update_opts=update_opts,
|
||||
storage=self.storage,
|
||||
debug=debug,
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get VM information by name.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to get information for
|
||||
storage: Optional storage path override. If provided, this will be used
|
||||
instead of the provider's default storage path.
|
||||
|
||||
Returns:
|
||||
Dictionary with VM information including status, IP address, etc.
|
||||
|
||||
Note:
|
||||
If storage is not provided, the provider's default storage path will be used.
|
||||
The storage parameter allows overriding the storage location for this specific call.
|
||||
"""
|
||||
if not HAS_CURL:
|
||||
logger.error("curl is not available. Cannot get VM status.")
|
||||
return {
|
||||
"name": name,
|
||||
"status": "unavailable",
|
||||
"error": "curl is not available"
|
||||
}
|
||||
|
||||
# First try to get detailed VM info from the API
|
||||
try:
|
||||
# Query the Lume API for VM status using the provider's storage_path
|
||||
vm_info = self._lume_api_get(
|
||||
vm_name=name,
|
||||
storage=storage if storage is not None else self.storage,
|
||||
debug=self.verbose
|
||||
)
|
||||
|
||||
# Check for API errors
|
||||
if "error" in vm_info:
|
||||
logger.debug(f"API request error: {vm_info['error']}")
|
||||
# If we got an error from the API, report the VM as not ready yet
|
||||
return {
|
||||
"name": name,
|
||||
"status": "starting", # VM is still starting - do not attempt to connect yet
|
||||
"api_status": "error",
|
||||
"error": vm_info["error"]
|
||||
}
|
||||
|
||||
# Process the VM status information
|
||||
vm_status = vm_info.get("status", "unknown")
|
||||
|
||||
# Check if VM is stopped or not running - don't wait for IP in this case
|
||||
if vm_status == "stopped":
|
||||
logger.info(f"VM {name} is in '{vm_status}' state - not waiting for IP address")
|
||||
# Return the status as-is without waiting for an IP
|
||||
result = {
|
||||
"name": name,
|
||||
"status": vm_status,
|
||||
**vm_info # Include all original fields from the API response
|
||||
}
|
||||
return result
|
||||
|
||||
# Handle field name differences between APIs
|
||||
# Some APIs use camelCase, others use snake_case
|
||||
if "vncUrl" in vm_info:
|
||||
vnc_url = vm_info["vncUrl"]
|
||||
elif "vnc_url" in vm_info:
|
||||
vnc_url = vm_info["vnc_url"]
|
||||
else:
|
||||
vnc_url = ""
|
||||
|
||||
if "ipAddress" in vm_info:
|
||||
ip_address = vm_info["ipAddress"]
|
||||
elif "ip_address" in vm_info:
|
||||
ip_address = vm_info["ip_address"]
|
||||
else:
|
||||
# If no IP address is provided and VM is supposed to be running,
|
||||
# report it as still starting
|
||||
ip_address = None
|
||||
logger.info(f"VM {name} is in '{vm_status}' state but no IP address found - reporting as still starting")
|
||||
|
||||
logger.info(f"VM {name} status: {vm_status}")
|
||||
|
||||
# Return the complete status information
|
||||
result = {
|
||||
"name": name,
|
||||
"status": vm_status if vm_status else "running",
|
||||
"ip_address": ip_address,
|
||||
"vnc_url": vnc_url,
|
||||
"api_status": "ok"
|
||||
}
|
||||
|
||||
# Include all original fields from the API response
|
||||
if isinstance(vm_info, dict):
|
||||
for key, value in vm_info.items():
|
||||
if key not in result: # Don't override our carefully processed fields
|
||||
result[key] = value
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get VM status: {e}")
|
||||
# Return a fallback status that indicates the VM is not ready yet
|
||||
return {
|
||||
"name": name,
|
||||
"status": "initializing", # VM is still initializing
|
||||
"error": f"Failed to get VM status: {str(e)}"
|
||||
}
|
||||
|
||||
async def list_vms(self) -> List[Dict[str, Any]]:
|
||||
"""List all available VMs."""
|
||||
result = self._lume_api_get(debug=self.verbose)
|
||||
|
||||
# Extract the VMs list from the response
|
||||
if "vms" in result and isinstance(result["vms"], list):
|
||||
return result["vms"]
|
||||
elif "error" in result:
|
||||
logger.error(f"Error listing VMs: {result['error']}")
|
||||
return []
|
||||
else:
|
||||
return []
|
||||
|
||||
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Run a VM with the given options.
|
||||
|
||||
If the VM does not exist in the storage location, this will attempt to pull it
|
||||
from the Lume registry first.
|
||||
|
||||
Args:
|
||||
image: Image name to use when pulling the VM if it doesn't exist
|
||||
name: Name of the VM to run
|
||||
run_opts: Dictionary of run options (memory, cpu, etc.)
|
||||
storage: Optional storage path override. If provided, this will be used
|
||||
instead of the provider's default storage path.
|
||||
|
||||
Returns:
|
||||
Dictionary with VM run status and information
|
||||
"""
|
||||
# First check if VM exists by trying to get its info
|
||||
vm_info = await self.get_vm(name, storage=storage)
|
||||
|
||||
if "error" in vm_info:
|
||||
# VM doesn't exist, try to pull it
|
||||
self.logger.info(f"VM {name} not found, attempting to pull image {image} from registry...")
|
||||
|
||||
# Call pull_vm with the image parameter
|
||||
pull_result = await self.pull_vm(
|
||||
name=name,
|
||||
image=image,
|
||||
storage=storage
|
||||
)
|
||||
|
||||
# Check if pull was successful
|
||||
if "error" in pull_result:
|
||||
self.logger.error(f"Failed to pull VM image: {pull_result['error']}")
|
||||
return pull_result # Return the error from pull
|
||||
|
||||
self.logger.info(f"Successfully pulled VM image {image} as {name}")
|
||||
|
||||
# Now run the VM with the given options
|
||||
self.logger.info(f"Running VM {name} with options: {run_opts}")
|
||||
|
||||
from ..lume_api import lume_api_run
|
||||
return lume_api_run(
|
||||
vm_name=name,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
run_opts=run_opts,
|
||||
storage=storage if storage is not None else self.storage,
|
||||
debug=self.verbose,
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Stop a running VM.
|
||||
|
||||
If this provider was initialized with ephemeral=True, the VM will also
|
||||
be deleted after it is stopped.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to stop
|
||||
storage: Optional storage path override
|
||||
|
||||
Returns:
|
||||
Dictionary with stop status and information
|
||||
"""
|
||||
# Stop the VM first
|
||||
stop_result = self._lume_api_stop(name, debug=self.verbose)
|
||||
|
||||
# Log ephemeral status for debugging
|
||||
self.logger.info(f"Ephemeral mode status: {self.ephemeral}")
|
||||
|
||||
# If ephemeral mode is enabled, delete the VM after stopping
|
||||
if self.ephemeral and (stop_result.get("success", False) or "error" not in stop_result):
|
||||
self.logger.info(f"Ephemeral mode enabled - deleting VM {name} after stopping")
|
||||
try:
|
||||
delete_result = await self.delete_vm(name, storage=storage)
|
||||
|
||||
# Return combined result
|
||||
return {
|
||||
**stop_result, # Include all stop result info
|
||||
"deleted": True,
|
||||
"delete_result": delete_result
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to delete ephemeral VM {name}: {e}")
|
||||
# Include the error but still return stop result
|
||||
return {
|
||||
**stop_result,
|
||||
"deleted": False,
|
||||
"delete_error": str(e)
|
||||
}
|
||||
|
||||
# Just return the stop result if not ephemeral
|
||||
return stop_result
|
||||
|
||||
async def pull_vm(
|
||||
self,
|
||||
name: str,
|
||||
image: str,
|
||||
storage: Optional[str] = None,
|
||||
registry: str = "ghcr.io",
|
||||
organization: str = "trycua",
|
||||
pull_opts: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Pull a VM image from the registry.
|
||||
|
||||
Args:
|
||||
name: Name for the VM after pulling
|
||||
image: The image name to pull (e.g. 'macos-sequoia-cua:latest')
|
||||
storage: Optional storage path to use
|
||||
registry: Registry to pull from (default: ghcr.io)
|
||||
organization: Organization in registry (default: trycua)
|
||||
pull_opts: Additional options for pulling the VM (optional)
|
||||
|
||||
Returns:
|
||||
Dictionary with information about the pulled VM
|
||||
|
||||
Raises:
|
||||
RuntimeError: If pull operation fails or image is not provided
|
||||
"""
|
||||
# Validate image parameter
|
||||
if not image:
|
||||
raise ValueError("Image parameter is required for pull_vm")
|
||||
|
||||
self.logger.info(f"Pulling VM image '{image}' as '{name}'")
|
||||
self.logger.info("You can check the pull progress using: lume logs -f")
|
||||
|
||||
# Set default pull_opts if not provided
|
||||
if pull_opts is None:
|
||||
pull_opts = {}
|
||||
|
||||
# Log information about the operation
|
||||
self.logger.debug(f"Pull storage location: {storage or 'default'}")
|
||||
|
||||
try:
|
||||
# Call the lume_api_pull function from lume_api.py
|
||||
from ..lume_api import lume_api_pull
|
||||
|
||||
result = lume_api_pull(
|
||||
image=image,
|
||||
name=name,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
storage=storage if storage is not None else self.storage,
|
||||
registry=registry,
|
||||
organization=organization,
|
||||
debug=self.verbose,
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
# Check for errors in the result
|
||||
if "error" in result:
|
||||
self.logger.error(f"Failed to pull VM image: {result['error']}")
|
||||
return result
|
||||
|
||||
self.logger.info(f"Successfully pulled VM image '{image}' as '{name}'")
|
||||
return result
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to pull VM image '{image}': {e}")
|
||||
return {"error": f"Failed to pull VM: {str(e)}"}
|
||||
|
||||
async def delete_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Delete a VM permanently.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to delete
|
||||
storage: Optional storage path override
|
||||
|
||||
Returns:
|
||||
Dictionary with delete status and information
|
||||
"""
|
||||
self.logger.info(f"Deleting VM {name}...")
|
||||
|
||||
try:
|
||||
# Call the lume_api_delete function we created
|
||||
from ..lume_api import lume_api_delete
|
||||
|
||||
result = lume_api_delete(
|
||||
vm_name=name,
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
storage=storage if storage is not None else self.storage,
|
||||
debug=self.verbose,
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
# Check for errors in the result
|
||||
if "error" in result:
|
||||
self.logger.error(f"Failed to delete VM: {result['error']}")
|
||||
return result
|
||||
|
||||
self.logger.info(f"Successfully deleted VM '{name}'")
|
||||
return result
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to delete VM '{name}': {e}")
|
||||
return {"error": f"Failed to delete VM: {str(e)}"}
|
||||
|
||||
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Update VM configuration."""
|
||||
return self._lume_api_update(name, update_opts, debug=self.verbose)
|
||||
|
||||
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
|
||||
"""Get the IP address of a VM, waiting indefinitely until it's available.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to get the IP for
|
||||
storage: Optional storage path override
|
||||
retry_delay: Delay between retries in seconds (default: 2)
|
||||
|
||||
Returns:
|
||||
IP address of the VM when it becomes available
|
||||
"""
|
||||
# Track total attempts for logging purposes
|
||||
total_attempts = 0
|
||||
|
||||
# Loop indefinitely until we get a valid IP
|
||||
while True:
|
||||
total_attempts += 1
|
||||
|
||||
# Log retry message but not on first attempt
|
||||
if total_attempts > 1:
|
||||
self.logger.info(f"Waiting for VM {name} IP address (attempt {total_attempts})...")
|
||||
|
||||
try:
|
||||
# Get VM information
|
||||
vm_info = await self.get_vm(name, storage=storage)
|
||||
|
||||
# Check if we got a valid IP
|
||||
ip = vm_info.get("ip_address", None)
|
||||
if ip and ip != "unknown" and not ip.startswith("0.0.0.0"):
|
||||
self.logger.info(f"Got valid VM IP address: {ip}")
|
||||
return ip
|
||||
|
||||
# Check the VM status
|
||||
status = vm_info.get("status", "unknown")
|
||||
|
||||
# If VM is not running yet, log and wait
|
||||
if status != "running":
|
||||
self.logger.info(f"VM is not running yet (status: {status}). Waiting...")
|
||||
# If VM is running but no IP yet, wait and retry
|
||||
else:
|
||||
self.logger.info("VM is running but no valid IP address yet. Waiting...")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error getting VM {name} IP: {e}, continuing to wait...")
|
||||
|
||||
# Wait before next retry
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
# Add progress log every 10 attempts
|
||||
if total_attempts % 10 == 0:
|
||||
self.logger.info(f"Still waiting for VM {name} IP after {total_attempts} attempts...")
|
||||
|
||||
|
||||
546
libs/python/computer/computer/providers/lume_api.py
Normal file
546
libs/python/computer/computer/providers/lume_api.py
Normal file
@@ -0,0 +1,546 @@
|
||||
"""Shared API utilities for Lume and Lumier providers.
|
||||
|
||||
This module contains shared functions for interacting with the Lume API,
|
||||
used by both the LumeProvider and LumierProvider classes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import subprocess
|
||||
import urllib.parse
|
||||
from typing import Dict, List, Optional, Any
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check if curl is available
|
||||
try:
|
||||
subprocess.run(["curl", "--version"], capture_output=True, check=True)
|
||||
HAS_CURL = True
|
||||
except (subprocess.SubprocessError, FileNotFoundError):
|
||||
HAS_CURL = False
|
||||
|
||||
|
||||
def lume_api_get(
|
||||
vm_name: str,
|
||||
host: str,
|
||||
port: int,
|
||||
storage: Optional[str] = None,
|
||||
debug: bool = False,
|
||||
verbose: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Use curl to get VM information from Lume API.
|
||||
|
||||
Args:
|
||||
vm_name: Name of the VM to get info for
|
||||
host: API host
|
||||
port: API port
|
||||
storage: Storage path for the VM
|
||||
debug: Whether to show debug output
|
||||
verbose: Enable verbose logging
|
||||
|
||||
Returns:
|
||||
Dictionary with VM status information parsed from JSON response
|
||||
"""
|
||||
# URL encode the storage parameter for the query
|
||||
encoded_storage = ""
|
||||
storage_param = ""
|
||||
|
||||
if storage:
|
||||
# First encode the storage path properly
|
||||
encoded_storage = urllib.parse.quote(storage, safe='')
|
||||
storage_param = f"?storage={encoded_storage}"
|
||||
|
||||
# Construct API URL with encoded storage parameter if needed
|
||||
api_url = f"http://{host}:{port}/lume/vms/{vm_name}{storage_param}"
|
||||
|
||||
# Construct the curl command with increased timeouts for more reliability
|
||||
# --connect-timeout: Time to establish connection (15 seconds)
|
||||
# --max-time: Maximum time for the whole operation (20 seconds)
|
||||
# -f: Fail silently (no output at all) on server errors
|
||||
# Add single quotes around URL to ensure special characters are handled correctly
|
||||
cmd = ["curl", "--connect-timeout", "15", "--max-time", "20", "-s", "-f", f"'{api_url}'"]
|
||||
|
||||
# For logging and display, show the properly escaped URL
|
||||
display_cmd = ["curl", "--connect-timeout", "15", "--max-time", "20", "-s", "-f", api_url]
|
||||
|
||||
# Only print the curl command when debug is enabled
|
||||
display_curl_string = ' '.join(display_cmd)
|
||||
logger.debug(f"Executing API request: {display_curl_string}")
|
||||
|
||||
# Execute the command - for execution we need to use shell=True to handle URLs with special characters
|
||||
try:
|
||||
# Use a single string with shell=True for proper URL handling
|
||||
shell_cmd = ' '.join(cmd)
|
||||
result = subprocess.run(shell_cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
# Handle curl exit codes
|
||||
if result.returncode != 0:
|
||||
curl_error = "Unknown error"
|
||||
|
||||
# Map common curl error codes to helpful messages
|
||||
if result.returncode == 7:
|
||||
curl_error = "Failed to connect to the API server - it might still be starting up"
|
||||
elif result.returncode == 22:
|
||||
curl_error = "HTTP error returned from API server"
|
||||
elif result.returncode == 28:
|
||||
curl_error = "Operation timeout - the API server is taking too long to respond"
|
||||
elif result.returncode == 52:
|
||||
curl_error = "Empty reply from server - the API server is starting but not fully ready yet"
|
||||
elif result.returncode == 56:
|
||||
curl_error = "Network problem during data transfer - check container networking"
|
||||
|
||||
# Only log at debug level to reduce noise during retries
|
||||
logger.debug(f"API request failed with code {result.returncode}: {curl_error}")
|
||||
|
||||
# Return a more useful error message
|
||||
return {
|
||||
"error": f"API request failed: {curl_error}",
|
||||
"curl_code": result.returncode,
|
||||
"vm_name": vm_name,
|
||||
"status": "unknown" # We don't know the actual status due to API error
|
||||
}
|
||||
|
||||
# Try to parse the response as JSON
|
||||
if result.stdout and result.stdout.strip():
|
||||
try:
|
||||
vm_status = json.loads(result.stdout)
|
||||
if debug or verbose:
|
||||
logger.info(f"Successfully parsed VM status: {vm_status.get('status', 'unknown')}")
|
||||
return vm_status
|
||||
except json.JSONDecodeError as e:
|
||||
# Return the raw response if it's not valid JSON
|
||||
logger.warning(f"Invalid JSON response: {e}")
|
||||
if "Virtual machine not found" in result.stdout:
|
||||
return {"status": "not_found", "message": "VM not found in Lume API"}
|
||||
|
||||
return {"error": f"Invalid JSON response: {result.stdout[:100]}...", "status": "unknown"}
|
||||
else:
|
||||
return {"error": "Empty response from API", "status": "unknown"}
|
||||
except subprocess.SubprocessError as e:
|
||||
logger.error(f"Failed to execute API request: {e}")
|
||||
return {"error": f"Failed to execute API request: {str(e)}", "status": "unknown"}
|
||||
|
||||
|
||||
def lume_api_run(
|
||||
vm_name: str,
|
||||
host: str,
|
||||
port: int,
|
||||
run_opts: Dict[str, Any],
|
||||
storage: Optional[str] = None,
|
||||
debug: bool = False,
|
||||
verbose: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Run a VM using curl.
|
||||
|
||||
Args:
|
||||
vm_name: Name of the VM to run
|
||||
host: API host
|
||||
port: API port
|
||||
run_opts: Dictionary of run options
|
||||
storage: Storage path for the VM
|
||||
debug: Whether to show debug output
|
||||
verbose: Enable verbose logging
|
||||
|
||||
Returns:
|
||||
Dictionary with API response or error information
|
||||
"""
|
||||
# Construct API URL
|
||||
api_url = f"http://{host}:{port}/lume/vms/{vm_name}/run"
|
||||
|
||||
# Prepare JSON payload with required parameters
|
||||
payload = {}
|
||||
|
||||
# Add CPU cores if specified
|
||||
if "cpu" in run_opts:
|
||||
payload["cpu"] = run_opts["cpu"]
|
||||
|
||||
# Add memory if specified
|
||||
if "memory" in run_opts:
|
||||
payload["memory"] = run_opts["memory"]
|
||||
|
||||
# Add storage parameter if specified
|
||||
if storage:
|
||||
payload["storage"] = storage
|
||||
elif "storage" in run_opts:
|
||||
payload["storage"] = run_opts["storage"]
|
||||
|
||||
# Add shared directories if specified
|
||||
if "shared_directories" in run_opts and run_opts["shared_directories"]:
|
||||
payload["sharedDirectories"] = run_opts["shared_directories"]
|
||||
|
||||
# Log the payload for debugging
|
||||
logger.debug(f"API payload: {json.dumps(payload, indent=2)}")
|
||||
|
||||
# Construct the curl command
|
||||
cmd = [
|
||||
"curl", "--connect-timeout", "30", "--max-time", "30",
|
||||
"-s", "-X", "POST", "-H", "Content-Type: application/json",
|
||||
"-d", json.dumps(payload),
|
||||
api_url
|
||||
]
|
||||
|
||||
# Execute the command
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.warning(f"API request failed with code {result.returncode}: {result.stderr}")
|
||||
return {"error": f"API request failed: {result.stderr}"}
|
||||
|
||||
# Try to parse the response as JSON
|
||||
if result.stdout and result.stdout.strip():
|
||||
try:
|
||||
response = json.loads(result.stdout)
|
||||
return response
|
||||
except json.JSONDecodeError:
|
||||
# Return the raw response if it's not valid JSON
|
||||
return {"success": True, "message": "VM started successfully", "raw_response": result.stdout}
|
||||
else:
|
||||
return {"success": True, "message": "VM started successfully"}
|
||||
except subprocess.SubprocessError as e:
|
||||
logger.error(f"Failed to execute run request: {e}")
|
||||
return {"error": f"Failed to execute run request: {str(e)}"}
|
||||
|
||||
|
||||
def lume_api_stop(
|
||||
vm_name: str,
|
||||
host: str,
|
||||
port: int,
|
||||
storage: Optional[str] = None,
|
||||
debug: bool = False,
|
||||
verbose: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Stop a VM using curl.
|
||||
|
||||
Args:
|
||||
vm_name: Name of the VM to stop
|
||||
host: API host
|
||||
port: API port
|
||||
storage: Storage path for the VM
|
||||
debug: Whether to show debug output
|
||||
verbose: Enable verbose logging
|
||||
|
||||
Returns:
|
||||
Dictionary with API response or error information
|
||||
"""
|
||||
# Construct API URL
|
||||
api_url = f"http://{host}:{port}/lume/vms/{vm_name}/stop"
|
||||
|
||||
# Prepare JSON payload with required parameters
|
||||
payload = {}
|
||||
|
||||
# Add storage path if specified
|
||||
if storage:
|
||||
payload["storage"] = storage
|
||||
|
||||
# Construct the curl command
|
||||
cmd = [
|
||||
"curl", "--connect-timeout", "15", "--max-time", "20",
|
||||
"-s", "-X", "POST", "-H", "Content-Type: application/json",
|
||||
"-d", json.dumps(payload),
|
||||
api_url
|
||||
]
|
||||
|
||||
# Execute the command
|
||||
try:
|
||||
if debug or verbose:
|
||||
logger.info(f"Executing: {' '.join(cmd)}")
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.warning(f"API request failed with code {result.returncode}: {result.stderr}")
|
||||
return {"error": f"API request failed: {result.stderr}"}
|
||||
|
||||
# Try to parse the response as JSON
|
||||
if result.stdout and result.stdout.strip():
|
||||
try:
|
||||
response = json.loads(result.stdout)
|
||||
return response
|
||||
except json.JSONDecodeError:
|
||||
# Return the raw response if it's not valid JSON
|
||||
return {"success": True, "message": "VM stopped successfully", "raw_response": result.stdout}
|
||||
else:
|
||||
return {"success": True, "message": "VM stopped successfully"}
|
||||
except subprocess.SubprocessError as e:
|
||||
logger.error(f"Failed to execute stop request: {e}")
|
||||
return {"error": f"Failed to execute stop request: {str(e)}"}
|
||||
|
||||
|
||||
def lume_api_update(
|
||||
vm_name: str,
|
||||
host: str,
|
||||
port: int,
|
||||
update_opts: Dict[str, Any],
|
||||
storage: Optional[str] = None,
|
||||
debug: bool = False,
|
||||
verbose: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Update VM settings using curl.
|
||||
|
||||
Args:
|
||||
vm_name: Name of the VM to update
|
||||
host: API host
|
||||
port: API port
|
||||
update_opts: Dictionary of update options
|
||||
storage: Storage path for the VM
|
||||
debug: Whether to show debug output
|
||||
verbose: Enable verbose logging
|
||||
|
||||
Returns:
|
||||
Dictionary with API response or error information
|
||||
"""
|
||||
# Construct API URL
|
||||
api_url = f"http://{host}:{port}/lume/vms/{vm_name}/update"
|
||||
|
||||
# Prepare JSON payload with required parameters
|
||||
payload = {}
|
||||
|
||||
# Add CPU cores if specified
|
||||
if "cpu" in update_opts:
|
||||
payload["cpu"] = update_opts["cpu"]
|
||||
|
||||
# Add memory if specified
|
||||
if "memory" in update_opts:
|
||||
payload["memory"] = update_opts["memory"]
|
||||
|
||||
# Add storage path if specified
|
||||
if storage:
|
||||
payload["storage"] = storage
|
||||
|
||||
# Construct the curl command
|
||||
cmd = [
|
||||
"curl", "--connect-timeout", "15", "--max-time", "20",
|
||||
"-s", "-X", "POST", "-H", "Content-Type: application/json",
|
||||
"-d", json.dumps(payload),
|
||||
api_url
|
||||
]
|
||||
|
||||
# Execute the command
|
||||
try:
|
||||
if debug:
|
||||
logger.info(f"Executing: {' '.join(cmd)}")
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
logger.warning(f"API request failed with code {result.returncode}: {result.stderr}")
|
||||
return {"error": f"API request failed: {result.stderr}"}
|
||||
|
||||
# Try to parse the response as JSON
|
||||
if result.stdout and result.stdout.strip():
|
||||
try:
|
||||
response = json.loads(result.stdout)
|
||||
return response
|
||||
except json.JSONDecodeError:
|
||||
# Return the raw response if it's not valid JSON
|
||||
return {"success": True, "message": "VM updated successfully", "raw_response": result.stdout}
|
||||
else:
|
||||
return {"success": True, "message": "VM updated successfully"}
|
||||
except subprocess.SubprocessError as e:
|
||||
logger.error(f"Failed to execute update request: {e}")
|
||||
return {"error": f"Failed to execute update request: {str(e)}"}
|
||||
|
||||
|
||||
def lume_api_pull(
|
||||
image: str,
|
||||
name: str,
|
||||
host: str,
|
||||
port: int,
|
||||
storage: Optional[str] = None,
|
||||
registry: str = "ghcr.io",
|
||||
organization: str = "trycua",
|
||||
debug: bool = False,
|
||||
verbose: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Pull a VM image from a registry using curl.
|
||||
|
||||
Args:
|
||||
image: Name/tag of the image to pull
|
||||
name: Name to give the VM after pulling
|
||||
host: API host
|
||||
port: API port
|
||||
storage: Storage path for the VM
|
||||
registry: Registry to pull from (default: ghcr.io)
|
||||
organization: Organization in registry (default: trycua)
|
||||
debug: Whether to show debug output
|
||||
verbose: Enable verbose logging
|
||||
|
||||
Returns:
|
||||
Dictionary with pull status and information
|
||||
"""
|
||||
# Prepare pull request payload
|
||||
pull_payload = {
|
||||
"image": image, # Use provided image name
|
||||
"name": name, # Always use name as the target VM name
|
||||
"registry": registry,
|
||||
"organization": organization
|
||||
}
|
||||
|
||||
if storage:
|
||||
pull_payload["storage"] = storage
|
||||
|
||||
# Construct pull command with proper JSON payload
|
||||
pull_cmd = [
|
||||
"curl"
|
||||
]
|
||||
|
||||
if not verbose:
|
||||
pull_cmd.append("-s")
|
||||
|
||||
pull_cmd.extend([
|
||||
"-X", "POST",
|
||||
"-H", "Content-Type: application/json",
|
||||
"-d", json.dumps(pull_payload),
|
||||
f"http://{host}:{port}/lume/pull"
|
||||
])
|
||||
|
||||
logger.debug(f"Executing API request: {' '.join(pull_cmd)}")
|
||||
|
||||
try:
|
||||
# Execute pull command
|
||||
result = subprocess.run(pull_cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
error_msg = f"Failed to pull VM {name}: {result.stderr}"
|
||||
logger.error(error_msg)
|
||||
return {"error": error_msg}
|
||||
|
||||
try:
|
||||
response = json.loads(result.stdout)
|
||||
logger.info(f"Successfully initiated pull for VM {name}")
|
||||
return response
|
||||
except json.JSONDecodeError:
|
||||
if result.stdout:
|
||||
logger.info(f"Pull response: {result.stdout}")
|
||||
return {"success": True, "message": f"Successfully initiated pull for VM {name}"}
|
||||
|
||||
except subprocess.SubprocessError as e:
|
||||
error_msg = f"Failed to execute pull command: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return {"error": error_msg}
|
||||
|
||||
|
||||
def lume_api_delete(
|
||||
vm_name: str,
|
||||
host: str,
|
||||
port: int,
|
||||
storage: Optional[str] = None,
|
||||
debug: bool = False,
|
||||
verbose: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""Delete a VM using curl.
|
||||
|
||||
Args:
|
||||
vm_name: Name of the VM to delete
|
||||
host: API host
|
||||
port: API port
|
||||
storage: Storage path for the VM
|
||||
debug: Whether to show debug output
|
||||
verbose: Enable verbose logging
|
||||
|
||||
Returns:
|
||||
Dictionary with API response or error information
|
||||
"""
|
||||
# URL encode the storage parameter for the query
|
||||
encoded_storage = ""
|
||||
storage_param = ""
|
||||
|
||||
if storage:
|
||||
# First encode the storage path properly
|
||||
encoded_storage = urllib.parse.quote(storage, safe='')
|
||||
storage_param = f"?storage={encoded_storage}"
|
||||
|
||||
# Construct API URL with encoded storage parameter if needed
|
||||
api_url = f"http://{host}:{port}/lume/vms/{vm_name}{storage_param}"
|
||||
|
||||
# Construct the curl command for DELETE operation - using much longer timeouts matching shell implementation
|
||||
cmd = ["curl", "--connect-timeout", "6000", "--max-time", "5000", "-s", "-X", "DELETE", f"'{api_url}'"]
|
||||
|
||||
# For logging and display, show the properly escaped URL
|
||||
display_cmd = ["curl", "--connect-timeout", "6000", "--max-time", "5000", "-s", "-X", "DELETE", api_url]
|
||||
|
||||
# Only print the curl command when debug is enabled
|
||||
display_curl_string = ' '.join(display_cmd)
|
||||
logger.debug(f"Executing API request: {display_curl_string}")
|
||||
|
||||
# Execute the command - for execution we need to use shell=True to handle URLs with special characters
|
||||
try:
|
||||
# Use a single string with shell=True for proper URL handling
|
||||
shell_cmd = ' '.join(cmd)
|
||||
result = subprocess.run(shell_cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
# Handle curl exit codes
|
||||
if result.returncode != 0:
|
||||
curl_error = "Unknown error"
|
||||
|
||||
# Map common curl error codes to helpful messages
|
||||
if result.returncode == 7:
|
||||
curl_error = "Failed to connect to the API server - it might still be starting up"
|
||||
elif result.returncode == 22:
|
||||
curl_error = "HTTP error returned from API server"
|
||||
elif result.returncode == 28:
|
||||
curl_error = "Operation timeout - the API server is taking too long to respond"
|
||||
elif result.returncode == 52:
|
||||
curl_error = "Empty reply from server - the API server is starting but not fully ready yet"
|
||||
elif result.returncode == 56:
|
||||
curl_error = "Network problem during data transfer - check container networking"
|
||||
|
||||
# Only log at debug level to reduce noise during retries
|
||||
logger.debug(f"API request failed with code {result.returncode}: {curl_error}")
|
||||
|
||||
# Return a more useful error message
|
||||
return {
|
||||
"error": f"API request failed: {curl_error}",
|
||||
"curl_code": result.returncode,
|
||||
"vm_name": vm_name,
|
||||
"storage": storage
|
||||
}
|
||||
|
||||
# Try to parse the response as JSON
|
||||
if result.stdout and result.stdout.strip():
|
||||
try:
|
||||
response = json.loads(result.stdout)
|
||||
return response
|
||||
except json.JSONDecodeError:
|
||||
# Return the raw response if it's not valid JSON
|
||||
return {"success": True, "message": "VM deleted successfully", "raw_response": result.stdout}
|
||||
else:
|
||||
return {"success": True, "message": "VM deleted successfully"}
|
||||
except subprocess.SubprocessError as e:
|
||||
logger.error(f"Failed to execute delete request: {e}")
|
||||
return {"error": f"Failed to execute delete request: {str(e)}"}
|
||||
|
||||
|
||||
def parse_memory(memory_str: str) -> int:
|
||||
"""Parse memory string to MB integer.
|
||||
|
||||
Examples:
|
||||
"8GB" -> 8192
|
||||
"1024MB" -> 1024
|
||||
"512" -> 512
|
||||
|
||||
Returns:
|
||||
Memory value in MB
|
||||
"""
|
||||
if isinstance(memory_str, int):
|
||||
return memory_str
|
||||
|
||||
if isinstance(memory_str, str):
|
||||
# Extract number and unit
|
||||
import re
|
||||
match = re.match(r"(\d+)([A-Za-z]*)", memory_str)
|
||||
if match:
|
||||
value, unit = match.groups()
|
||||
value = int(value)
|
||||
unit = unit.upper()
|
||||
|
||||
if unit == "GB" or unit == "G":
|
||||
return value * 1024
|
||||
elif unit == "MB" or unit == "M" or unit == "":
|
||||
return value
|
||||
|
||||
# Default fallback
|
||||
logger.warning(f"Could not parse memory string '{memory_str}', using 8GB default")
|
||||
return 8192 # Default to 8GB
|
||||
@@ -0,0 +1,8 @@
|
||||
"""Lumier VM provider implementation."""
|
||||
|
||||
try:
|
||||
# Use the same import approach as in the Lume provider
|
||||
from .provider import LumierProvider
|
||||
HAS_LUMIER = True
|
||||
except ImportError:
|
||||
HAS_LUMIER = False
|
||||
942
libs/python/computer/computer/providers/lumier/provider.py
Normal file
942
libs/python/computer/computer/providers/lumier/provider.py
Normal file
@@ -0,0 +1,942 @@
|
||||
"""
|
||||
Lumier VM provider implementation.
|
||||
|
||||
This provider uses Docker containers running the Lumier image to create
|
||||
macOS and Linux VMs. It handles VM lifecycle operations through Docker
|
||||
commands and container management.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Dict, List, Optional, Any
|
||||
import subprocess
|
||||
import time
|
||||
import re
|
||||
|
||||
from ..base import BaseVMProvider, VMProviderType
|
||||
from ..lume_api import (
|
||||
lume_api_get,
|
||||
lume_api_run,
|
||||
lume_api_stop,
|
||||
lume_api_update
|
||||
)
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check if Docker is available
|
||||
try:
|
||||
subprocess.run(["docker", "--version"], capture_output=True, check=True)
|
||||
HAS_LUMIER = True
|
||||
except (subprocess.SubprocessError, FileNotFoundError):
|
||||
HAS_LUMIER = False
|
||||
|
||||
|
||||
class LumierProvider(BaseVMProvider):
|
||||
"""
|
||||
Lumier VM Provider implementation using Docker containers.
|
||||
|
||||
This provider uses Docker to run Lumier containers that can create
|
||||
macOS and Linux VMs through containerization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: Optional[int] = 7777,
|
||||
host: str = "localhost",
|
||||
storage: Optional[str] = None, # Can be a path or 'ephemeral'
|
||||
shared_path: Optional[str] = None,
|
||||
image: str = "macos-sequoia-cua:latest", # VM image to use
|
||||
verbose: bool = False,
|
||||
ephemeral: bool = False,
|
||||
noVNC_port: Optional[int] = 8006,
|
||||
):
|
||||
"""Initialize the Lumier VM Provider.
|
||||
|
||||
Args:
|
||||
port: Port for the API server (default: 7777)
|
||||
host: Hostname for the API server (default: localhost)
|
||||
storage: Path for persistent VM storage
|
||||
shared_path: Path for shared folder between host and VM
|
||||
image: VM image to use (e.g. "macos-sequoia-cua:latest")
|
||||
verbose: Enable verbose logging
|
||||
ephemeral: Use ephemeral (temporary) storage
|
||||
noVNC_port: Specific port for noVNC interface (default: 8006)
|
||||
"""
|
||||
self.host = host
|
||||
# Always ensure api_port has a valid value (7777 is the default)
|
||||
self.api_port = 7777 if port is None else port
|
||||
self.vnc_port = noVNC_port # User-specified noVNC port, will be set in run_vm if provided
|
||||
self.ephemeral = ephemeral
|
||||
|
||||
# Handle ephemeral storage (temporary directory)
|
||||
if ephemeral:
|
||||
self.storage = "ephemeral"
|
||||
else:
|
||||
self.storage = storage
|
||||
|
||||
self.shared_path = shared_path
|
||||
self.image = image # Store the VM image name to use
|
||||
# The container_name will be set in run_vm using the VM name
|
||||
self.verbose = verbose
|
||||
self._container_id = None
|
||||
self._api_url = None # Will be set after container starts
|
||||
|
||||
@property
|
||||
def provider_type(self) -> VMProviderType:
|
||||
"""Return the provider type."""
|
||||
return VMProviderType.LUMIER
|
||||
|
||||
def _parse_memory(self, memory_str: str) -> int:
|
||||
"""Parse memory string to MB integer.
|
||||
|
||||
Examples:
|
||||
"8GB" -> 8192
|
||||
"1024MB" -> 1024
|
||||
"512" -> 512
|
||||
"""
|
||||
if isinstance(memory_str, int):
|
||||
return memory_str
|
||||
|
||||
if isinstance(memory_str, str):
|
||||
# Extract number and unit
|
||||
match = re.match(r"(\d+)([A-Za-z]*)", memory_str)
|
||||
if match:
|
||||
value, unit = match.groups()
|
||||
value = int(value)
|
||||
unit = unit.upper()
|
||||
|
||||
if unit == "GB" or unit == "G":
|
||||
return value * 1024
|
||||
elif unit == "MB" or unit == "M" or unit == "":
|
||||
return value
|
||||
|
||||
# Default fallback
|
||||
logger.warning(f"Could not parse memory string '{memory_str}', using 8GB default")
|
||||
return 8192 # Default to 8GB
|
||||
|
||||
# Helper methods for interacting with the Lumier API through curl
|
||||
# These methods handle the various VM operations via API calls
|
||||
|
||||
def _get_curl_error_message(self, return_code: int) -> str:
|
||||
"""Get a descriptive error message for curl return codes.
|
||||
|
||||
Args:
|
||||
return_code: The curl return code
|
||||
|
||||
Returns:
|
||||
A descriptive error message
|
||||
"""
|
||||
# Map common curl error codes to helpful messages
|
||||
if return_code == 7:
|
||||
return "Failed to connect - API server is starting up"
|
||||
elif return_code == 22:
|
||||
return "HTTP error returned from API server"
|
||||
elif return_code == 28:
|
||||
return "Operation timeout - API server is slow to respond"
|
||||
elif return_code == 52:
|
||||
return "Empty reply from server - API is starting but not ready"
|
||||
elif return_code == 56:
|
||||
return "Network problem during data transfer"
|
||||
else:
|
||||
return f"Unknown curl error code: {return_code}"
|
||||
|
||||
|
||||
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get VM information by name.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to get information for
|
||||
storage: Optional storage path override. If provided, this will be used
|
||||
instead of the provider's default storage path.
|
||||
|
||||
Returns:
|
||||
Dictionary with VM information including status, IP address, etc.
|
||||
"""
|
||||
if not HAS_LUMIER:
|
||||
logger.error("Docker is not available. Cannot get VM status.")
|
||||
return {
|
||||
"name": name,
|
||||
"status": "unavailable",
|
||||
"error": "Docker is not available"
|
||||
}
|
||||
|
||||
# Store the current name for API requests
|
||||
self.container_name = name
|
||||
|
||||
try:
|
||||
# Check if the container exists and is running
|
||||
check_cmd = ["docker", "ps", "-a", "--filter", f"name={name}", "--format", "{{.Status}}"]
|
||||
check_result = subprocess.run(check_cmd, capture_output=True, text=True)
|
||||
container_status = check_result.stdout.strip()
|
||||
|
||||
if not container_status:
|
||||
logger.info(f"Container {name} does not exist. Will create when run_vm is called.")
|
||||
return {
|
||||
"name": name,
|
||||
"status": "not_found",
|
||||
"message": "Container doesn't exist yet"
|
||||
}
|
||||
|
||||
# Container exists, check if it's running
|
||||
is_running = container_status.startswith("Up")
|
||||
|
||||
if not is_running:
|
||||
logger.info(f"Container {name} exists but is not running. Status: {container_status}")
|
||||
return {
|
||||
"name": name,
|
||||
"status": "stopped",
|
||||
"container_status": container_status,
|
||||
}
|
||||
|
||||
# Container is running, get the IP address and API status from Lumier API
|
||||
logger.info(f"Container {name} is running. Getting VM status from API.")
|
||||
|
||||
# Use the shared lume_api_get function directly
|
||||
vm_info = lume_api_get(
|
||||
vm_name=name,
|
||||
host=self.host,
|
||||
port=self.api_port,
|
||||
storage=storage if storage is not None else self.storage,
|
||||
debug=self.verbose,
|
||||
verbose=self.verbose
|
||||
)
|
||||
|
||||
# Check for API errors
|
||||
if "error" in vm_info:
|
||||
# Use debug level instead of warning to reduce log noise during polling
|
||||
logger.debug(f"API request error: {vm_info['error']}")
|
||||
return {
|
||||
"name": name,
|
||||
"status": "running", # Container is running even if API is not responsive
|
||||
"api_status": "error",
|
||||
"error": vm_info["error"],
|
||||
"container_status": container_status
|
||||
}
|
||||
|
||||
# Process the VM status information
|
||||
vm_status = vm_info.get("status", "unknown")
|
||||
vnc_url = vm_info.get("vncUrl", "")
|
||||
ip_address = vm_info.get("ipAddress", "")
|
||||
|
||||
# IMPORTANT: Always ensure we have a valid IP address for connectivity
|
||||
# If the API doesn't return an IP address, default to localhost (127.0.0.1)
|
||||
# This makes the behavior consistent with LumeProvider
|
||||
if not ip_address and vm_status == "running":
|
||||
ip_address = "127.0.0.1"
|
||||
logger.info(f"No IP address returned from API, defaulting to {ip_address}")
|
||||
vm_info["ipAddress"] = ip_address
|
||||
|
||||
logger.info(f"VM {name} status: {vm_status}")
|
||||
|
||||
if ip_address and vnc_url:
|
||||
logger.info(f"VM {name} has IP: {ip_address} and VNC URL: {vnc_url}")
|
||||
elif not ip_address and not vnc_url and vm_status != "running":
|
||||
# Not running is expected in this case
|
||||
logger.info(f"VM {name} is not running yet. Status: {vm_status}")
|
||||
else:
|
||||
# Missing IP or VNC but status is running - this is unusual but handled with default IP
|
||||
logger.warning(f"VM {name} is running but missing expected fields. API response: {vm_info}")
|
||||
|
||||
# Return the full status information
|
||||
return {
|
||||
"name": name,
|
||||
"status": vm_status,
|
||||
"ip_address": ip_address,
|
||||
"vnc_url": vnc_url,
|
||||
"api_status": "ok",
|
||||
"container_status": container_status,
|
||||
**vm_info # Include all fields from the API response
|
||||
}
|
||||
except subprocess.SubprocessError as e:
|
||||
logger.error(f"Failed to check container status: {e}")
|
||||
return {
|
||||
"name": name,
|
||||
"status": "error",
|
||||
"error": f"Failed to check container status: {str(e)}"
|
||||
}
|
||||
|
||||
async def list_vms(self) -> List[Dict[str, Any]]:
|
||||
"""List all VMs managed by this provider.
|
||||
|
||||
For Lumier provider, there is only one VM per container.
|
||||
"""
|
||||
try:
|
||||
status = await self.get_vm("default")
|
||||
return [status] if status.get("status") != "unknown" else []
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list VMs: {e}")
|
||||
return []
|
||||
|
||||
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Run a VM with the given options.
|
||||
|
||||
Args:
|
||||
image: Name/tag of the image to use
|
||||
name: Name of the VM to run (used for the container name and Docker image tag)
|
||||
run_opts: Options for running the VM, including:
|
||||
- cpu: Number of CPU cores
|
||||
- memory: Amount of memory (e.g. "8GB")
|
||||
- noVNC_port: Specific port for noVNC interface
|
||||
|
||||
Returns:
|
||||
Dictionary with VM status information
|
||||
"""
|
||||
# Set the container name using the VM name for consistency
|
||||
self.container_name = name
|
||||
try:
|
||||
# First, check if container already exists and remove it
|
||||
try:
|
||||
check_cmd = ["docker", "ps", "-a", "--filter", f"name={self.container_name}", "--format", "{{.ID}}"]
|
||||
check_result = subprocess.run(check_cmd, capture_output=True, text=True)
|
||||
existing_container = check_result.stdout.strip()
|
||||
|
||||
if existing_container:
|
||||
logger.info(f"Removing existing container: {self.container_name}")
|
||||
remove_cmd = ["docker", "rm", "-f", self.container_name]
|
||||
subprocess.run(remove_cmd, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(f"Error removing existing container: {e}")
|
||||
# Continue anyway, next steps will fail if there's a real problem
|
||||
|
||||
# Prepare the Docker run command
|
||||
cmd = ["docker", "run", "-d", "--name", self.container_name]
|
||||
|
||||
cmd.extend(["-p", f"{self.vnc_port}:8006"])
|
||||
logger.debug(f"Using specified noVNC_port: {self.vnc_port}")
|
||||
|
||||
# Set API URL using the API port
|
||||
self._api_url = f"http://{self.host}:{self.api_port}"
|
||||
|
||||
# Parse memory setting
|
||||
memory_mb = self._parse_memory(run_opts.get("memory", "8GB"))
|
||||
|
||||
# Add storage volume mount if storage is specified (for persistent VM storage)
|
||||
if self.storage and self.storage != "ephemeral":
|
||||
# Create storage directory if it doesn't exist
|
||||
storage_dir = os.path.abspath(os.path.expanduser(self.storage or ""))
|
||||
os.makedirs(storage_dir, exist_ok=True)
|
||||
|
||||
# Add volume mount for storage
|
||||
cmd.extend([
|
||||
"-v", f"{storage_dir}:/storage",
|
||||
"-e", f"HOST_STORAGE_PATH={storage_dir}"
|
||||
])
|
||||
logger.debug(f"Using persistent storage at: {storage_dir}")
|
||||
|
||||
# Add shared folder volume mount if shared_path is specified
|
||||
if self.shared_path:
|
||||
# Create shared directory if it doesn't exist
|
||||
shared_dir = os.path.abspath(os.path.expanduser(self.shared_path or ""))
|
||||
os.makedirs(shared_dir, exist_ok=True)
|
||||
|
||||
# Add volume mount for shared folder
|
||||
cmd.extend([
|
||||
"-v", f"{shared_dir}:/shared",
|
||||
"-e", f"HOST_SHARED_PATH={shared_dir}"
|
||||
])
|
||||
logger.debug(f"Using shared folder at: {shared_dir}")
|
||||
|
||||
# Add environment variables
|
||||
# Always use the container_name as the VM_NAME for consistency
|
||||
# Use the VM image passed from the Computer class
|
||||
logger.debug(f"Using VM image: {self.image}")
|
||||
|
||||
# If ghcr.io is in the image, use the full image name
|
||||
if "ghcr.io" in self.image:
|
||||
vm_image = self.image
|
||||
else:
|
||||
vm_image = f"ghcr.io/trycua/{self.image}"
|
||||
|
||||
cmd.extend([
|
||||
"-e", f"VM_NAME={self.container_name}",
|
||||
"-e", f"VERSION={vm_image}",
|
||||
"-e", f"CPU_CORES={run_opts.get('cpu', '4')}",
|
||||
"-e", f"RAM_SIZE={memory_mb}",
|
||||
])
|
||||
|
||||
# Specify the Lumier image with the full image name
|
||||
lumier_image = "trycua/lumier:latest"
|
||||
|
||||
# First check if the image exists locally
|
||||
try:
|
||||
logger.debug(f"Checking if Docker image {lumier_image} exists locally...")
|
||||
check_image_cmd = ["docker", "image", "inspect", lumier_image]
|
||||
subprocess.run(check_image_cmd, capture_output=True, check=True)
|
||||
logger.debug(f"Docker image {lumier_image} found locally.")
|
||||
except subprocess.CalledProcessError:
|
||||
# Image doesn't exist locally
|
||||
logger.warning(f"\nWARNING: Docker image {lumier_image} not found locally.")
|
||||
logger.warning("The system will attempt to pull it from Docker Hub, which may fail if you have network connectivity issues.")
|
||||
logger.warning("If the Docker pull fails, you may need to manually pull the image first with:")
|
||||
logger.warning(f" docker pull {lumier_image}\n")
|
||||
|
||||
# Add the image to the command
|
||||
cmd.append(lumier_image)
|
||||
|
||||
# Print the Docker command for debugging
|
||||
logger.debug(f"DOCKER COMMAND: {' '.join(cmd)}")
|
||||
|
||||
# Run the container with improved error handling
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
if "no route to host" in str(e.stderr).lower() or "failed to resolve reference" in str(e.stderr).lower():
|
||||
error_msg = (f"Network error while trying to pull Docker image '{lumier_image}'\n"
|
||||
f"Error: {e.stderr}\n\n"
|
||||
f"SOLUTION: Please try one of the following:\n"
|
||||
f"1. Check your internet connection\n"
|
||||
f"2. Pull the image manually with: docker pull {lumier_image}\n"
|
||||
f"3. Check if Docker is running properly\n")
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
raise
|
||||
|
||||
# Container started, now check VM status with polling
|
||||
logger.debug("Container started, checking VM status...")
|
||||
logger.debug("NOTE: This may take some time while the VM image is being pulled and initialized")
|
||||
|
||||
# Start a background thread to show container logs in real-time
|
||||
import threading
|
||||
|
||||
def show_container_logs():
|
||||
# Give the container a moment to start generating logs
|
||||
time.sleep(1)
|
||||
logger.debug(f"\n---- CONTAINER LOGS FOR '{name}' (LIVE) ----")
|
||||
logger.debug("Showing logs as they are generated. Press Ctrl+C to stop viewing logs...\n")
|
||||
|
||||
try:
|
||||
# Use docker logs with follow option
|
||||
log_cmd = ["docker", "logs", "--tail", "30", "--follow", name]
|
||||
process = subprocess.Popen(log_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
text=True, bufsize=1, universal_newlines=True)
|
||||
|
||||
# Read and print logs line by line
|
||||
for line in process.stdout:
|
||||
logger.debug(line, end='')
|
||||
|
||||
# Break if process has exited
|
||||
if process.poll() is not None:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"\nError showing container logs: {e}")
|
||||
if self.verbose:
|
||||
logger.error(f"Error in log streaming thread: {e}")
|
||||
finally:
|
||||
logger.debug("\n---- LOG STREAMING ENDED ----")
|
||||
# Make sure process is terminated
|
||||
if 'process' in locals() and process.poll() is None:
|
||||
process.terminate()
|
||||
|
||||
# Start log streaming in a background thread if verbose mode is enabled
|
||||
log_thread = threading.Thread(target=show_container_logs)
|
||||
log_thread.daemon = True # Thread will exit when main program exits
|
||||
log_thread.start()
|
||||
|
||||
# Skip waiting for container readiness and just poll get_vm directly
|
||||
# Poll the get_vm method indefinitely until the VM is ready with an IP address
|
||||
attempt = 0
|
||||
consecutive_errors = 0
|
||||
vm_running = False
|
||||
|
||||
while True: # Wait indefinitely
|
||||
try:
|
||||
# Use longer delays to give the system time to initialize
|
||||
if attempt > 0:
|
||||
# Start with 5s delay, then increase gradually up to 30s for later attempts
|
||||
# But use shorter delays while we're getting API errors
|
||||
if consecutive_errors > 0 and consecutive_errors < 5:
|
||||
wait_time = 3 # Use shorter delays when we're getting API errors
|
||||
else:
|
||||
wait_time = min(30, 5 + (attempt * 2))
|
||||
|
||||
logger.debug(f"Waiting {wait_time}s before retry #{attempt+1}...")
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
# Try to get VM status
|
||||
logger.debug(f"Checking VM status (attempt {attempt+1})...")
|
||||
vm_status = await self.get_vm(name)
|
||||
|
||||
# Check for API errors
|
||||
if 'error' in vm_status:
|
||||
consecutive_errors += 1
|
||||
error_msg = vm_status.get('error', 'Unknown error')
|
||||
|
||||
# Only print a user-friendly status message, not the raw error
|
||||
# since _lume_api_get already logged the technical details
|
||||
if consecutive_errors == 1 or attempt % 5 == 0:
|
||||
if 'Empty reply from server' in error_msg:
|
||||
logger.info("API server is starting up - container is running, but API isn't fully initialized yet.")
|
||||
logger.info("This is expected during the initial VM setup - will continue polling...")
|
||||
else:
|
||||
# Don't repeat the exact same error message each time
|
||||
logger.warning(f"API request error (attempt {attempt+1}): {error_msg}")
|
||||
# Just log that we're still working on it
|
||||
if attempt > 3:
|
||||
logger.debug("Still waiting for the API server to become available...")
|
||||
|
||||
# If we're getting errors but container is running, that's normal during startup
|
||||
if vm_status.get('status') == 'running':
|
||||
if not vm_running:
|
||||
logger.info("Container is running, waiting for the VM within it to become fully ready...")
|
||||
logger.info("This might take a minute while the VM initializes...")
|
||||
vm_running = True
|
||||
|
||||
# Increase counter and continue
|
||||
attempt += 1
|
||||
continue
|
||||
|
||||
# Reset consecutive error counter when we get a successful response
|
||||
consecutive_errors = 0
|
||||
|
||||
# If the VM is running, check if it has an IP address (which means it's fully ready)
|
||||
if vm_status.get('status') == 'running':
|
||||
vm_running = True
|
||||
|
||||
# Check if we have an IP address, which means the VM is fully ready
|
||||
if 'ip_address' in vm_status and vm_status['ip_address']:
|
||||
logger.info(f"VM is now fully running with IP: {vm_status.get('ip_address')}")
|
||||
if 'vnc_url' in vm_status and vm_status['vnc_url']:
|
||||
logger.info(f"VNC URL: {vm_status.get('vnc_url')}")
|
||||
return vm_status
|
||||
else:
|
||||
logger.debug("VM is running but still initializing network interfaces...")
|
||||
logger.debug("Waiting for IP address to be assigned...")
|
||||
else:
|
||||
# VM exists but might still be starting up
|
||||
status = vm_status.get('status', 'unknown')
|
||||
logger.debug(f"VM found but status is: {status}. Continuing to poll...")
|
||||
|
||||
# Increase counter for next iteration's delay calculation
|
||||
attempt += 1
|
||||
|
||||
# If we reach a very large number of attempts, give a reassuring message but continue
|
||||
if attempt % 10 == 0:
|
||||
logger.debug(f"Still waiting after {attempt} attempts. This might take several minutes for first-time setup.")
|
||||
if not vm_running and attempt >= 20:
|
||||
logger.warning("\nNOTE: First-time VM initialization can be slow as images are downloaded.")
|
||||
logger.warning("If this continues for more than 10 minutes, you may want to check:")
|
||||
logger.warning(" 1. Docker logs with: docker logs " + name)
|
||||
logger.warning(" 2. If your network can access container registries")
|
||||
logger.warning("Press Ctrl+C to abort if needed.\n")
|
||||
|
||||
# After 150 attempts (likely over 30-40 minutes), return current status
|
||||
if attempt >= 150:
|
||||
logger.debug(f"Reached 150 polling attempts. VM status is: {vm_status.get('status', 'unknown')}")
|
||||
logger.debug("Returning current VM status, but please check Docker logs if there are issues.")
|
||||
return vm_status
|
||||
|
||||
except Exception as e:
|
||||
# Always continue retrying, but with increasing delays
|
||||
logger.warning(f"Error checking VM status (attempt {attempt+1}): {e}. Will retry.")
|
||||
consecutive_errors += 1
|
||||
|
||||
# If we've had too many consecutive errors, might be a deeper problem
|
||||
if consecutive_errors >= 10:
|
||||
logger.warning(f"\nWARNING: Encountered {consecutive_errors} consecutive errors while checking VM status.")
|
||||
logger.warning("You may need to check the Docker container logs or restart the process.")
|
||||
logger.warning(f"Error details: {str(e)}\n")
|
||||
|
||||
# Increase attempt counter for next iteration
|
||||
attempt += 1
|
||||
|
||||
# After many consecutive errors, add a delay to avoid hammering the system
|
||||
if attempt > 5:
|
||||
error_delay = min(30, 10 + attempt)
|
||||
logger.warning(f"Multiple connection errors, waiting {error_delay}s before next attempt...")
|
||||
await asyncio.sleep(error_delay)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_msg = f"Failed to start Lumier container: {e.stderr if hasattr(e, 'stderr') else str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
async def _wait_for_container_ready(self, container_name: str, timeout: int = 90) -> bool:
|
||||
"""Wait for the Lumier container to be fully ready with a valid API response.
|
||||
|
||||
Args:
|
||||
container_name: Name of the Docker container to check
|
||||
timeout: Maximum time to wait in seconds (default: 90 seconds)
|
||||
|
||||
Returns:
|
||||
True if the container is running, even if API is not fully ready.
|
||||
This allows operations to continue with appropriate fallbacks.
|
||||
"""
|
||||
start_time = time.time()
|
||||
api_ready = False
|
||||
container_running = False
|
||||
|
||||
logger.debug(f"Waiting for container {container_name} to be ready (timeout: {timeout}s)...")
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
# Check if container is running
|
||||
try:
|
||||
check_cmd = ["docker", "ps", "--filter", f"name={container_name}", "--format", "{{.Status}}"]
|
||||
result = subprocess.run(check_cmd, capture_output=True, text=True, check=True)
|
||||
container_status = result.stdout.strip()
|
||||
|
||||
if container_status and container_status.startswith("Up"):
|
||||
container_running = True
|
||||
logger.info(f"Container {container_name} is running with status: {container_status}")
|
||||
else:
|
||||
logger.warning(f"Container {container_name} not yet running, status: {container_status}")
|
||||
# container is not running yet, wait and try again
|
||||
await asyncio.sleep(2) # Longer sleep to give Docker time
|
||||
continue
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(f"Error checking container status: {e}")
|
||||
await asyncio.sleep(2)
|
||||
continue
|
||||
|
||||
# Container is running, check if API is responsive
|
||||
try:
|
||||
# First check the health endpoint
|
||||
api_url = f"http://{self.host}:{self.api_port}/health"
|
||||
logger.info(f"Checking API health at: {api_url}")
|
||||
|
||||
# Use longer timeout for API health check since it may still be initializing
|
||||
curl_cmd = ["curl", "-s", "--connect-timeout", "5", "--max-time", "10", api_url]
|
||||
result = subprocess.run(curl_cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode == 0 and "ok" in result.stdout.lower():
|
||||
api_ready = True
|
||||
logger.info(f"API is ready at {api_url}")
|
||||
break
|
||||
else:
|
||||
# API health check failed, now let's check if the VM status endpoint is responsive
|
||||
# This covers cases where the health endpoint isn't implemented but the VM API is working
|
||||
vm_api_url = f"http://{self.host}:{self.api_port}/lume/vms/{container_name}"
|
||||
if self.storage:
|
||||
import urllib.parse
|
||||
encoded_storage = urllib.parse.quote_plus(self.storage)
|
||||
vm_api_url += f"?storage={encoded_storage}"
|
||||
|
||||
curl_vm_cmd = ["curl", "-s", "--connect-timeout", "5", "--max-time", "10", vm_api_url]
|
||||
vm_result = subprocess.run(curl_vm_cmd, capture_output=True, text=True)
|
||||
|
||||
if vm_result.returncode == 0 and vm_result.stdout.strip():
|
||||
# VM API responded with something - consider the API ready
|
||||
api_ready = True
|
||||
logger.info(f"VM API is ready at {vm_api_url}")
|
||||
break
|
||||
else:
|
||||
curl_code = result.returncode
|
||||
if curl_code == 0:
|
||||
curl_code = vm_result.returncode
|
||||
|
||||
# Map common curl error codes to helpful messages
|
||||
if curl_code == 7:
|
||||
curl_error = "Failed to connect - API server is starting up"
|
||||
elif curl_code == 22:
|
||||
curl_error = "HTTP error returned from API server"
|
||||
elif curl_code == 28:
|
||||
curl_error = "Operation timeout - API server is slow to respond"
|
||||
elif curl_code == 52:
|
||||
curl_error = "Empty reply from server - API is starting but not ready"
|
||||
elif curl_code == 56:
|
||||
curl_error = "Network problem during data transfer"
|
||||
else:
|
||||
curl_error = f"Unknown curl error code: {curl_code}"
|
||||
|
||||
logger.info(f"API not ready yet: {curl_error}")
|
||||
except subprocess.SubprocessError as e:
|
||||
logger.warning(f"Error checking API status: {e}")
|
||||
|
||||
# If the container is running but API is not ready, that's OK - we'll just wait
|
||||
# a bit longer before checking again, as the container may still be initializing
|
||||
elapsed_seconds = time.time() - start_time
|
||||
if int(elapsed_seconds) % 5 == 0: # Only print status every 5 seconds to reduce verbosity
|
||||
logger.debug(f"Waiting for API to initialize... ({elapsed_seconds:.1f}s / {timeout}s)")
|
||||
|
||||
await asyncio.sleep(3) # Longer sleep between API checks
|
||||
|
||||
# Handle timeout - if the container is running but API is not ready, that's not
|
||||
# necessarily an error - the API might just need more time to start up
|
||||
if not container_running:
|
||||
logger.warning(f"Timed out waiting for container {container_name} to start")
|
||||
return False
|
||||
|
||||
if not api_ready:
|
||||
logger.warning(f"Container {container_name} is running, but API is not fully ready yet.")
|
||||
logger.warning(f"NOTE: You may see some 'API request failed' messages while the API initializes.")
|
||||
|
||||
# Return True if container is running, even if API isn't ready yet
|
||||
# This allows VM operations to proceed, with appropriate retries for API calls
|
||||
return container_running
|
||||
|
||||
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Stop a running VM by stopping the Lumier container."""
|
||||
try:
|
||||
# Use Docker commands to stop the container directly
|
||||
if hasattr(self, '_container_id') and self._container_id:
|
||||
logger.info(f"Stopping Lumier container: {self.container_name}")
|
||||
cmd = ["docker", "stop", self.container_name]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
logger.info(f"Container stopped: {result.stdout.strip()}")
|
||||
|
||||
# Return minimal status info
|
||||
return {
|
||||
"name": name,
|
||||
"status": "stopped",
|
||||
"container_id": self._container_id,
|
||||
}
|
||||
else:
|
||||
# Try to find the container by name
|
||||
check_cmd = ["docker", "ps", "-a", "--filter", f"name={self.container_name}", "--format", "{{.ID}}"]
|
||||
check_result = subprocess.run(check_cmd, capture_output=True, text=True)
|
||||
container_id = check_result.stdout.strip()
|
||||
|
||||
if container_id:
|
||||
logger.info(f"Found container ID: {container_id}")
|
||||
cmd = ["docker", "stop", self.container_name]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
logger.info(f"Container stopped: {result.stdout.strip()}")
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"status": "stopped",
|
||||
"container_id": container_id,
|
||||
}
|
||||
else:
|
||||
logger.warning(f"No container found with name {self.container_name}")
|
||||
return {
|
||||
"name": name,
|
||||
"status": "unknown",
|
||||
}
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_msg = f"Failed to stop container: {e.stderr if hasattr(e, 'stderr') else str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise RuntimeError(f"Failed to stop Lumier container: {error_msg}")
|
||||
|
||||
# update_vm is not implemented as it's not needed for Lumier
|
||||
# The BaseVMProvider requires it, so we provide a minimal implementation
|
||||
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Not implemented for Lumier provider."""
|
||||
logger.warning("update_vm is not implemented for Lumier provider")
|
||||
return {"name": name, "status": "unchanged"}
|
||||
|
||||
async def get_logs(self, name: str, num_lines: int = 100, follow: bool = False, timeout: Optional[int] = None) -> str:
|
||||
"""Get the logs from the Lumier container.
|
||||
|
||||
Args:
|
||||
name: Name of the VM/container to get logs for
|
||||
num_lines: Number of recent log lines to return (default: 100)
|
||||
follow: If True, follow the logs (stream new logs as they are generated)
|
||||
timeout: Optional timeout in seconds for follow mode (None means no timeout)
|
||||
|
||||
Returns:
|
||||
Container logs as a string
|
||||
|
||||
Note:
|
||||
If follow=True, this function will continuously stream logs until timeout
|
||||
or until interrupted. The output will be printed to console in real-time.
|
||||
"""
|
||||
if not HAS_LUMIER:
|
||||
error_msg = "Docker is not available. Cannot get container logs."
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
# Make sure we have a container name
|
||||
container_name = name
|
||||
|
||||
# Check if the container exists and is running
|
||||
try:
|
||||
# Check if the container exists
|
||||
inspect_cmd = ["docker", "container", "inspect", container_name]
|
||||
result = subprocess.run(inspect_cmd, capture_output=True, text=True)
|
||||
|
||||
if result.returncode != 0:
|
||||
error_msg = f"Container '{container_name}' does not exist or is not accessible"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Error checking container status: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
# Base docker logs command
|
||||
log_cmd = ["docker", "logs"]
|
||||
|
||||
# Add tail parameter to limit the number of lines
|
||||
log_cmd.extend(["--tail", str(num_lines)])
|
||||
|
||||
# Handle follow mode with or without timeout
|
||||
if follow:
|
||||
log_cmd.append("--follow")
|
||||
|
||||
if timeout is not None:
|
||||
# For follow mode with timeout, we'll run the command and handle the timeout
|
||||
log_cmd.append(container_name)
|
||||
logger.info(f"Following logs for container '{container_name}' with timeout {timeout}s")
|
||||
logger.info(f"\n---- CONTAINER LOGS FOR '{container_name}' (LIVE) ----")
|
||||
logger.info(f"Press Ctrl+C to stop following logs\n")
|
||||
|
||||
try:
|
||||
# Run with timeout
|
||||
process = subprocess.Popen(log_cmd, text=True)
|
||||
|
||||
# Wait for the specified timeout
|
||||
if timeout:
|
||||
try:
|
||||
process.wait(timeout=timeout)
|
||||
except subprocess.TimeoutExpired:
|
||||
process.terminate() # Stop after timeout
|
||||
logger.info(f"\n---- LOG FOLLOWING STOPPED (timeout {timeout}s reached) ----")
|
||||
else:
|
||||
# Without timeout, wait for user interruption
|
||||
process.wait()
|
||||
|
||||
return "Logs were displayed to console in follow mode"
|
||||
except KeyboardInterrupt:
|
||||
process.terminate()
|
||||
logger.info("\n---- LOG FOLLOWING STOPPED (user interrupted) ----")
|
||||
return "Logs were displayed to console in follow mode (interrupted)"
|
||||
else:
|
||||
# For follow mode without timeout, we'll print a helpful message
|
||||
log_cmd.append(container_name)
|
||||
logger.info(f"Following logs for container '{container_name}' indefinitely")
|
||||
logger.info(f"\n---- CONTAINER LOGS FOR '{container_name}' (LIVE) ----")
|
||||
logger.info(f"Press Ctrl+C to stop following logs\n")
|
||||
|
||||
try:
|
||||
# Run the command and let it run until interrupted
|
||||
process = subprocess.Popen(log_cmd, text=True)
|
||||
process.wait() # Wait indefinitely (until user interrupts)
|
||||
return "Logs were displayed to console in follow mode"
|
||||
except KeyboardInterrupt:
|
||||
process.terminate()
|
||||
logger.info("\n---- LOG FOLLOWING STOPPED (user interrupted) ----")
|
||||
return "Logs were displayed to console in follow mode (interrupted)"
|
||||
else:
|
||||
# For non-follow mode, capture and return the logs as a string
|
||||
log_cmd.append(container_name)
|
||||
logger.info(f"Getting {num_lines} log lines for container '{container_name}'")
|
||||
|
||||
try:
|
||||
result = subprocess.run(log_cmd, capture_output=True, text=True, check=True)
|
||||
logs = result.stdout
|
||||
|
||||
# Only print header and logs if there's content
|
||||
if logs.strip():
|
||||
logger.info(f"\n---- CONTAINER LOGS FOR '{container_name}' (LAST {num_lines} LINES) ----\n")
|
||||
logger.info(logs)
|
||||
logger.info(f"\n---- END OF LOGS ----")
|
||||
else:
|
||||
logger.info(f"\nNo logs available for container '{container_name}'")
|
||||
|
||||
return logs
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_msg = f"Error getting logs: {e.stderr}"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error getting logs: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return error_msg
|
||||
|
||||
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
|
||||
"""Get the IP address of a VM, waiting indefinitely until it's available.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to get the IP for
|
||||
storage: Optional storage path override
|
||||
retry_delay: Delay between retries in seconds (default: 2)
|
||||
|
||||
Returns:
|
||||
IP address of the VM when it becomes available
|
||||
"""
|
||||
# Use container_name = name for consistency
|
||||
self.container_name = name
|
||||
|
||||
# Track total attempts for logging purposes
|
||||
total_attempts = 0
|
||||
|
||||
# Loop indefinitely until we get a valid IP
|
||||
while True:
|
||||
total_attempts += 1
|
||||
|
||||
# Log retry message but not on first attempt
|
||||
if total_attempts > 1:
|
||||
logger.info(f"Waiting for VM {name} IP address (attempt {total_attempts})...")
|
||||
|
||||
try:
|
||||
# Get VM information
|
||||
vm_info = await self.get_vm(name, storage=storage)
|
||||
|
||||
# Check if we got a valid IP
|
||||
ip = vm_info.get("ip_address", None)
|
||||
if ip and ip != "unknown" and not ip.startswith("0.0.0.0"):
|
||||
logger.info(f"Got valid VM IP address: {ip}")
|
||||
return ip
|
||||
|
||||
# Check the VM status
|
||||
status = vm_info.get("status", "unknown")
|
||||
|
||||
# Special handling for Lumier: it may report "stopped" even when the VM is starting
|
||||
# If the VM information contains an IP but status is stopped, it might be a race condition
|
||||
if status == "stopped" and "ip_address" in vm_info:
|
||||
ip = vm_info.get("ip_address")
|
||||
if ip and ip != "unknown" and not ip.startswith("0.0.0.0"):
|
||||
logger.info(f"Found valid IP {ip} despite VM status being {status}")
|
||||
return ip
|
||||
logger.info(f"VM status is {status}, but still waiting for IP to be assigned")
|
||||
# If VM is not running yet, log and wait
|
||||
elif status != "running":
|
||||
logger.info(f"VM is not running yet (status: {status}). Waiting...")
|
||||
# If VM is running but no IP yet, wait and retry
|
||||
else:
|
||||
logger.info("VM is running but no valid IP address yet. Waiting...")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting VM {name} IP: {e}, continuing to wait...")
|
||||
|
||||
# Wait before next retry
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
# Add progress log every 10 attempts
|
||||
if total_attempts % 10 == 0:
|
||||
logger.info(f"Still waiting for VM {name} IP after {total_attempts} attempts...")
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry.
|
||||
|
||||
This method is called when entering an async context manager block.
|
||||
Returns self to be used in the context.
|
||||
"""
|
||||
logger.debug("Entering LumierProvider context")
|
||||
|
||||
# Initialize the API URL with the default value if not already set
|
||||
# This ensures get_vm can work before run_vm is called
|
||||
if not hasattr(self, '_api_url') or not self._api_url:
|
||||
self._api_url = f"http://{self.host}:{self.api_port}"
|
||||
logger.info(f"Initialized default Lumier API URL: {self._api_url}")
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit.
|
||||
|
||||
This method is called when exiting an async context manager block.
|
||||
It handles proper cleanup of resources, including stopping any running containers.
|
||||
"""
|
||||
logger.debug(f"Exiting LumierProvider context, handling exceptions: {exc_type}")
|
||||
try:
|
||||
# If we have a container ID, we should stop it to clean up resources
|
||||
if hasattr(self, '_container_id') and self._container_id:
|
||||
logger.info(f"Stopping Lumier container on context exit: {self.container_name}")
|
||||
try:
|
||||
cmd = ["docker", "stop", self.container_name]
|
||||
subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
logger.info(f"Container stopped during context exit: {self.container_name}")
|
||||
except subprocess.CalledProcessError as e:
|
||||
logger.warning(f"Failed to stop container during cleanup: {e.stderr}")
|
||||
# Don't raise an exception here, we want to continue with cleanup
|
||||
except Exception as e:
|
||||
logger.error(f"Error during LumierProvider cleanup: {e}")
|
||||
# We don't want to suppress the original exception if there was one
|
||||
if exc_type is None:
|
||||
raise
|
||||
# Return False to indicate that any exception should propagate
|
||||
return False
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Windows Sandbox provider for CUA Computer."""
|
||||
|
||||
try:
|
||||
import winsandbox
|
||||
HAS_WINSANDBOX = True
|
||||
except ImportError:
|
||||
HAS_WINSANDBOX = False
|
||||
|
||||
from .provider import WinSandboxProvider
|
||||
|
||||
__all__ = ["WinSandboxProvider", "HAS_WINSANDBOX"]
|
||||
468
libs/python/computer/computer/providers/winsandbox/provider.py
Normal file
468
libs/python/computer/computer/providers/winsandbox/provider.py
Normal file
@@ -0,0 +1,468 @@
|
||||
"""Windows Sandbox VM provider implementation using pywinsandbox."""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
from ..base import BaseVMProvider, VMProviderType
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import winsandbox
|
||||
HAS_WINSANDBOX = True
|
||||
except ImportError:
|
||||
HAS_WINSANDBOX = False
|
||||
|
||||
|
||||
class WinSandboxProvider(BaseVMProvider):
|
||||
"""Windows Sandbox VM provider implementation using pywinsandbox.
|
||||
|
||||
This provider uses Windows Sandbox to create isolated Windows environments.
|
||||
Storage is always ephemeral with Windows Sandbox.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
port: int = 7777,
|
||||
host: str = "localhost",
|
||||
storage: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
ephemeral: bool = True, # Windows Sandbox is always ephemeral
|
||||
memory_mb: int = 4096,
|
||||
networking: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
"""Initialize the Windows Sandbox provider.
|
||||
|
||||
Args:
|
||||
port: Port for the computer server (default: 7777)
|
||||
host: Host to use for connections (default: localhost)
|
||||
storage: Storage path (ignored - Windows Sandbox is always ephemeral)
|
||||
verbose: Enable verbose logging
|
||||
ephemeral: Always True for Windows Sandbox
|
||||
memory_mb: Memory allocation in MB (default: 4096)
|
||||
networking: Enable networking in sandbox (default: True)
|
||||
"""
|
||||
if not HAS_WINSANDBOX:
|
||||
raise ImportError(
|
||||
"pywinsandbox is required for WinSandboxProvider. "
|
||||
"Please install it with 'pip install pywinsandbox'"
|
||||
)
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.verbose = verbose
|
||||
self.memory_mb = memory_mb
|
||||
self.networking = networking
|
||||
|
||||
# Windows Sandbox is always ephemeral
|
||||
if not ephemeral:
|
||||
logger.warning("Windows Sandbox storage is always ephemeral. Ignoring ephemeral=False.")
|
||||
self.ephemeral = True
|
||||
|
||||
# Storage is always ephemeral for Windows Sandbox
|
||||
if storage and storage != "ephemeral":
|
||||
logger.warning("Windows Sandbox does not support persistent storage. Using ephemeral storage.")
|
||||
self.storage = "ephemeral"
|
||||
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Track active sandboxes
|
||||
self._active_sandboxes: Dict[str, Any] = {}
|
||||
|
||||
@property
|
||||
def provider_type(self) -> VMProviderType:
|
||||
"""Get the provider type."""
|
||||
return VMProviderType.WINSANDBOX
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter async context manager."""
|
||||
# Verify Windows Sandbox is available
|
||||
if not HAS_WINSANDBOX:
|
||||
raise ImportError("pywinsandbox is not available")
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit async context manager."""
|
||||
# Clean up any active sandboxes
|
||||
for name, sandbox in self._active_sandboxes.items():
|
||||
try:
|
||||
sandbox.shutdown()
|
||||
self.logger.info(f"Terminated sandbox: {name}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error terminating sandbox {name}: {e}")
|
||||
|
||||
self._active_sandboxes.clear()
|
||||
|
||||
async def get_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Get VM information by name.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to get information for
|
||||
storage: Ignored for Windows Sandbox (always ephemeral)
|
||||
|
||||
Returns:
|
||||
Dictionary with VM information including status, IP address, etc.
|
||||
"""
|
||||
if name not in self._active_sandboxes:
|
||||
return {
|
||||
"name": name,
|
||||
"status": "stopped",
|
||||
"ip_address": None,
|
||||
"storage": "ephemeral"
|
||||
}
|
||||
|
||||
sandbox = self._active_sandboxes[name]
|
||||
|
||||
# Check if sandbox is still running
|
||||
try:
|
||||
# Try to ping the sandbox to see if it's responsive
|
||||
try:
|
||||
sandbox.rpyc.modules.os.getcwd()
|
||||
sandbox_responsive = True
|
||||
except Exception:
|
||||
sandbox_responsive = False
|
||||
|
||||
if not sandbox_responsive:
|
||||
return {
|
||||
"name": name,
|
||||
"status": "starting",
|
||||
"ip_address": None,
|
||||
"storage": "ephemeral",
|
||||
"memory_mb": self.memory_mb,
|
||||
"networking": self.networking
|
||||
}
|
||||
|
||||
# Check for computer server address file
|
||||
server_address_file = r"C:\Users\WDAGUtilityAccount\Desktop\shared_windows_sandbox_dir\server_address"
|
||||
|
||||
try:
|
||||
# Check if the server address file exists
|
||||
file_exists = sandbox.rpyc.modules.os.path.exists(server_address_file)
|
||||
|
||||
if file_exists:
|
||||
# Read the server address file
|
||||
with sandbox.rpyc.builtin.open(server_address_file, 'r') as f:
|
||||
server_address = f.read().strip()
|
||||
|
||||
if server_address and ':' in server_address:
|
||||
# Parse IP:port from the file
|
||||
ip_address, port = server_address.split(':', 1)
|
||||
|
||||
# Verify the server is actually responding
|
||||
try:
|
||||
import socket
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.settimeout(3)
|
||||
result = sock.connect_ex((ip_address, int(port)))
|
||||
sock.close()
|
||||
|
||||
if result == 0:
|
||||
# Server is responding
|
||||
status = "running"
|
||||
self.logger.debug(f"Computer server found at {ip_address}:{port}")
|
||||
else:
|
||||
# Server file exists but not responding
|
||||
status = "starting"
|
||||
ip_address = None
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Error checking server connectivity: {e}")
|
||||
status = "starting"
|
||||
ip_address = None
|
||||
else:
|
||||
# File exists but doesn't contain valid address
|
||||
status = "starting"
|
||||
ip_address = None
|
||||
else:
|
||||
# Server address file doesn't exist yet
|
||||
status = "starting"
|
||||
ip_address = None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Error checking server address file: {e}")
|
||||
status = "starting"
|
||||
ip_address = None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error checking sandbox status: {e}")
|
||||
status = "error"
|
||||
ip_address = None
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"status": status,
|
||||
"ip_address": ip_address,
|
||||
"storage": "ephemeral",
|
||||
"memory_mb": self.memory_mb,
|
||||
"networking": self.networking
|
||||
}
|
||||
|
||||
async def list_vms(self) -> List[Dict[str, Any]]:
|
||||
"""List all available VMs."""
|
||||
vms = []
|
||||
for name in self._active_sandboxes.keys():
|
||||
vm_info = await self.get_vm(name)
|
||||
vms.append(vm_info)
|
||||
return vms
|
||||
|
||||
async def run_vm(self, image: str, name: str, run_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Run a VM with the given options.
|
||||
|
||||
Args:
|
||||
image: Image name (ignored for Windows Sandbox - always uses host Windows)
|
||||
name: Name of the VM to run
|
||||
run_opts: Dictionary of run options (memory, cpu, etc.)
|
||||
storage: Ignored for Windows Sandbox (always ephemeral)
|
||||
|
||||
Returns:
|
||||
Dictionary with VM run status and information
|
||||
"""
|
||||
if name in self._active_sandboxes:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Sandbox {name} is already running"
|
||||
}
|
||||
|
||||
try:
|
||||
# Extract options from run_opts
|
||||
memory_mb = run_opts.get("memory_mb", self.memory_mb)
|
||||
if isinstance(memory_mb, str):
|
||||
# Convert memory string like "4GB" to MB
|
||||
if memory_mb.upper().endswith("GB"):
|
||||
memory_mb = int(float(memory_mb[:-2]) * 1024)
|
||||
elif memory_mb.upper().endswith("MB"):
|
||||
memory_mb = int(memory_mb[:-2])
|
||||
else:
|
||||
memory_mb = self.memory_mb
|
||||
|
||||
networking = run_opts.get("networking", self.networking)
|
||||
|
||||
# Create folder mappers if shared directories are specified
|
||||
folder_mappers = []
|
||||
shared_directories = run_opts.get("shared_directories", [])
|
||||
for shared_dir in shared_directories:
|
||||
if isinstance(shared_dir, dict):
|
||||
host_path = shared_dir.get("hostPath", "")
|
||||
elif isinstance(shared_dir, str):
|
||||
host_path = shared_dir
|
||||
else:
|
||||
continue
|
||||
|
||||
if host_path and os.path.exists(host_path):
|
||||
folder_mappers.append(winsandbox.FolderMapper(host_path))
|
||||
|
||||
self.logger.info(f"Creating Windows Sandbox: {name}")
|
||||
self.logger.info(f"Memory: {memory_mb}MB, Networking: {networking}")
|
||||
if folder_mappers:
|
||||
self.logger.info(f"Shared directories: {len(folder_mappers)}")
|
||||
|
||||
# Create the sandbox without logon script
|
||||
sandbox = winsandbox.new_sandbox(
|
||||
memory_mb=str(memory_mb),
|
||||
networking=networking,
|
||||
folder_mappers=folder_mappers
|
||||
)
|
||||
|
||||
# Store the sandbox
|
||||
self._active_sandboxes[name] = sandbox
|
||||
|
||||
self.logger.info(f"Windows Sandbox {name} created successfully")
|
||||
|
||||
# Setup the computer server in the sandbox
|
||||
await self._setup_computer_server(sandbox, name)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"name": name,
|
||||
"status": "starting",
|
||||
"memory_mb": memory_mb,
|
||||
"networking": networking,
|
||||
"storage": "ephemeral"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to create Windows Sandbox {name}: {e}")
|
||||
# stack trace
|
||||
import traceback
|
||||
self.logger.error(f"Stack trace: {traceback.format_exc()}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to create sandbox: {str(e)}"
|
||||
}
|
||||
|
||||
async def stop_vm(self, name: str, storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Stop a running VM.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to stop
|
||||
storage: Ignored for Windows Sandbox
|
||||
|
||||
Returns:
|
||||
Dictionary with stop status and information
|
||||
"""
|
||||
if name not in self._active_sandboxes:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Sandbox {name} is not running"
|
||||
}
|
||||
|
||||
try:
|
||||
sandbox = self._active_sandboxes[name]
|
||||
|
||||
# Terminate the sandbox
|
||||
sandbox.shutdown()
|
||||
|
||||
# Remove from active sandboxes
|
||||
del self._active_sandboxes[name]
|
||||
|
||||
self.logger.info(f"Windows Sandbox {name} stopped successfully")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"name": name,
|
||||
"status": "stopped"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to stop Windows Sandbox {name}: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to stop sandbox: {str(e)}"
|
||||
}
|
||||
|
||||
async def update_vm(self, name: str, update_opts: Dict[str, Any], storage: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Update VM configuration.
|
||||
|
||||
Note: Windows Sandbox does not support runtime configuration updates.
|
||||
The sandbox must be stopped and restarted with new configuration.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to update
|
||||
update_opts: Dictionary of update options
|
||||
storage: Ignored for Windows Sandbox
|
||||
|
||||
Returns:
|
||||
Dictionary with update status and information
|
||||
"""
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Windows Sandbox does not support runtime configuration updates. "
|
||||
"Please stop and restart the sandbox with new configuration."
|
||||
}
|
||||
|
||||
async def get_ip(self, name: str, storage: Optional[str] = None, retry_delay: int = 2) -> str:
|
||||
"""Get the IP address of a VM, waiting indefinitely until it's available.
|
||||
|
||||
Args:
|
||||
name: Name of the VM to get the IP for
|
||||
storage: Ignored for Windows Sandbox
|
||||
retry_delay: Delay between retries in seconds (default: 2)
|
||||
|
||||
Returns:
|
||||
IP address of the VM when it becomes available
|
||||
"""
|
||||
total_attempts = 0
|
||||
|
||||
# Loop indefinitely until we get a valid IP
|
||||
while True:
|
||||
total_attempts += 1
|
||||
|
||||
# Log retry message but not on first attempt
|
||||
if total_attempts > 1:
|
||||
self.logger.info(f"Waiting for Windows Sandbox {name} IP address (attempt {total_attempts})...")
|
||||
|
||||
try:
|
||||
# Get VM information
|
||||
vm_info = await self.get_vm(name, storage=storage)
|
||||
|
||||
# Check if we got a valid IP
|
||||
ip = vm_info.get("ip_address", None)
|
||||
if ip and ip != "unknown" and not ip.startswith("0.0.0.0"):
|
||||
self.logger.info(f"Got valid Windows Sandbox IP address: {ip}")
|
||||
return ip
|
||||
|
||||
# Check the VM status
|
||||
status = vm_info.get("status", "unknown")
|
||||
|
||||
# If VM is not running yet, log and wait
|
||||
if status != "running":
|
||||
self.logger.info(f"Windows Sandbox is not running yet (status: {status}). Waiting...")
|
||||
# If VM is running but no IP yet, wait and retry
|
||||
else:
|
||||
self.logger.info("Windows Sandbox is running but no valid IP address yet. Waiting...")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error getting Windows Sandbox {name} IP: {e}, continuing to wait...")
|
||||
|
||||
# Wait before next retry
|
||||
await asyncio.sleep(retry_delay)
|
||||
|
||||
# Add progress log every 10 attempts
|
||||
if total_attempts % 10 == 0:
|
||||
self.logger.info(f"Still waiting for Windows Sandbox {name} IP after {total_attempts} attempts...")
|
||||
|
||||
async def _setup_computer_server(self, sandbox, name: str, visible: bool = False):
|
||||
"""Setup the computer server in the Windows Sandbox using RPyC.
|
||||
|
||||
Args:
|
||||
sandbox: The Windows Sandbox instance
|
||||
name: Name of the sandbox
|
||||
visible: Whether the opened process should be visible (default: False)
|
||||
"""
|
||||
try:
|
||||
self.logger.info(f"Setting up computer server in sandbox {name}...")
|
||||
|
||||
# Read the PowerShell setup script
|
||||
script_path = os.path.join(os.path.dirname(__file__), "setup_script.ps1")
|
||||
with open(script_path, 'r', encoding='utf-8') as f:
|
||||
setup_script_content = f.read()
|
||||
|
||||
# Write the setup script to the sandbox using RPyC
|
||||
script_dest_path = r"C:\Users\WDAGUtilityAccount\setup_cua.ps1"
|
||||
|
||||
self.logger.info(f"Writing setup script to {script_dest_path}")
|
||||
with sandbox.rpyc.builtin.open(script_dest_path, 'w') as f:
|
||||
f.write(setup_script_content)
|
||||
|
||||
# Execute the PowerShell script in the background
|
||||
self.logger.info("Executing setup script in sandbox...")
|
||||
|
||||
# Use subprocess to run PowerShell script
|
||||
import subprocess
|
||||
powershell_cmd = [
|
||||
"powershell.exe",
|
||||
"-ExecutionPolicy", "Bypass",
|
||||
"-NoExit", # Keep window open after script completes
|
||||
"-File", script_dest_path
|
||||
]
|
||||
|
||||
# Set creation flags based on visibility preference
|
||||
if visible:
|
||||
# CREATE_NEW_CONSOLE - creates a new console window (visible)
|
||||
creation_flags = 0x00000010
|
||||
else:
|
||||
creation_flags = 0x08000000 # CREATE_NO_WINDOW
|
||||
|
||||
# Start the process using RPyC
|
||||
process = sandbox.rpyc.modules.subprocess.Popen(
|
||||
powershell_cmd,
|
||||
creationflags=creation_flags,
|
||||
shell=False
|
||||
)
|
||||
|
||||
# # Sleep for 30 seconds
|
||||
# await asyncio.sleep(30)
|
||||
|
||||
ip = await self.get_ip(name)
|
||||
self.logger.info(f"Sandbox IP: {ip}")
|
||||
self.logger.info(f"Setup script started in background in sandbox {name} with PID: {process.pid}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to setup computer server in sandbox {name}: {e}")
|
||||
import traceback
|
||||
self.logger.error(f"Stack trace: {traceback.format_exc()}")
|
||||
@@ -0,0 +1,124 @@
|
||||
# Setup script for Windows Sandbox CUA Computer provider
|
||||
# This script runs when the sandbox starts
|
||||
|
||||
Write-Host "Starting CUA Computer setup in Windows Sandbox..."
|
||||
|
||||
# Function to find the mapped Python installation from pywinsandbox
|
||||
function Find-MappedPython {
|
||||
Write-Host "Looking for mapped Python installation from pywinsandbox..."
|
||||
|
||||
# pywinsandbox maps the host Python installation to the sandbox
|
||||
# Look for mapped shared folders on the desktop (common pywinsandbox pattern)
|
||||
$desktopPath = "C:\Users\WDAGUtilityAccount\Desktop"
|
||||
$sharedFolders = Get-ChildItem -Path $desktopPath -Directory -ErrorAction SilentlyContinue
|
||||
|
||||
foreach ($folder in $sharedFolders) {
|
||||
# Look for Python executables in shared folders
|
||||
$pythonPaths = @(
|
||||
"$($folder.FullName)\python.exe",
|
||||
"$($folder.FullName)\Scripts\python.exe",
|
||||
"$($folder.FullName)\bin\python.exe"
|
||||
)
|
||||
|
||||
foreach ($pythonPath in $pythonPaths) {
|
||||
if (Test-Path $pythonPath) {
|
||||
try {
|
||||
$version = & $pythonPath --version 2>&1
|
||||
if ($version -match "Python") {
|
||||
Write-Host "Found mapped Python: $pythonPath - $version"
|
||||
return $pythonPath
|
||||
}
|
||||
} catch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Also check subdirectories that might contain Python
|
||||
$subDirs = Get-ChildItem -Path $folder.FullName -Directory -ErrorAction SilentlyContinue
|
||||
foreach ($subDir in $subDirs) {
|
||||
$pythonPath = "$($subDir.FullName)\python.exe"
|
||||
if (Test-Path $pythonPath) {
|
||||
try {
|
||||
$version = & $pythonPath --version 2>&1
|
||||
if ($version -match "Python") {
|
||||
Write-Host "Found mapped Python in subdirectory: $pythonPath - $version"
|
||||
return $pythonPath
|
||||
}
|
||||
} catch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Fallback: try common Python commands that might be available
|
||||
$pythonCommands = @("python", "py", "python3")
|
||||
foreach ($cmd in $pythonCommands) {
|
||||
try {
|
||||
$version = & $cmd --version 2>&1
|
||||
if ($version -match "Python") {
|
||||
Write-Host "Found Python via command '$cmd': $version"
|
||||
return $cmd
|
||||
}
|
||||
} catch {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
throw "Could not find any Python installation (mapped or otherwise)"
|
||||
}
|
||||
|
||||
try {
|
||||
# Step 1: Find the mapped Python installation
|
||||
Write-Host "Step 1: Finding mapped Python installation..."
|
||||
$pythonExe = Find-MappedPython
|
||||
Write-Host "Using Python: $pythonExe"
|
||||
|
||||
# Verify Python works and show version
|
||||
$pythonVersion = & $pythonExe --version 2>&1
|
||||
Write-Host "Python version: $pythonVersion"
|
||||
|
||||
# Step 2: Install cua-computer-server directly
|
||||
Write-Host "Step 2: Installing cua-computer-server..."
|
||||
|
||||
Write-Host "Upgrading pip..."
|
||||
& $pythonExe -m pip install --upgrade pip --quiet
|
||||
|
||||
Write-Host "Installing cua-computer-server..."
|
||||
& $pythonExe -m pip install cua-computer-server --quiet
|
||||
|
||||
Write-Host "cua-computer-server installation completed."
|
||||
|
||||
# Step 3: Start computer server in background
|
||||
Write-Host "Step 3: Starting computer server in background..."
|
||||
Write-Host "Starting computer server with: $pythonExe"
|
||||
|
||||
# Start the computer server in the background
|
||||
$serverProcess = Start-Process -FilePath $pythonExe -ArgumentList "-m", "computer_server.main" -WindowStyle Hidden -PassThru
|
||||
Write-Host "Computer server started in background with PID: $($serverProcess.Id)"
|
||||
|
||||
# Give it a moment to start
|
||||
Start-Sleep -Seconds 3
|
||||
|
||||
# Check if the process is still running
|
||||
if (Get-Process -Id $serverProcess.Id -ErrorAction SilentlyContinue) {
|
||||
Write-Host "Computer server is running successfully in background"
|
||||
} else {
|
||||
throw "Computer server failed to start or exited immediately"
|
||||
}
|
||||
|
||||
} catch {
|
||||
Write-Error "Setup failed: $_"
|
||||
Write-Host "Error details: $($_.Exception.Message)"
|
||||
Write-Host "Stack trace: $($_.ScriptStackTrace)"
|
||||
Write-Host ""
|
||||
Write-Host "Press any key to close this window..."
|
||||
$null = $Host.UI.RawUI.ReadKey("NoEcho,IncludeKeyDown")
|
||||
exit 1
|
||||
}
|
||||
|
||||
Write-Host ""
|
||||
Write-Host "Setup completed successfully!"
|
||||
Write-Host "Press any key to close this window..."
|
||||
$null = $Host.UI.RawUI.ReadKey("NoEcho,IncludeKeyDown")
|
||||
116
libs/python/computer/computer/telemetry.py
Normal file
116
libs/python/computer/computer/telemetry.py
Normal file
@@ -0,0 +1,116 @@
|
||||
"""Computer telemetry for tracking anonymous usage and feature usage."""
|
||||
|
||||
import logging
|
||||
import platform
|
||||
from typing import Any
|
||||
|
||||
# Import the core telemetry module
|
||||
TELEMETRY_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from core.telemetry import (
|
||||
increment,
|
||||
is_telemetry_enabled,
|
||||
is_telemetry_globally_disabled,
|
||||
record_event,
|
||||
)
|
||||
|
||||
def increment_counter(counter_name: str, value: int = 1) -> None:
|
||||
"""Wrapper for increment to maintain backward compatibility."""
|
||||
if is_telemetry_enabled():
|
||||
increment(counter_name, value)
|
||||
|
||||
def set_dimension(name: str, value: Any) -> None:
|
||||
"""Set a dimension that will be attached to all events."""
|
||||
logger = logging.getLogger("computer.telemetry")
|
||||
logger.debug(f"Setting dimension {name}={value}")
|
||||
|
||||
TELEMETRY_AVAILABLE = True
|
||||
logger = logging.getLogger("computer.telemetry")
|
||||
logger.info("Successfully imported telemetry")
|
||||
except ImportError as e:
|
||||
logger = logging.getLogger("computer.telemetry")
|
||||
logger.warning(f"Could not import telemetry: {e}")
|
||||
TELEMETRY_AVAILABLE = False
|
||||
|
||||
|
||||
# Local fallbacks in case core telemetry isn't available
|
||||
def _noop(*args: Any, **kwargs: Any) -> None:
|
||||
"""No-op function for when telemetry is not available."""
|
||||
pass
|
||||
|
||||
|
||||
logger = logging.getLogger("computer.telemetry")
|
||||
|
||||
# If telemetry isn't available, use no-op functions
|
||||
if not TELEMETRY_AVAILABLE:
|
||||
logger.debug("Telemetry not available, using no-op functions")
|
||||
record_event = _noop # type: ignore
|
||||
increment_counter = _noop # type: ignore
|
||||
set_dimension = _noop # type: ignore
|
||||
get_telemetry_client = lambda: None # type: ignore
|
||||
flush = _noop # type: ignore
|
||||
is_telemetry_enabled = lambda: False # type: ignore
|
||||
is_telemetry_globally_disabled = lambda: True # type: ignore
|
||||
|
||||
# Get system info once to use in telemetry
|
||||
SYSTEM_INFO = {
|
||||
"os": platform.system().lower(),
|
||||
"os_version": platform.release(),
|
||||
"python_version": platform.python_version(),
|
||||
}
|
||||
|
||||
|
||||
def enable_telemetry() -> bool:
|
||||
"""Enable telemetry if available.
|
||||
|
||||
Returns:
|
||||
bool: True if telemetry was successfully enabled, False otherwise
|
||||
"""
|
||||
global TELEMETRY_AVAILABLE
|
||||
|
||||
# Check if globally disabled using core function
|
||||
if TELEMETRY_AVAILABLE and is_telemetry_globally_disabled():
|
||||
logger.info("Telemetry is globally disabled via environment variable - cannot enable")
|
||||
return False
|
||||
|
||||
# Already enabled
|
||||
if TELEMETRY_AVAILABLE:
|
||||
return True
|
||||
|
||||
# Try to import and enable
|
||||
try:
|
||||
# Verify we can import core telemetry
|
||||
from core.telemetry import record_event # type: ignore
|
||||
|
||||
TELEMETRY_AVAILABLE = True
|
||||
logger.info("Telemetry successfully enabled")
|
||||
return True
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not enable telemetry: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def is_telemetry_enabled() -> bool:
|
||||
"""Check if telemetry is enabled.
|
||||
|
||||
Returns:
|
||||
bool: True if telemetry is enabled, False otherwise
|
||||
"""
|
||||
# Use the core function if available, otherwise use our local flag
|
||||
if TELEMETRY_AVAILABLE:
|
||||
from core.telemetry import is_telemetry_enabled as core_is_enabled
|
||||
|
||||
return core_is_enabled()
|
||||
return False
|
||||
|
||||
|
||||
def record_computer_initialization() -> None:
|
||||
"""Record when a computer instance is initialized."""
|
||||
if TELEMETRY_AVAILABLE and is_telemetry_enabled():
|
||||
record_event("computer_initialized", SYSTEM_INFO)
|
||||
|
||||
# Set dimensions that will be attached to all events
|
||||
set_dimension("os", SYSTEM_INFO["os"])
|
||||
set_dimension("os_version", SYSTEM_INFO["os_version"])
|
||||
set_dimension("python_version", SYSTEM_INFO["python_version"])
|
||||
1
libs/python/computer/computer/ui/__init__.py
Normal file
1
libs/python/computer/computer/ui/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""UI modules for the Computer Interface."""
|
||||
15
libs/python/computer/computer/ui/__main__.py
Normal file
15
libs/python/computer/computer/ui/__main__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
Main entry point for computer.ui module.
|
||||
|
||||
This allows running the computer UI with:
|
||||
python -m computer.ui
|
||||
|
||||
Instead of:
|
||||
python -m computer.ui.gradio.app
|
||||
"""
|
||||
|
||||
from .gradio.app import create_gradio_ui
|
||||
|
||||
if __name__ == "__main__":
|
||||
app = create_gradio_ui()
|
||||
app.launch()
|
||||
6
libs/python/computer/computer/ui/gradio/__init__.py
Normal file
6
libs/python/computer/computer/ui/gradio/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Gradio UI for Computer UI."""
|
||||
|
||||
import gradio as gr
|
||||
from typing import Optional
|
||||
|
||||
from .app import create_gradio_ui
|
||||
1651
libs/python/computer/computer/ui/gradio/app.py
Normal file
1651
libs/python/computer/computer/ui/gradio/app.py
Normal file
File diff suppressed because it is too large
Load Diff
101
libs/python/computer/computer/utils.py
Normal file
101
libs/python/computer/computer/utils.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import base64
|
||||
from typing import Tuple, Optional, Dict, Any
|
||||
from PIL import Image, ImageDraw
|
||||
import io
|
||||
|
||||
def decode_base64_image(base64_str: str) -> bytes:
|
||||
"""Decode a base64 string into image bytes."""
|
||||
return base64.b64decode(base64_str)
|
||||
|
||||
def encode_base64_image(image_bytes: bytes) -> str:
|
||||
"""Encode image bytes to base64 string."""
|
||||
return base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
def bytes_to_image(image_bytes: bytes) -> Image.Image:
|
||||
"""Convert bytes to PIL Image.
|
||||
|
||||
Args:
|
||||
image_bytes: Raw image bytes
|
||||
|
||||
Returns:
|
||||
PIL.Image: The converted image
|
||||
"""
|
||||
return Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
def image_to_bytes(image: Image.Image, format: str = 'PNG') -> bytes:
|
||||
"""Convert PIL Image to bytes."""
|
||||
buf = io.BytesIO()
|
||||
image.save(buf, format=format)
|
||||
return buf.getvalue()
|
||||
|
||||
def resize_image(image_bytes: bytes, scale_factor: float) -> bytes:
|
||||
"""Resize an image by a scale factor.
|
||||
|
||||
Args:
|
||||
image_bytes: The original image as bytes
|
||||
scale_factor: Factor to scale the image by (e.g., 0.5 for half size, 2.0 for double)
|
||||
|
||||
Returns:
|
||||
bytes: The resized image as bytes
|
||||
"""
|
||||
image = bytes_to_image(image_bytes)
|
||||
if scale_factor != 1.0:
|
||||
new_size = (int(image.width * scale_factor), int(image.height * scale_factor))
|
||||
image = image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
return image_to_bytes(image)
|
||||
|
||||
def draw_box(
|
||||
image_bytes: bytes,
|
||||
x: int,
|
||||
y: int,
|
||||
width: int,
|
||||
height: int,
|
||||
color: str = "#FF0000",
|
||||
thickness: int = 2
|
||||
) -> bytes:
|
||||
"""Draw a box on an image.
|
||||
|
||||
Args:
|
||||
image_bytes: The original image as bytes
|
||||
x: X coordinate of top-left corner
|
||||
y: Y coordinate of top-left corner
|
||||
width: Width of the box
|
||||
height: Height of the box
|
||||
color: Color of the box in hex format
|
||||
thickness: Thickness of the box border in pixels
|
||||
|
||||
Returns:
|
||||
bytes: The modified image as bytes
|
||||
"""
|
||||
# Convert bytes to PIL Image
|
||||
image = bytes_to_image(image_bytes)
|
||||
|
||||
# Create drawing context
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
# Draw rectangle
|
||||
draw.rectangle(
|
||||
[(x, y), (x + width, y + height)],
|
||||
outline=color,
|
||||
width=thickness
|
||||
)
|
||||
|
||||
# Convert back to bytes
|
||||
return image_to_bytes(image)
|
||||
|
||||
def get_image_size(image_bytes: bytes) -> Tuple[int, int]:
|
||||
"""Get the dimensions of an image.
|
||||
|
||||
Args:
|
||||
image_bytes: The image as bytes
|
||||
|
||||
Returns:
|
||||
Tuple[int, int]: Width and height of the image
|
||||
"""
|
||||
image = bytes_to_image(image_bytes)
|
||||
return image.size
|
||||
|
||||
def parse_vm_info(vm_info: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Parse VM info from pylume response."""
|
||||
if not vm_info:
|
||||
return None
|
||||
2
libs/python/computer/poetry.toml
Normal file
2
libs/python/computer/poetry.toml
Normal file
@@ -0,0 +1,2 @@
|
||||
[virtualenvs]
|
||||
in-project = true
|
||||
73
libs/python/computer/pyproject.toml
Normal file
73
libs/python/computer/pyproject.toml
Normal file
@@ -0,0 +1,73 @@
|
||||
[build-system]
|
||||
requires = ["pdm-backend"]
|
||||
build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-computer"
|
||||
version = "0.2.0"
|
||||
description = "Computer-Use Interface (CUI) framework powering Cua"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "TryCua", email = "gh@trycua.com" }
|
||||
]
|
||||
dependencies = [
|
||||
"pillow>=10.0.0",
|
||||
"websocket-client>=1.8.0",
|
||||
"websockets>=12.0",
|
||||
"aiohttp>=3.9.0",
|
||||
"cua-core>=0.1.0,<0.2.0",
|
||||
"pydantic>=2.11.1"
|
||||
]
|
||||
requires-python = ">=3.11"
|
||||
|
||||
[project.optional-dependencies]
|
||||
lume = [
|
||||
]
|
||||
lumier = [
|
||||
]
|
||||
ui = [
|
||||
"gradio>=5.23.3,<6.0.0",
|
||||
"python-dotenv>=1.0.1,<2.0.0",
|
||||
"datasets>=3.6.0,<4.0.0",
|
||||
]
|
||||
all = [
|
||||
# Include all optional dependencies
|
||||
"gradio>=5.23.3,<6.0.0",
|
||||
"python-dotenv>=1.0.1,<2.0.0",
|
||||
"datasets>=3.6.0,<4.0.0",
|
||||
]
|
||||
|
||||
[tool.pdm]
|
||||
distribution = true
|
||||
|
||||
[tool.pdm.build]
|
||||
includes = ["computer/"]
|
||||
source-includes = ["tests/", "README.md", "LICENSE"]
|
||||
|
||||
[tool.black]
|
||||
line-length = 100
|
||||
target-version = ["py311"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py311"
|
||||
select = ["E", "F", "B", "I"]
|
||||
fix = true
|
||||
|
||||
[tool.ruff.format]
|
||||
docstring-code-format = true
|
||||
|
||||
[tool.mypy]
|
||||
strict = true
|
||||
python_version = "3.11"
|
||||
ignore_missing_imports = true
|
||||
disallow_untyped_defs = true
|
||||
check_untyped_defs = true
|
||||
warn_return_any = true
|
||||
show_error_codes = true
|
||||
warn_unused_ignores = false
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
python_files = "test_*.py"
|
||||
Reference in New Issue
Block a user