Merge branch 'main' into feat/api_key_overrides

This commit is contained in:
Dillon DuPont
2025-10-30 12:34:40 -04:00
142 changed files with 7173 additions and 4438 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.35
current_version = 0.4.37
commit = True
tag = True
tag_name = agent-v{new_version}

View File

@@ -73,8 +73,8 @@ if __name__ == "__main__":
## Docs
- [Agent Loops](https://trycua.com/docs/agent-sdk/agent-loops)
- [Supported Agents](https://trycua.com/docs/agent-sdk/supported-agents)
- [Supported Models](https://trycua.com/docs/agent-sdk/supported-models)
- [Supported Agents](https://trycua.com/docs/agent-sdk/supported-agents/computer-use-agents)
- [Supported Models](https://trycua.com/docs/agent-sdk/supported-model-providers)
- [Chat History](https://trycua.com/docs/agent-sdk/chat-history)
- [Callbacks](https://trycua.com/docs/agent-sdk/callbacks)
- [Custom Tools](https://trycua.com/docs/agent-sdk/custom-tools)

View File

@@ -28,8 +28,12 @@ class AsyncComputerHandler(Protocol):
"""Get screen dimensions as (width, height)."""
...
async def screenshot(self) -> str:
"""Take a screenshot and return as base64 string."""
async def screenshot(self, text: Optional[str] = None) -> str:
"""Take a screenshot and return as base64 string.
Args:
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
"""
...
async def click(self, x: int, y: int, button: str = "left") -> None:

View File

@@ -36,8 +36,12 @@ class cuaComputerHandler(AsyncComputerHandler):
screen_size = await self.interface.get_screen_size()
return screen_size["width"], screen_size["height"]
async def screenshot(self) -> str:
"""Take a screenshot and return as base64 string."""
async def screenshot(self, text: Optional[str] = None) -> str:
"""Take a screenshot and return as base64 string.
Args:
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
"""
assert self.interface is not None
screenshot_bytes = await self.interface.screenshot()
return base64.b64encode(screenshot_bytes).decode("utf-8")

View File

@@ -122,8 +122,12 @@ class CustomComputerHandler(AsyncComputerHandler):
return self._last_screenshot_size
async def screenshot(self) -> str:
"""Take a screenshot and return as base64 string."""
async def screenshot(self, text: Optional[str] = None) -> str:
"""Take a screenshot and return as base64 string.
Args:
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
"""
result = await self._call_function(self.functions["screenshot"])
b64_str = self._to_b64_str(result) # type: ignore

View File

@@ -14,67 +14,73 @@ import litellm
from ..decorators import register_agent
from ..loops.base import AsyncAgentConfig
from ..responses import (
convert_completion_messages_to_responses_items,
convert_responses_items_to_completion_messages,
)
from ..types import AgentCapability, AgentResponse, Messages, Tools
SOM_TOOL_SCHEMA = {
"type": "function",
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment",
],
"description": "The action to perform",
},
"element_id": {
"type": "integer",
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)",
},
"start_element_id": {
"type": "integer",
"description": "The ID of the element to start dragging from (required for drag action)",
},
"end_element_id": {
"type": "integer",
"description": "The ID of the element to drag to (required for drag action)",
},
"text": {
"type": "string",
"description": "The text to type (required for type action)",
},
"keys": {
"type": "string",
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')",
},
"button": {
"type": "string",
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
"function": {
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment",
],
"description": "The action to perform",
},
"element_id": {
"type": "integer",
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)",
},
"start_element_id": {
"type": "integer",
"description": "The ID of the element to start dragging from (required for drag action)",
},
"end_element_id": {
"type": "integer",
"description": "The ID of the element to drag to (required for drag action)",
},
"text": {
"type": "string",
"description": "The text to type (required for type action)",
},
"keys": {
"type": "string",
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')",
},
"button": {
"type": "string",
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
},
},
"required": ["action", "element_id"],
},
"required": ["action"],
},
}
@@ -243,18 +249,20 @@ async def replace_computer_call_with_function(
"id": item.get("id"),
"call_id": item.get("call_id"),
"status": "completed",
# Fall back to string representation
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})",
}
]
elif item_type == "computer_call_output":
# Simple conversion: computer_call_output -> function_call_output
output = item.get("output")
if isinstance(output, dict):
output = [output]
return [
{
"type": "function_call_output",
"call_id": item.get("call_id"),
"content": [item.get("output")],
"output": item.get("output"),
"id": item.get("id"),
"status": "completed",
}
@@ -296,6 +304,13 @@ class OmniparserConfig(AsyncAgentConfig):
llm_model = model.split("+")[-1]
# Get screen dimensions from computer handler
try:
width, height = await computer_handler.get_dimensions()
except Exception:
# Fallback to default dimensions if method fails
width, height = 1024, 768
# Prepare tools for OpenAI API
openai_tools, id2xy = _prepare_tools_for_omniparser(tools)
@@ -309,27 +324,43 @@ class OmniparserConfig(AsyncAgentConfig):
result = parser.parse(image_data)
if _on_screenshot:
await _on_screenshot(result.annotated_image_base64, "annotated_image")
for element in result.elements:
id2xy[element.id] = (
(element.bbox.x1 + element.bbox.x2) / 2,
(element.bbox.y1 + element.bbox.y2) / 2,
)
# handle computer calls -> function calls
new_messages = []
for message in messages:
# Convert OmniParser normalized coordinates (0-1) to absolute pixels, convert to pixels
for element in result.elements:
norm_x = (element.bbox.x1 + element.bbox.x2) / 2
norm_y = (element.bbox.y1 + element.bbox.y2) / 2
pixel_x = int(norm_x * width)
pixel_y = int(norm_y * height)
id2xy[element.id] = (pixel_x, pixel_y)
# Replace the original screenshot with the annotated image
annotated_image_url = f"data:image/png;base64,{result.annotated_image_base64}"
last_computer_call_output["output"]["image_url"] = annotated_image_url
xy2id = {v: k for k, v in id2xy.items()}
messages_with_element_ids = []
for i, message in enumerate(messages):
if not isinstance(message, dict):
message = message.__dict__
new_messages += await replace_computer_call_with_function(message, id2xy) # type: ignore
messages = new_messages
msg_type = message.get("type")
if msg_type == "computer_call" and "action" in message:
action = message.get("action", {})
converted = await replace_computer_call_with_function(message, xy2id) # type: ignore
messages_with_element_ids += converted
completion_messages = convert_responses_items_to_completion_messages(
messages_with_element_ids, allow_images_in_tool_results=False
)
# Prepare API call kwargs
api_kwargs = {
"model": llm_model,
"input": messages,
"messages": completion_messages,
"tools": openai_tools if openai_tools else None,
"stream": stream,
"truncation": "auto",
"num_retries": max_retries,
**kwargs,
}
@@ -340,8 +371,8 @@ class OmniparserConfig(AsyncAgentConfig):
print(str(api_kwargs)[:1000])
# Use liteLLM responses
response = await litellm.aresponses(**api_kwargs)
# Use liteLLM completion
response = await litellm.acompletion(**api_kwargs)
# Call API end hook
if _on_api_end:
@@ -355,12 +386,45 @@ class OmniparserConfig(AsyncAgentConfig):
if _on_usage:
await _on_usage(usage)
# handle som function calls -> xy computer calls
new_output = []
for i in range(len(response.output)): # type: ignore
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) # type: ignore
response_dict = response.model_dump() # type: ignore
choice_messages = [choice["message"] for choice in response_dict["choices"]]
responses_items = []
for choice_message in choice_messages:
responses_items.extend(convert_completion_messages_to_responses_items([choice_message]))
return {"output": new_output, "usage": usage}
# Convert element_id → x,y (similar to moondream's convert_computer_calls_desc2xy)
final_output = []
for item in responses_items:
if item.get("type") == "computer_call" and "action" in item:
action = item["action"].copy()
# Handle single element_id
if "element_id" in action:
element_id = action["element_id"]
if element_id in id2xy:
x, y = id2xy[element_id]
action["x"] = x
action["y"] = y
del action["element_id"]
# Handle start_element_id and end_element_id for drag operations
elif "start_element_id" in action and "end_element_id" in action:
start_id = action["start_element_id"]
end_id = action["end_element_id"]
if start_id in id2xy and end_id in id2xy:
start_x, start_y = id2xy[start_id]
end_x, end_y = id2xy[end_id]
action["path"] = [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}]
del action["start_element_id"]
del action["end_element_id"]
converted_item = item.copy()
converted_item["action"] = action
final_output.append(converted_item)
else:
final_output.append(item)
return {"output": final_output, "usage": usage}
async def predict_click(
self, model: str, image_b64: str, instruction: str, **kwargs

View File

@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
[project]
name = "cua-agent"
version = "0.4.35"
version = "0.4.37"
description = "CUA (Computer Use) Agent for AI-driven computer interaction"
readme = "README.md"
authors = [

View File

@@ -0,0 +1,84 @@
"""Pytest configuration and shared fixtures for agent package tests.
This file contains shared fixtures and configuration for all agent tests.
Following SRP: This file ONLY handles test setup/teardown.
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
@pytest.fixture
def mock_litellm():
"""Mock liteLLM completion calls.
Use this fixture to avoid making real LLM API calls during tests.
Returns a mock that simulates LLM responses.
"""
with patch("litellm.acompletion") as mock_completion:
async def mock_response(*args, **kwargs):
"""Simulate a typical LLM response."""
return {
"id": "chatcmpl-test123",
"object": "chat.completion",
"created": 1234567890,
"model": kwargs.get("model", "anthropic/claude-3-5-sonnet-20241022"),
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a mocked response for testing.",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
},
}
mock_completion.side_effect = mock_response
yield mock_completion
@pytest.fixture
def mock_computer():
"""Mock Computer interface for agent tests.
Use this fixture to test agent logic without requiring a real Computer instance.
"""
computer = AsyncMock()
computer.interface = AsyncMock()
computer.interface.screenshot = AsyncMock(return_value=b"fake_screenshot_data")
computer.interface.left_click = AsyncMock()
computer.interface.type = AsyncMock()
computer.interface.key = AsyncMock()
# Mock context manager
computer.__aenter__ = AsyncMock(return_value=computer)
computer.__aexit__ = AsyncMock()
return computer
@pytest.fixture
def disable_telemetry(monkeypatch):
"""Disable telemetry for tests.
Use this fixture to ensure no telemetry is sent during tests.
"""
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
@pytest.fixture
def sample_messages():
"""Provide sample messages for testing.
Returns a list of messages in the expected format.
"""
return [{"role": "user", "content": "Take a screenshot and tell me what you see"}]

View File

@@ -0,0 +1,139 @@
"""Unit tests for ComputerAgent class.
This file tests ONLY the ComputerAgent initialization and basic functionality.
Following SRP: This file tests ONE class (ComputerAgent).
All external dependencies (liteLLM, Computer) are mocked.
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
class TestComputerAgentInitialization:
"""Test ComputerAgent initialization (SRP: Only tests initialization)."""
@patch("agent.agent.litellm")
def test_agent_initialization_with_model(self, mock_litellm, disable_telemetry):
"""Test that agent can be initialized with a model string."""
from agent import ComputerAgent
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
assert agent is not None
assert hasattr(agent, "model")
assert agent.model == "anthropic/claude-3-5-sonnet-20241022"
@patch("agent.agent.litellm")
def test_agent_initialization_with_tools(self, mock_litellm, disable_telemetry, mock_computer):
"""Test that agent can be initialized with tools."""
from agent import ComputerAgent
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022", tools=[mock_computer])
assert agent is not None
assert hasattr(agent, "tools")
@patch("agent.agent.litellm")
def test_agent_initialization_with_max_budget(self, mock_litellm, disable_telemetry):
"""Test that agent can be initialized with max trajectory budget."""
from agent import ComputerAgent
budget = 5.0
agent = ComputerAgent(
model="anthropic/claude-3-5-sonnet-20241022", max_trajectory_budget=budget
)
assert agent is not None
@patch("agent.agent.litellm")
def test_agent_requires_model(self, mock_litellm, disable_telemetry):
"""Test that agent requires a model parameter."""
from agent import ComputerAgent
with pytest.raises(TypeError):
# Should fail without model parameter - intentionally missing required argument
ComputerAgent() # type: ignore[call-arg]
class TestComputerAgentRun:
"""Test ComputerAgent.run() method (SRP: Only tests run logic)."""
@pytest.mark.asyncio
@patch("agent.agent.litellm")
async def test_agent_run_with_messages(self, mock_litellm, disable_telemetry, sample_messages):
"""Test that agent.run() works with valid messages."""
from agent import ComputerAgent
# Mock liteLLM response
mock_response = {
"id": "chatcmpl-test",
"choices": [
{
"message": {"role": "assistant", "content": "Test response"},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
}
mock_litellm.acompletion = AsyncMock(return_value=mock_response)
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
# Run should return an async generator
result_generator = agent.run(sample_messages)
assert result_generator is not None
# Check it's an async generator
assert hasattr(result_generator, "__anext__")
def test_agent_has_run_method(self, disable_telemetry):
"""Test that agent has run method available."""
from agent import ComputerAgent
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
# Verify run method exists
assert hasattr(agent, "run")
assert callable(agent.run)
def test_agent_has_agent_loop(self, disable_telemetry):
"""Test that agent has agent_loop initialized."""
from agent import ComputerAgent
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
# Verify agent_loop is initialized
assert hasattr(agent, "agent_loop")
assert agent.agent_loop is not None
class TestComputerAgentTypes:
"""Test AgentResponse and Messages types (SRP: Only tests type definitions)."""
def test_messages_type_exists(self):
"""Test that Messages type is exported."""
from agent import Messages
assert Messages is not None
def test_agent_response_type_exists(self):
"""Test that AgentResponse type is exported."""
from agent import AgentResponse
assert AgentResponse is not None
class TestComputerAgentIntegration:
"""Test ComputerAgent integration with Computer tool (SRP: Integration within package)."""
def test_agent_accepts_computer_tool(self, disable_telemetry, mock_computer):
"""Test that agent can be initialized with Computer tool."""
from agent import ComputerAgent
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022", tools=[mock_computer])
# Verify agent accepted the tool
assert agent is not None
assert hasattr(agent, "tools")

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.1.25
current_version = 0.1.28
commit = True
tag = True
tag_name = computer-server-v{new_version}

View File

@@ -85,6 +85,102 @@ class BaseFileHandler(ABC):
pass
class BaseDesktopHandler(ABC):
"""Abstract base class for OS-specific desktop handlers.
Categories:
- Wallpaper Actions: Methods for wallpaper operations
- Desktop shortcut actions: Methods for managing desktop shortcuts
"""
# Wallpaper Actions
@abstractmethod
async def get_desktop_environment(self) -> Dict[str, Any]:
"""Get the current desktop environment name."""
pass
@abstractmethod
async def set_wallpaper(self, path: str) -> Dict[str, Any]:
"""Set the desktop wallpaper to the file at path."""
pass
class BaseWindowHandler(ABC):
"""Abstract class for OS-specific window management handlers.
Categories:
- Window Management: Methods for application/window control
"""
# Window Management
@abstractmethod
async def open(self, target: str) -> Dict[str, Any]:
"""Open a file or URL with the default application."""
pass
@abstractmethod
async def launch(self, app: str, args: Optional[List[str]] = None) -> Dict[str, Any]:
"""Launch an application with optional arguments."""
pass
@abstractmethod
async def get_current_window_id(self) -> Dict[str, Any]:
"""Get the currently active window ID."""
pass
@abstractmethod
async def get_application_windows(self, app: str) -> Dict[str, Any]:
"""Get windows belonging to an application (by name or bundle)."""
pass
@abstractmethod
async def get_window_name(self, window_id: str) -> Dict[str, Any]:
"""Get the title/name of a window by ID."""
pass
@abstractmethod
async def get_window_size(self, window_id: str | int) -> Dict[str, Any]:
"""Get the size of a window by ID as {width, height}."""
pass
@abstractmethod
async def activate_window(self, window_id: str | int) -> Dict[str, Any]:
"""Bring a window to the foreground by ID."""
pass
@abstractmethod
async def close_window(self, window_id: str | int) -> Dict[str, Any]:
"""Close a window by ID."""
pass
@abstractmethod
async def get_window_position(self, window_id: str | int) -> Dict[str, Any]:
"""Get the top-left position of a window as {x, y}."""
pass
@abstractmethod
async def set_window_size(
self, window_id: str | int, width: int, height: int
) -> Dict[str, Any]:
"""Set the size of a window by ID."""
pass
@abstractmethod
async def set_window_position(self, window_id: str | int, x: int, y: int) -> Dict[str, Any]:
"""Set the position of a window by ID."""
pass
@abstractmethod
async def maximize_window(self, window_id: str | int) -> Dict[str, Any]:
"""Maximize a window by ID."""
pass
@abstractmethod
async def minimize_window(self, window_id: str | int) -> Dict[str, Any]:
"""Minimize a window by ID."""
pass
class BaseAutomationHandler(ABC):
"""Abstract base class for OS-specific automation handlers.

View File

@@ -4,7 +4,13 @@ from typing import Tuple, Type
from computer_server.diorama.base import BaseDioramaHandler
from .base import BaseAccessibilityHandler, BaseAutomationHandler, BaseFileHandler
from .base import (
BaseAccessibilityHandler,
BaseAutomationHandler,
BaseDesktopHandler,
BaseFileHandler,
BaseWindowHandler,
)
# Conditionally import platform-specific handlers
system = platform.system().lower()
@@ -17,7 +23,7 @@ elif system == "linux":
elif system == "windows":
from .windows import WindowsAccessibilityHandler, WindowsAutomationHandler
from .generic import GenericFileHandler
from .generic import GenericDesktopHandler, GenericFileHandler, GenericWindowHandler
class HandlerFactory:
@@ -49,9 +55,14 @@ class HandlerFactory:
raise RuntimeError(f"Failed to determine current OS: {str(e)}")
@staticmethod
def create_handlers() -> (
Tuple[BaseAccessibilityHandler, BaseAutomationHandler, BaseDioramaHandler, BaseFileHandler]
):
def create_handlers() -> Tuple[
BaseAccessibilityHandler,
BaseAutomationHandler,
BaseDioramaHandler,
BaseFileHandler,
BaseDesktopHandler,
BaseWindowHandler,
]:
"""Create and return appropriate handlers for the current OS.
Returns:
@@ -70,6 +81,8 @@ class HandlerFactory:
MacOSAutomationHandler(),
MacOSDioramaHandler(),
GenericFileHandler(),
GenericDesktopHandler(),
GenericWindowHandler(),
)
elif os_type == "linux":
return (
@@ -77,6 +90,8 @@ class HandlerFactory:
LinuxAutomationHandler(),
BaseDioramaHandler(),
GenericFileHandler(),
GenericDesktopHandler(),
GenericWindowHandler(),
)
elif os_type == "windows":
return (
@@ -84,6 +99,8 @@ class HandlerFactory:
WindowsAutomationHandler(),
BaseDioramaHandler(),
GenericFileHandler(),
GenericDesktopHandler(),
GenericWindowHandler(),
)
else:
raise NotImplementedError(f"OS '{os_type}' is not supported")

View File

@@ -2,15 +2,26 @@
Generic handlers for all OSes.
Includes:
- DesktopHandler
- FileHandler
"""
import base64
import os
import platform
import subprocess
import webbrowser
from pathlib import Path
from typing import Any, Dict, Optional
from .base import BaseFileHandler
from ..utils import wallpaper
from .base import BaseDesktopHandler, BaseFileHandler, BaseWindowHandler
try:
import pywinctl as pwc
except Exception: # pragma: no cover
pwc = None # type: ignore
def resolve_path(path: str) -> Path:
@@ -25,6 +36,233 @@ def resolve_path(path: str) -> Path:
return Path(path).expanduser().resolve()
# ===== Cross-platform Desktop command handlers =====
class GenericDesktopHandler(BaseDesktopHandler):
"""
Generic desktop handler providing desktop-related operations.
Implements:
- get_desktop_environment: detect current desktop environment
- set_wallpaper: set desktop wallpaper path
"""
async def get_desktop_environment(self) -> Dict[str, Any]:
"""
Get the current desktop environment.
Returns:
Dict containing 'success' boolean and either 'environment' string or 'error' string
"""
try:
env = wallpaper.get_desktop_environment()
return {"success": True, "environment": env}
except Exception as e:
return {"success": False, "error": str(e)}
async def set_wallpaper(self, path: str) -> Dict[str, Any]:
"""
Set the desktop wallpaper to the specified path.
Args:
path: The file path to set as wallpaper
Returns:
Dict containing 'success' boolean and optionally 'error' string
"""
try:
file_path = resolve_path(path)
ok = wallpaper.set_wallpaper(str(file_path))
return {"success": bool(ok)}
except Exception as e:
return {"success": False, "error": str(e)}
# ===== Cross-platform window control command handlers =====
class GenericWindowHandler(BaseWindowHandler):
"""
Cross-platform window management using pywinctl where possible.
"""
async def open(self, target: str) -> Dict[str, Any]:
try:
if target.startswith("http://") or target.startswith("https://"):
ok = webbrowser.open(target)
return {"success": bool(ok)}
path = str(resolve_path(target))
sys = platform.system().lower()
if sys == "darwin":
subprocess.Popen(["open", path])
elif sys == "linux":
subprocess.Popen(["xdg-open", path])
elif sys == "windows":
os.startfile(path) # type: ignore[attr-defined]
else:
return {"success": False, "error": f"Unsupported OS: {sys}"}
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def launch(self, app: str, args: Optional[list[str]] = None) -> Dict[str, Any]:
try:
if args:
proc = subprocess.Popen([app, *args])
else:
# allow shell command like "libreoffice --writer"
proc = subprocess.Popen(app, shell=True)
return {"success": True, "pid": proc.pid}
except Exception as e:
return {"success": False, "error": str(e)}
def _get_window_by_id(self, window_id: int | str) -> Optional[Any]:
if pwc is None:
raise RuntimeError("pywinctl not available")
# Find by native handle among Window objects; getAllWindowsDict keys are titles
try:
for w in pwc.getAllWindows():
if str(w.getHandle()) == str(window_id):
return w
return None
except Exception:
return None
async def get_current_window_id(self) -> Dict[str, Any]:
try:
if pwc is None:
return {"success": False, "error": "pywinctl not available"}
win = pwc.getActiveWindow()
if not win:
return {"success": False, "error": "No active window"}
return {"success": True, "window_id": win.getHandle()}
except Exception as e:
return {"success": False, "error": str(e)}
async def get_application_windows(self, app: str) -> Dict[str, Any]:
try:
if pwc is None:
return {"success": False, "error": "pywinctl not available"}
wins = pwc.getWindowsWithTitle(app, condition=pwc.Re.CONTAINS, flags=pwc.Re.IGNORECASE)
ids = [w.getHandle() for w in wins]
return {"success": True, "windows": ids}
except Exception as e:
return {"success": False, "error": str(e)}
async def get_window_name(self, window_id: int | str) -> Dict[str, Any]:
try:
if pwc is None:
return {"success": False, "error": "pywinctl not available"}
w = self._get_window_by_id(window_id)
if not w:
return {"success": False, "error": "Window not found"}
return {"success": True, "name": w.title}
except Exception as e:
return {"success": False, "error": str(e)}
async def get_window_size(self, window_id: int | str) -> Dict[str, Any]:
try:
if pwc is None:
return {"success": False, "error": "pywinctl not available"}
w = self._get_window_by_id(window_id)
if not w:
return {"success": False, "error": "Window not found"}
width, height = w.size
return {"success": True, "width": int(width), "height": int(height)}
except Exception as e:
return {"success": False, "error": str(e)}
async def get_window_position(self, window_id: int | str) -> Dict[str, Any]:
try:
if pwc is None:
return {"success": False, "error": "pywinctl not available"}
w = self._get_window_by_id(window_id)
if not w:
return {"success": False, "error": "Window not found"}
x, y = w.position
return {"success": True, "x": int(x), "y": int(y)}
except Exception as e:
return {"success": False, "error": str(e)}
async def set_window_size(
self, window_id: int | str, width: int, height: int
) -> Dict[str, Any]:
try:
if pwc is None:
return {"success": False, "error": "pywinctl not available"}
w = self._get_window_by_id(window_id)
if not w:
return {"success": False, "error": "Window not found"}
ok = w.resizeTo(int(width), int(height))
return {"success": bool(ok)}
except Exception as e:
return {"success": False, "error": str(e)}
async def set_window_position(self, window_id: int | str, x: int, y: int) -> Dict[str, Any]:
try:
if pwc is None:
return {"success": False, "error": "pywinctl not available"}
w = self._get_window_by_id(window_id)
if not w:
return {"success": False, "error": "Window not found"}
ok = w.moveTo(int(x), int(y))
return {"success": bool(ok)}
except Exception as e:
return {"success": False, "error": str(e)}
async def maximize_window(self, window_id: int | str) -> Dict[str, Any]:
try:
if pwc is None:
return {"success": False, "error": "pywinctl not available"}
w = self._get_window_by_id(window_id)
if not w:
return {"success": False, "error": "Window not found"}
ok = w.maximize()
return {"success": bool(ok)}
except Exception as e:
return {"success": False, "error": str(e)}
async def minimize_window(self, window_id: int | str) -> Dict[str, Any]:
try:
if pwc is None:
return {"success": False, "error": "pywinctl not available"}
w = self._get_window_by_id(window_id)
if not w:
return {"success": False, "error": "Window not found"}
ok = w.minimize()
return {"success": bool(ok)}
except Exception as e:
return {"success": False, "error": str(e)}
async def activate_window(self, window_id: int | str) -> Dict[str, Any]:
try:
if pwc is None:
return {"success": False, "error": "pywinctl not available"}
w = self._get_window_by_id(window_id)
if not w:
return {"success": False, "error": "Window not found"}
ok = w.activate()
return {"success": bool(ok)}
except Exception as e:
return {"success": False, "error": str(e)}
async def close_window(self, window_id: int | str) -> Dict[str, Any]:
try:
if pwc is None:
return {"success": False, "error": "pywinctl not available"}
w = self._get_window_by_id(window_id)
if not w:
return {"success": False, "error": "Window not found"}
ok = w.close()
return {"success": bool(ok)}
except Exception as e:
return {"success": False, "error": str(e)}
# ===== Cross-platform file system command handlers =====
class GenericFileHandler(BaseFileHandler):
"""
Generic file handler that provides file system operations for all operating systems.

View File

@@ -75,9 +75,14 @@ except Exception:
except Exception:
package_version = "unknown"
accessibility_handler, automation_handler, diorama_handler, file_handler = (
HandlerFactory.create_handlers()
)
(
accessibility_handler,
automation_handler,
diorama_handler,
file_handler,
desktop_handler,
window_handler,
) = HandlerFactory.create_handlers()
handlers = {
"version": lambda: {"protocol": protocol_version, "package": package_version},
# App-Use commands
@@ -99,6 +104,23 @@ handlers = {
"delete_file": file_handler.delete_file,
"create_dir": file_handler.create_dir,
"delete_dir": file_handler.delete_dir,
# Desktop commands
"get_desktop_environment": desktop_handler.get_desktop_environment,
"set_wallpaper": desktop_handler.set_wallpaper,
# Window management
"open": window_handler.open,
"launch": window_handler.launch,
"get_current_window_id": window_handler.get_current_window_id,
"get_application_windows": window_handler.get_application_windows,
"get_window_name": window_handler.get_window_name,
"get_window_size": window_handler.get_window_size,
"get_window_position": window_handler.get_window_position,
"set_window_size": window_handler.set_window_size,
"set_window_position": window_handler.set_window_position,
"maximize_window": window_handler.maximize_window,
"minimize_window": window_handler.minimize_window,
"activate_window": window_handler.activate_window,
"close_window": window_handler.close_window,
# Mouse commands
"mouse_down": automation_handler.mouse_down,
"mouse_up": automation_handler.mouse_up,

View File

@@ -0,0 +1,3 @@
from . import wallpaper
__all__ = ["wallpaper"]

View File

@@ -0,0 +1,321 @@
"""Set the desktop wallpaper."""
import os
import subprocess
import sys
from pathlib import Path
def get_desktop_environment() -> str:
"""
Returns the name of the current desktop environment.
"""
# From https://stackoverflow.com/a/21213358/2624876
# which takes from:
# http://stackoverflow.com/questions/2035657/what-is-my-current-desktop-environment
# and http://ubuntuforums.org/showthread.php?t=652320
# and http://ubuntuforums.org/showthread.php?t=1139057
if sys.platform in ["win32", "cygwin"]:
return "windows"
elif sys.platform == "darwin":
return "mac"
else: # Most likely either a POSIX system or something not much common
desktop_session = os.environ.get("DESKTOP_SESSION")
if (
desktop_session is not None
): # easier to match if we doesn't have to deal with character cases
desktop_session = desktop_session.lower()
if desktop_session in [
"gnome",
"unity",
"cinnamon",
"mate",
"xfce4",
"lxde",
"fluxbox",
"blackbox",
"openbox",
"icewm",
"jwm",
"afterstep",
"trinity",
"kde",
]:
return desktop_session
## Special cases ##
# Canonical sets $DESKTOP_SESSION to Lubuntu rather than LXDE if using LXDE.
# There is no guarantee that they will not do the same with the other desktop environments.
elif "xfce" in desktop_session or desktop_session.startswith("xubuntu"):
return "xfce4"
elif desktop_session.startswith("ubuntustudio"):
return "kde"
elif desktop_session.startswith("ubuntu"):
return "gnome"
elif desktop_session.startswith("lubuntu"):
return "lxde"
elif desktop_session.startswith("kubuntu"):
return "kde"
elif desktop_session.startswith("razor"): # e.g. razorkwin
return "razor-qt"
elif desktop_session.startswith("wmaker"): # e.g. wmaker-common
return "windowmaker"
gnome_desktop_session_id = os.environ.get("GNOME_DESKTOP_SESSION_ID")
if os.environ.get("KDE_FULL_SESSION") == "true":
return "kde"
elif gnome_desktop_session_id:
if "deprecated" not in gnome_desktop_session_id:
return "gnome2"
# From http://ubuntuforums.org/showthread.php?t=652320
elif is_running("xfce-mcs-manage"):
return "xfce4"
elif is_running("ksmserver"):
return "kde"
return "unknown"
def is_running(process: str) -> bool:
"""Returns whether a process with the given name is (likely) currently running.
Uses a basic text search, and so may have false positives.
"""
# From http://www.bloggerpolis.com/2011/05/how-to-check-if-a-process-is-running-using-python/
# and http://richarddingwall.name/2009/06/18/windows-equivalents-of-ps-and-kill-commands/
try: # Linux/Unix
s = subprocess.Popen(["ps", "axw"], stdout=subprocess.PIPE)
except: # Windows
s = subprocess.Popen(["tasklist", "/v"], stdout=subprocess.PIPE)
assert s.stdout is not None
for x in s.stdout:
# if re.search(process, x):
if process in str(x):
return True
return False
def set_wallpaper(file_loc: str, first_run: bool = True):
"""Sets the wallpaper to the given file location."""
# From https://stackoverflow.com/a/21213504/2624876
# I have not personally tested most of this. -- @1j01
# -----------------------------------------
# Note: There are two common Linux desktop environments where
# I have not been able to set the desktop background from
# command line: KDE, Enlightenment
desktop_env = get_desktop_environment()
if desktop_env in ["gnome", "unity", "cinnamon"]:
# Tested on Ubuntu 22 -- @1j01
uri = Path(file_loc).as_uri()
SCHEMA = "org.gnome.desktop.background"
KEY = "picture-uri"
# Needed for Ubuntu 22 in dark mode
# Might be better to set only one or the other, depending on the current theme
# In the settings it will say "This background selection only applies to the dark style"
# even if it's set for both, arguably referring to the selection that you can make on that page.
# -- @1j01
KEY_DARK = "picture-uri-dark"
try:
from gi.repository import Gio # type: ignore
gsettings = Gio.Settings.new(SCHEMA) # type: ignore
gsettings.set_string(KEY, uri)
gsettings.set_string(KEY_DARK, uri)
except Exception:
# Fallback tested on Ubuntu 22 -- @1j01
args = ["gsettings", "set", SCHEMA, KEY, uri]
subprocess.Popen(args)
args = ["gsettings", "set", SCHEMA, KEY_DARK, uri]
subprocess.Popen(args)
elif desktop_env == "mate":
try: # MATE >= 1.6
# info from http://wiki.mate-desktop.org/docs:gsettings
args = ["gsettings", "set", "org.mate.background", "picture-filename", file_loc]
subprocess.Popen(args)
except Exception: # MATE < 1.6
# From https://bugs.launchpad.net/variety/+bug/1033918
args = [
"mateconftool-2",
"-t",
"string",
"--set",
"/desktop/mate/background/picture_filename",
file_loc,
]
subprocess.Popen(args)
elif desktop_env == "gnome2": # Not tested
# From https://bugs.launchpad.net/variety/+bug/1033918
args = [
"gconftool-2",
"-t",
"string",
"--set",
"/desktop/gnome/background/picture_filename",
file_loc,
]
subprocess.Popen(args)
## KDE4 is difficult
## see http://blog.zx2c4.com/699 for a solution that might work
elif desktop_env in ["kde3", "trinity"]:
# From http://ubuntuforums.org/archive/index.php/t-803417.html
args = ["dcop", "kdesktop", "KBackgroundIface", "setWallpaper", "0", file_loc, "6"]
subprocess.Popen(args)
elif desktop_env == "xfce4":
# Iterate over all wallpaper-related keys and set to file_loc
try:
list_proc = subprocess.run(
["xfconf-query", "-c", "xfce4-desktop", "-l"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=False,
)
keys = []
if list_proc.stdout:
for line in list_proc.stdout.splitlines():
line = line.strip()
if not line:
continue
# Common keys: .../last-image and .../image-path
if "/last-image" in line or "/image-path" in line:
keys.append(line)
# Fallback: known defaults if none were listed
if not keys:
keys = [
"/backdrop/screen0/monitorVNC-0/workspace0/last-image",
"/backdrop/screen0/monitor0/image-path",
]
for key in keys:
subprocess.run(
[
"xfconf-query",
"-c",
"xfce4-desktop",
"-p",
key,
"-s",
file_loc,
],
check=False,
)
except Exception:
pass
# Reload xfdesktop to apply changes
subprocess.Popen(["xfdesktop", "--reload"])
elif desktop_env == "razor-qt": # TODO: implement reload of desktop when possible
if first_run:
import configparser
desktop_conf = configparser.ConfigParser()
# Development version
desktop_conf_file = os.path.join(get_config_dir("razor"), "desktop.conf")
if os.path.isfile(desktop_conf_file):
config_option = R"screens\1\desktops\1\wallpaper"
else:
desktop_conf_file = os.path.join(get_home_dir(), ".razor/desktop.conf")
config_option = R"desktops\1\wallpaper"
desktop_conf.read(os.path.join(desktop_conf_file))
try:
if desktop_conf.has_option("razor", config_option): # only replacing a value
desktop_conf.set("razor", config_option, file_loc)
with open(desktop_conf_file, "w", encoding="utf-8", errors="replace") as f:
desktop_conf.write(f)
except Exception:
pass
else:
# TODO: reload desktop when possible
pass
elif desktop_env in ["fluxbox", "jwm", "openbox", "afterstep"]:
# http://fluxbox-wiki.org/index.php/Howto_set_the_background
# used fbsetbg on jwm too since I am too lazy to edit the XML configuration
# now where fbsetbg does the job excellent anyway.
# and I have not figured out how else it can be set on Openbox and AfterSTep
# but fbsetbg works excellent here too.
try:
args = ["fbsetbg", file_loc]
subprocess.Popen(args)
except Exception:
sys.stderr.write("ERROR: Failed to set wallpaper with fbsetbg!\n")
sys.stderr.write("Please make sre that You have fbsetbg installed.\n")
elif desktop_env == "icewm":
# command found at http://urukrama.wordpress.com/2007/12/05/desktop-backgrounds-in-window-managers/
args = ["icewmbg", file_loc]
subprocess.Popen(args)
elif desktop_env == "blackbox":
# command found at http://blackboxwm.sourceforge.net/BlackboxDocumentation/BlackboxBackground
args = ["bsetbg", "-full", file_loc]
subprocess.Popen(args)
elif desktop_env == "lxde":
args = ["pcmanfm", "--set-wallpaper", file_loc, "--wallpaper-mode=scaled"]
subprocess.Popen(args)
elif desktop_env == "windowmaker":
# From http://www.commandlinefu.com/commands/view/3857/set-wallpaper-on-windowmaker-in-one-line
args = ["wmsetbg", "-s", "-u", file_loc]
subprocess.Popen(args)
# elif desktop_env == "enlightenment": # I have not been able to make it work on e17. On e16 it would have been something in this direction
# args = ["enlightenment_remote", "-desktop-bg-add", "0", "0", "0", "0", file_loc]
# subprocess.Popen(args)
elif desktop_env == "windows":
# From https://stackoverflow.com/questions/1977694/change-desktop-background
# Tested on Windows 10. -- @1j01
import ctypes
SPI_SETDESKWALLPAPER = 20
ctypes.windll.user32.SystemParametersInfoW(SPI_SETDESKWALLPAPER, 0, file_loc, 0) # type: ignore
elif desktop_env == "mac":
# From https://stackoverflow.com/questions/431205/how-can-i-programatically-change-the-background-in-mac-os-x
try:
# Tested on macOS 10.14.6 (Mojave) -- @1j01
assert (
sys.platform == "darwin"
) # ignore `Import "appscript" could not be resolved` for other platforms
from appscript import app, mactypes
app("Finder").desktop_picture.set(mactypes.File(file_loc))
except ImportError:
# Tested on macOS 10.14.6 (Mojave) -- @1j01
# import subprocess
# SCRIPT = f"""/usr/bin/osascript<<END
# tell application "Finder" to set desktop picture to POSIX file "{file_loc}"
# END"""
# subprocess.Popen(SCRIPT, shell=True)
# Safer version, avoiding string interpolation,
# to protect against command injection (both in the shell and in AppleScript):
OSASCRIPT = """
on run (clp)
if clp's length is not 1 then error "Incorrect Parameters"
local file_loc
set file_loc to clp's item 1
tell application "Finder" to set desktop picture to POSIX file file_loc
end run
"""
subprocess.Popen(["osascript", "-e", OSASCRIPT, "--", file_loc])
else:
if first_run: # don't spam the user with the same message over and over again
sys.stderr.write(
"Warning: Failed to set wallpaper. Your desktop environment is not supported."
)
sys.stderr.write(f"You can try manually to set your wallpaper to {file_loc}")
return False
return True
def get_config_dir(app_name: str) -> str:
"""Returns the configuration directory for the given application name."""
if "XDG_CONFIG_HOME" in os.environ:
config_home = os.environ["XDG_CONFIG_HOME"]
elif "APPDATA" in os.environ: # On Windows
config_home = os.environ["APPDATA"]
else:
try:
from xdg import BaseDirectory
config_home = BaseDirectory.xdg_config_home
except ImportError: # Most likely a Linux/Unix system anyway
config_home = os.path.join(get_home_dir(), ".config")
config_dir = os.path.join(config_home, app_name)
return config_dir
def get_home_dir() -> str:
"""Returns the home directory of the current user."""
return os.path.expanduser("~")

View File

@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
[project]
name = "cua-computer-server"
version = "0.1.25"
version = "0.1.28"
description = "Server component for the Computer-Use Interface (CUI) framework powering Cua"
authors = [
@@ -22,7 +22,15 @@ dependencies = [
"pillow>=10.2.0",
"aiohttp>=3.9.1",
"pyperclip>=1.9.0",
"websockets>=12.0"
"websockets>=12.0",
"pywinctl>=0.4.1",
# OS-specific runtime deps
"pyobjc-framework-Cocoa>=10.1; sys_platform == 'darwin'",
"pyobjc-framework-Quartz>=10.1; sys_platform == 'darwin'",
"pyobjc-framework-ApplicationServices>=10.1; sys_platform == 'darwin'",
"python-xlib>=0.33; sys_platform == 'linux'",
"pywin32>=310; sys_platform == 'win32'",
"pip-system-certs; sys_platform == 'win32'",
]
[project.optional-dependencies]

View File

@@ -0,0 +1,47 @@
"""Pytest configuration and shared fixtures for computer-server package tests.
This file contains shared fixtures and configuration for all computer-server tests.
Following SRP: This file ONLY handles test setup/teardown.
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
@pytest.fixture
def mock_websocket():
"""Mock WebSocket connection for testing.
Use this fixture to test WebSocket logic without real connections.
"""
websocket = AsyncMock()
websocket.send = AsyncMock()
websocket.recv = AsyncMock()
websocket.close = AsyncMock()
return websocket
@pytest.fixture
def mock_computer_interface():
"""Mock computer interface for server tests.
Use this fixture to test server logic without real computer operations.
"""
interface = AsyncMock()
interface.screenshot = AsyncMock(return_value=b"fake_screenshot")
interface.left_click = AsyncMock()
interface.type = AsyncMock()
interface.key = AsyncMock()
return interface
@pytest.fixture
def disable_telemetry(monkeypatch):
"""Disable telemetry for tests.
Use this fixture to ensure no telemetry is sent during tests.
"""
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")

View File

@@ -0,0 +1,40 @@
"""Unit tests for computer-server package.
This file tests ONLY basic server functionality.
Following SRP: This file tests server initialization and basic operations.
All external dependencies are mocked.
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
class TestServerImports:
"""Test server module imports (SRP: Only tests imports)."""
def test_server_module_exists(self):
"""Test that server module can be imported."""
try:
import computer_server
assert computer_server is not None
except ImportError:
pytest.skip("computer_server module not installed")
class TestServerInitialization:
"""Test server initialization (SRP: Only tests initialization)."""
@pytest.mark.asyncio
async def test_server_can_be_imported(self):
"""Basic smoke test: verify server components can be imported."""
try:
from computer_server import server
assert server is not None
except ImportError:
pytest.skip("Server module not available")
except Exception as e:
# Some initialization errors are acceptable in unit tests
pytest.skip(f"Server initialization requires specific setup: {e}")

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.7
current_version = 0.4.11
commit = True
tag = True
tag_name = computer-v{new_version}

View File

@@ -436,6 +436,189 @@ class BaseComputerInterface(ABC):
"""
pass
# Desktop actions
@abstractmethod
async def get_desktop_environment(self) -> str:
"""Get the current desktop environment.
Returns:
The name of the current desktop environment.
"""
pass
@abstractmethod
async def set_wallpaper(self, path: str) -> None:
"""Set the desktop wallpaper to the specified path.
Args:
path: The file path to set as wallpaper
"""
pass
# Window management
@abstractmethod
async def open(self, target: str) -> None:
"""Open a target using the system's default handler.
Typically opens files, folders, or URLs with the associated application.
Args:
target: The file path, folder path, or URL to open.
"""
pass
@abstractmethod
async def launch(self, app: str, args: List[str] | None = None) -> Optional[int]:
"""Launch an application with optional arguments.
Args:
app: The application executable or bundle identifier.
args: Optional list of arguments to pass to the application.
Returns:
Optional process ID (PID) of the launched application if available, otherwise None.
"""
pass
@abstractmethod
async def get_current_window_id(self) -> int | str:
"""Get the identifier of the currently active/focused window.
Returns:
A window identifier that can be used with other window management methods.
"""
pass
@abstractmethod
async def get_application_windows(self, app: str) -> List[int | str]:
"""Get all window identifiers for a specific application.
Args:
app: The application name, executable, or identifier to query.
Returns:
A list of window identifiers belonging to the specified application.
"""
pass
@abstractmethod
async def get_window_name(self, window_id: int | str) -> str:
"""Get the title/name of a window.
Args:
window_id: The window identifier.
Returns:
The window's title or name string.
"""
pass
@abstractmethod
async def get_window_size(self, window_id: int | str) -> tuple[int, int]:
"""Get the size of a window in pixels.
Args:
window_id: The window identifier.
Returns:
A tuple of (width, height) representing the window size in pixels.
"""
pass
@abstractmethod
async def get_window_position(self, window_id: int | str) -> tuple[int, int]:
"""Get the screen position of a window.
Args:
window_id: The window identifier.
Returns:
A tuple of (x, y) representing the window's top-left corner in screen coordinates.
"""
pass
@abstractmethod
async def set_window_size(self, window_id: int | str, width: int, height: int) -> None:
"""Set the size of a window in pixels.
Args:
window_id: The window identifier.
width: Desired width in pixels.
height: Desired height in pixels.
"""
pass
@abstractmethod
async def set_window_position(self, window_id: int | str, x: int, y: int) -> None:
"""Move a window to a specific position on the screen.
Args:
window_id: The window identifier.
x: X coordinate for the window's top-left corner.
y: Y coordinate for the window's top-left corner.
"""
pass
@abstractmethod
async def maximize_window(self, window_id: int | str) -> None:
"""Maximize a window.
Args:
window_id: The window identifier.
"""
pass
@abstractmethod
async def minimize_window(self, window_id: int | str) -> None:
"""Minimize a window.
Args:
window_id: The window identifier.
"""
pass
@abstractmethod
async def activate_window(self, window_id: int | str) -> None:
"""Bring a window to the foreground and focus it.
Args:
window_id: The window identifier.
"""
pass
@abstractmethod
async def close_window(self, window_id: int | str) -> None:
"""Close a window.
Args:
window_id: The window identifier.
"""
pass
# Convenience aliases
async def get_window_title(self, window_id: int | str) -> str:
"""Convenience alias for get_window_name().
Args:
window_id: The window identifier.
Returns:
The window's title or name string.
"""
return await self.get_window_name(window_id)
async def window_size(self, window_id: int | str) -> tuple[int, int]:
"""Convenience alias for get_window_size().
Args:
window_id: The window identifier.
Returns:
A tuple of (width, height) representing the window size in pixels.
"""
return await self.get_window_size(window_id)
# Shell actions
@abstractmethod
async def run_command(self, command: str) -> CommandResult:
"""Run shell command and return structured result.

View File

@@ -487,6 +487,104 @@ class GenericComputerInterface(BaseComputerInterface):
raise RuntimeError(result.get("error", "Failed to list directory"))
return result.get("files", [])
# Desktop actions
async def get_desktop_environment(self) -> str:
result = await self._send_command("get_desktop_environment")
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to get desktop environment"))
return result.get("environment", "unknown")
async def set_wallpaper(self, path: str) -> None:
result = await self._send_command("set_wallpaper", {"path": path})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to set wallpaper"))
# Window management
async def open(self, target: str) -> None:
result = await self._send_command("open", {"target": target})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to open target"))
async def launch(self, app: str, args: list[str] | None = None) -> int | None:
payload: dict[str, object] = {"app": app}
if args is not None:
payload["args"] = args
result = await self._send_command("launch", payload)
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to launch application"))
return result.get("pid") # type: ignore[return-value]
async def get_current_window_id(self) -> int | str:
result = await self._send_command("get_current_window_id")
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to get current window id"))
return result["window_id"] # type: ignore[return-value]
async def get_application_windows(self, app: str) -> list[int | str]:
result = await self._send_command("get_application_windows", {"app": app})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to get application windows"))
return list(result.get("windows", [])) # type: ignore[return-value]
async def get_window_name(self, window_id: int | str) -> str:
result = await self._send_command("get_window_name", {"window_id": window_id})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to get window name"))
return result.get("name", "") # type: ignore[return-value]
async def get_window_size(self, window_id: int | str) -> tuple[int, int]:
result = await self._send_command("get_window_size", {"window_id": window_id})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to get window size"))
return int(result.get("width", 0)), int(result.get("height", 0))
async def get_window_position(self, window_id: int | str) -> tuple[int, int]:
result = await self._send_command("get_window_position", {"window_id": window_id})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to get window position"))
return int(result.get("x", 0)), int(result.get("y", 0))
async def set_window_size(self, window_id: int | str, width: int, height: int) -> None:
result = await self._send_command(
"set_window_size", {"window_id": window_id, "width": width, "height": height}
)
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to set window size"))
async def set_window_position(self, window_id: int | str, x: int, y: int) -> None:
result = await self._send_command(
"set_window_position", {"window_id": window_id, "x": x, "y": y}
)
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to set window position"))
async def maximize_window(self, window_id: int | str) -> None:
result = await self._send_command("maximize_window", {"window_id": window_id})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to maximize window"))
async def minimize_window(self, window_id: int | str) -> None:
result = await self._send_command("minimize_window", {"window_id": window_id})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to minimize window"))
async def activate_window(self, window_id: int | str) -> None:
result = await self._send_command("activate_window", {"window_id": window_id})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to activate window"))
async def close_window(self, window_id: int | str) -> None:
result = await self._send_command("close_window", {"window_id": window_id})
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to close window"))
# Convenience aliases
async def get_window_title(self, window_id: int | str) -> str:
return await self.get_window_name(window_id)
async def window_size(self, window_id: int | str) -> tuple[int, int]:
return await self.get_window_size(window_id)
# Command execution
async def run_command(self, command: str) -> CommandResult:
result = await self._send_command("run_command", {"command": command})

View File

@@ -258,14 +258,20 @@ class DockerProvider(BaseVMProvider):
logger.info(f"Container {name} is already running")
return existing_vm
elif existing_vm["status"] in ["stopped", "paused"]:
# Start existing container
logger.info(f"Starting existing container {name}")
start_cmd = ["docker", "start", name]
result = subprocess.run(start_cmd, capture_output=True, text=True, check=True)
if self.ephemeral:
# Delete existing container
logger.info(f"Deleting existing container {name}")
delete_cmd = ["docker", "rm", name]
result = subprocess.run(delete_cmd, capture_output=True, text=True, check=True)
else:
# Start existing container
logger.info(f"Starting existing container {name}")
start_cmd = ["docker", "start", name]
result = subprocess.run(start_cmd, capture_output=True, text=True, check=True)
# Wait for container to be ready
await self._wait_for_container_ready(name)
return await self.get_vm(name, storage)
# Wait for container to be ready
await self._wait_for_container_ready(name)
return await self.get_vm(name, storage)
# Use provided image or default
docker_image = image if image != "default" else self.image
@@ -307,6 +313,20 @@ class DockerProvider(BaseVMProvider):
cmd.extend(["-e", "VNC_PW=password"]) # Set VNC password
cmd.extend(["-e", "VNCOPTIONS=-disableBasicAuth"]) # Disable VNC basic auth
# Apply display resolution if provided (e.g., "1024x768")
display_resolution = run_opts.get("display")
if (
isinstance(display_resolution, dict)
and "width" in display_resolution
and "height" in display_resolution
):
cmd.extend(
[
"-e",
f"VNC_RESOLUTION={display_resolution['width']}x{display_resolution['height']}",
]
)
# Add the image
cmd.append(docker_image)
@@ -388,6 +408,11 @@ class DockerProvider(BaseVMProvider):
logger.info(f"Container {name} stopped successfully")
# Delete container if ephemeral=True
if self.ephemeral:
cmd = ["docker", "rm", name]
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
return {
"name": name,
"status": "stopped",

View File

@@ -10,6 +10,8 @@ import subprocess
import urllib.parse
from typing import Any, Dict, List, Optional
from computer.utils import safe_join
# Setup logging
logger = logging.getLogger(__name__)
@@ -59,7 +61,7 @@ def lume_api_get(
# --max-time: Maximum time for the whole operation (20 seconds)
# -f: Fail silently (no output at all) on server errors
# Add single quotes around URL to ensure special characters are handled correctly
cmd = ["curl", "--connect-timeout", "15", "--max-time", "20", "-s", "-f", f"'{api_url}'"]
cmd = ["curl", "--connect-timeout", "15", "--max-time", "20", "-s", "-f", api_url]
# For logging and display, show the properly escaped URL
display_cmd = ["curl", "--connect-timeout", "15", "--max-time", "20", "-s", "-f", api_url]
@@ -71,7 +73,7 @@ def lume_api_get(
# Execute the command - for execution we need to use shell=True to handle URLs with special characters
try:
# Use a single string with shell=True for proper URL handling
shell_cmd = " ".join(cmd)
shell_cmd = safe_join(cmd)
result = subprocess.run(shell_cmd, shell=True, capture_output=True, text=True)
# Handle curl exit codes
@@ -514,7 +516,7 @@ def lume_api_delete(
"-s",
"-X",
"DELETE",
f"'{api_url}'",
api_url,
]
# For logging and display, show the properly escaped URL
@@ -537,7 +539,7 @@ def lume_api_delete(
# Execute the command - for execution we need to use shell=True to handle URLs with special characters
try:
# Use a single string with shell=True for proper URL handling
shell_cmd = " ".join(cmd)
shell_cmd = safe_join(cmd)
result = subprocess.run(shell_cmd, shell=True, capture_output=True, text=True)
# Handle curl exit codes

View File

@@ -1,7 +1,10 @@
import base64
import io
import os
import shlex
from typing import Any, Dict, Optional, Tuple
import mslex
from PIL import Image, ImageDraw
@@ -104,3 +107,25 @@ def parse_vm_info(vm_info: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Parse VM info from pylume response."""
if not vm_info:
return None
def safe_join(argv: list[str]) -> str:
"""
Return a platform-correct string that safely quotes `argv` for shell execution.
- On POSIX: uses `shlex.join`.
- On Windows: uses `shlex.join`.
Args:
argv: iterable of argument strings (will be coerced to str).
Returns:
A safely quoted command-line string appropriate for the current platform that protects against
shell injection vulnerabilities.
"""
if os.name == "nt":
# On Windows, use mslex for proper quoting
return mslex.join(argv)
else:
# On POSIX systems, use shlex
return shlex.join(argv)

View File

@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
[project]
name = "cua-computer"
version = "0.4.10"
version = "0.4.11"
description = "Computer-Use Interface (CUI) framework powering Cua"
readme = "README.md"
authors = [
@@ -16,7 +16,8 @@ dependencies = [
"websockets>=12.0",
"aiohttp>=3.9.0",
"cua-core>=0.1.0,<0.2.0",
"pydantic>=2.11.1"
"pydantic>=2.11.1",
"mslex>=1.3.0",
]
requires-python = ">=3.12"
@@ -47,4 +48,4 @@ source-includes = ["tests/", "README.md", "LICENSE"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
python_files = "test_*.py"
python_files = "test_*.py"

View File

@@ -0,0 +1,69 @@
"""Pytest configuration and shared fixtures for computer package tests.
This file contains shared fixtures and configuration for all computer tests.
Following SRP: This file ONLY handles test setup/teardown.
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
@pytest.fixture
def mock_interface():
"""Mock computer interface for testing.
Use this fixture to test Computer logic without real OS calls.
"""
interface = AsyncMock()
interface.screenshot = AsyncMock(return_value=b"fake_screenshot")
interface.left_click = AsyncMock()
interface.right_click = AsyncMock()
interface.middle_click = AsyncMock()
interface.double_click = AsyncMock()
interface.type = AsyncMock()
interface.key = AsyncMock()
interface.move_mouse = AsyncMock()
interface.scroll = AsyncMock()
interface.get_screen_size = AsyncMock(return_value=(1920, 1080))
return interface
@pytest.fixture
def mock_cloud_provider():
"""Mock cloud provider for testing.
Use this fixture to test cloud provider logic without real API calls.
"""
provider = AsyncMock()
provider.start = AsyncMock()
provider.stop = AsyncMock()
provider.get_status = AsyncMock(return_value="running")
provider.execute_command = AsyncMock(return_value="command output")
return provider
@pytest.fixture
def mock_local_provider():
"""Mock local provider for testing.
Use this fixture to test local provider logic without real VM operations.
"""
provider = AsyncMock()
provider.start = AsyncMock()
provider.stop = AsyncMock()
provider.get_status = AsyncMock(return_value="running")
provider.execute_command = AsyncMock(return_value="command output")
return provider
@pytest.fixture
def disable_telemetry(monkeypatch):
"""Disable telemetry for tests.
Use this fixture to ensure no telemetry is sent during tests.
"""
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")

View File

@@ -0,0 +1,67 @@
"""Unit tests for Computer class.
This file tests ONLY the Computer class initialization and context manager.
Following SRP: This file tests ONE class (Computer).
All external dependencies (providers, interfaces) are mocked.
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
class TestComputerImport:
"""Test Computer module imports (SRP: Only tests imports)."""
def test_computer_class_exists(self):
"""Test that Computer class can be imported."""
from computer import Computer
assert Computer is not None
def test_vm_provider_type_exists(self):
"""Test that VMProviderType enum can be imported."""
from computer import VMProviderType
assert VMProviderType is not None
class TestComputerInitialization:
"""Test Computer initialization (SRP: Only tests initialization)."""
def test_computer_class_can_be_imported(self, disable_telemetry):
"""Test that Computer class can be imported without errors."""
from computer import Computer
assert Computer is not None
def test_computer_has_required_methods(self, disable_telemetry):
"""Test that Computer class has required methods."""
from computer import Computer
assert hasattr(Computer, "__aenter__")
assert hasattr(Computer, "__aexit__")
class TestComputerContextManager:
"""Test Computer context manager protocol (SRP: Only tests context manager)."""
def test_computer_is_async_context_manager(self, disable_telemetry):
"""Test that Computer has async context manager methods."""
from computer import Computer
assert hasattr(Computer, "__aenter__")
assert hasattr(Computer, "__aexit__")
assert callable(Computer.__aenter__)
assert callable(Computer.__aexit__)
class TestComputerInterface:
"""Test Computer.interface property (SRP: Only tests interface access)."""
def test_computer_class_structure(self, disable_telemetry):
"""Test that Computer class has expected structure."""
from computer import Computer
# Verify Computer is a class
assert isinstance(Computer, type)

View File

@@ -0,0 +1,43 @@
"""Pytest configuration and shared fixtures for core package tests.
This file contains shared fixtures and configuration for all core tests.
Following SRP: This file ONLY handles test setup/teardown.
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
@pytest.fixture
def mock_httpx_client():
"""Mock httpx.AsyncClient for API calls.
Use this fixture to avoid making real HTTP requests during tests.
"""
with patch("httpx.AsyncClient") as mock_client:
mock_instance = AsyncMock()
mock_client.return_value.__aenter__.return_value = mock_instance
yield mock_instance
@pytest.fixture
def mock_posthog():
"""Mock PostHog client for telemetry tests.
Use this fixture to avoid sending real telemetry during tests.
"""
with patch("posthog.Posthog") as mock_ph:
mock_instance = Mock()
mock_ph.return_value = mock_instance
yield mock_instance
@pytest.fixture
def disable_telemetry(monkeypatch):
"""Disable telemetry for tests that don't need it.
Use this fixture to ensure telemetry is disabled during tests.
"""
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
yield

View File

@@ -0,0 +1,255 @@
"""Unit tests for core telemetry functionality.
This file tests ONLY telemetry logic, following SRP.
All external dependencies (PostHog, file system) are mocked.
"""
import os
from pathlib import Path
from unittest.mock import MagicMock, Mock, mock_open, patch
import pytest
class TestTelemetryEnabled:
"""Test telemetry enable/disable logic (SRP: Only tests enable/disable)."""
def test_telemetry_enabled_by_default(self, monkeypatch):
"""Test that telemetry is enabled by default."""
# Remove any environment variables that might affect the test
monkeypatch.delenv("CUA_TELEMETRY", raising=False)
monkeypatch.delenv("CUA_TELEMETRY_ENABLED", raising=False)
from core.telemetry import is_telemetry_enabled
assert is_telemetry_enabled() is True
def test_telemetry_disabled_with_legacy_flag(self, monkeypatch):
"""Test that telemetry can be disabled with legacy CUA_TELEMETRY=off."""
monkeypatch.setenv("CUA_TELEMETRY", "off")
from core.telemetry import is_telemetry_enabled
assert is_telemetry_enabled() is False
def test_telemetry_disabled_with_new_flag(self, monkeypatch):
"""Test that telemetry can be disabled with CUA_TELEMETRY_ENABLED=false."""
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", "false")
from core.telemetry import is_telemetry_enabled
assert is_telemetry_enabled() is False
@pytest.mark.parametrize("value", ["0", "false", "no", "off"])
def test_telemetry_disabled_with_various_values(self, monkeypatch, value):
"""Test that telemetry respects various disable values."""
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", value)
from core.telemetry import is_telemetry_enabled
assert is_telemetry_enabled() is False
@pytest.mark.parametrize("value", ["1", "true", "yes", "on"])
def test_telemetry_enabled_with_various_values(self, monkeypatch, value):
"""Test that telemetry respects various enable values."""
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", value)
from core.telemetry import is_telemetry_enabled
assert is_telemetry_enabled() is True
class TestPostHogTelemetryClient:
"""Test PostHogTelemetryClient class (SRP: Only tests client logic)."""
@patch("core.telemetry.posthog.posthog")
@patch("core.telemetry.posthog.Path")
def test_client_initialization(self, mock_path, mock_posthog, disable_telemetry):
"""Test that client initializes correctly."""
from core.telemetry.posthog import PostHogTelemetryClient
# Mock the storage directory
mock_storage_dir = MagicMock()
mock_storage_dir.exists.return_value = False
mock_path.return_value.parent.parent = MagicMock()
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
# Reset singleton
PostHogTelemetryClient.destroy_client()
client = PostHogTelemetryClient()
assert client is not None
assert hasattr(client, "installation_id")
assert hasattr(client, "initialized")
assert hasattr(client, "queued_events")
@patch("core.telemetry.posthog.posthog")
@patch("core.telemetry.posthog.Path")
def test_installation_id_generation(self, mock_path, mock_posthog, disable_telemetry):
"""Test that installation ID is generated if not exists."""
from core.telemetry.posthog import PostHogTelemetryClient
# Mock file system
mock_id_file = MagicMock()
mock_id_file.exists.return_value = False
mock_storage_dir = MagicMock()
mock_storage_dir.__truediv__.return_value = mock_id_file
mock_core_dir = MagicMock()
mock_core_dir.__truediv__.return_value = mock_storage_dir
mock_path.return_value.parent.parent = mock_core_dir
# Reset singleton
PostHogTelemetryClient.destroy_client()
client = PostHogTelemetryClient()
# Should have generated a new UUID
assert client.installation_id is not None
assert len(client.installation_id) == 36 # UUID format
@patch("core.telemetry.posthog.posthog")
@patch("core.telemetry.posthog.Path")
def test_installation_id_persistence(self, mock_path, mock_posthog, disable_telemetry):
"""Test that installation ID is read from file if exists."""
from core.telemetry.posthog import PostHogTelemetryClient
existing_id = "test-installation-id-123"
# Mock file system
mock_id_file = MagicMock()
mock_id_file.exists.return_value = True
mock_id_file.read_text.return_value = existing_id
mock_storage_dir = MagicMock()
mock_storage_dir.__truediv__.return_value = mock_id_file
mock_core_dir = MagicMock()
mock_core_dir.__truediv__.return_value = mock_storage_dir
mock_path.return_value.parent.parent = mock_core_dir
# Reset singleton
PostHogTelemetryClient.destroy_client()
client = PostHogTelemetryClient()
assert client.installation_id == existing_id
@patch("core.telemetry.posthog.posthog")
@patch("core.telemetry.posthog.Path")
def test_record_event_when_disabled(self, mock_path, mock_posthog, monkeypatch):
"""Test that events are not recorded when telemetry is disabled."""
from core.telemetry.posthog import PostHogTelemetryClient
# Disable telemetry explicitly using the correct environment variable
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", "false")
# Mock file system
mock_storage_dir = MagicMock()
mock_storage_dir.exists.return_value = False
mock_path.return_value.parent.parent = MagicMock()
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
# Reset singleton
PostHogTelemetryClient.destroy_client()
client = PostHogTelemetryClient()
client.record_event("test_event", {"key": "value"})
# PostHog capture should not be called at all when telemetry is disabled
mock_posthog.capture.assert_not_called()
@patch("core.telemetry.posthog.posthog")
@patch("core.telemetry.posthog.Path")
def test_record_event_when_enabled(self, mock_path, mock_posthog, monkeypatch):
"""Test that events are recorded when telemetry is enabled."""
from core.telemetry.posthog import PostHogTelemetryClient
# Enable telemetry
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", "true")
# Mock file system
mock_storage_dir = MagicMock()
mock_storage_dir.exists.return_value = False
mock_path.return_value.parent.parent = MagicMock()
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
# Reset singleton
PostHogTelemetryClient.destroy_client()
client = PostHogTelemetryClient()
client.initialized = True # Pretend it's initialized
event_name = "test_event"
event_props = {"key": "value"}
client.record_event(event_name, event_props)
# PostHog capture should be called
assert mock_posthog.capture.call_count >= 1
@patch("core.telemetry.posthog.posthog")
@patch("core.telemetry.posthog.Path")
def test_singleton_pattern(self, mock_path, mock_posthog, disable_telemetry):
"""Test that get_client returns the same instance."""
from core.telemetry.posthog import PostHogTelemetryClient
# Mock file system
mock_storage_dir = MagicMock()
mock_storage_dir.exists.return_value = False
mock_path.return_value.parent.parent = MagicMock()
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
# Reset singleton
PostHogTelemetryClient.destroy_client()
client1 = PostHogTelemetryClient.get_client()
client2 = PostHogTelemetryClient.get_client()
assert client1 is client2
class TestRecordEvent:
"""Test the public record_event function (SRP: Only tests public API)."""
@patch("core.telemetry.posthog.PostHogTelemetryClient")
def test_record_event_calls_client(self, mock_client_class, disable_telemetry):
"""Test that record_event delegates to the client."""
from core.telemetry import record_event
mock_client_instance = Mock()
mock_client_class.get_client.return_value = mock_client_instance
event_name = "test_event"
event_props = {"key": "value"}
record_event(event_name, event_props)
mock_client_instance.record_event.assert_called_once_with(event_name, event_props)
@patch("core.telemetry.posthog.PostHogTelemetryClient")
def test_record_event_without_properties(self, mock_client_class, disable_telemetry):
"""Test that record_event works without properties."""
from core.telemetry import record_event
mock_client_instance = Mock()
mock_client_class.get_client.return_value = mock_client_instance
event_name = "test_event"
record_event(event_name)
mock_client_instance.record_event.assert_called_once_with(event_name, {})
class TestDestroyTelemetryClient:
"""Test client destruction (SRP: Only tests cleanup)."""
@patch("core.telemetry.posthog.PostHogTelemetryClient")
def test_destroy_client_calls_class_method(self, mock_client_class):
"""Test that destroy_telemetry_client delegates correctly."""
from core.telemetry import destroy_telemetry_client
destroy_telemetry_client()
mock_client_class.destroy_client.assert_called_once()

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.1.14
current_version = 0.1.15
commit = True
tag = True
tag_name = mcp-server-v{new_version}

View File

@@ -7,7 +7,7 @@ name = "cua-mcp-server"
description = "MCP Server for Computer-Use Agent (CUA)"
readme = "README.md"
requires-python = ">=3.12"
version = "0.1.14"
version = "0.1.15"
authors = [
{name = "TryCua", email = "gh@trycua.com"}
]

View File

@@ -0,0 +1,51 @@
"""Pytest configuration and shared fixtures for mcp-server package tests.
This file contains shared fixtures and configuration for all mcp-server tests.
Following SRP: This file ONLY handles test setup/teardown.
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
@pytest.fixture
def mock_mcp_context():
"""Mock MCP context for testing.
Use this fixture to test MCP server logic without real MCP connections.
"""
context = AsyncMock()
context.request_context = AsyncMock()
context.session = Mock()
context.session.send_resource_updated = AsyncMock()
return context
@pytest.fixture
def mock_computer():
"""Mock Computer instance for MCP server tests.
Use this fixture to test MCP logic without real Computer operations.
"""
computer = AsyncMock()
computer.interface = AsyncMock()
computer.interface.screenshot = AsyncMock(return_value=b"fake_screenshot")
computer.interface.left_click = AsyncMock()
computer.interface.type = AsyncMock()
# Mock context manager
computer.__aenter__ = AsyncMock(return_value=computer)
computer.__aexit__ = AsyncMock()
return computer
@pytest.fixture
def disable_telemetry(monkeypatch):
"""Disable telemetry for tests.
Use this fixture to ensure no telemetry is sent during tests.
"""
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")

View File

@@ -0,0 +1,44 @@
"""Unit tests for mcp-server package.
This file tests ONLY basic MCP server functionality.
Following SRP: This file tests MCP server initialization.
All external dependencies are mocked.
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
class TestMCPServerImports:
"""Test MCP server module imports (SRP: Only tests imports)."""
def test_mcp_server_module_exists(self):
"""Test that mcp_server module can be imported."""
try:
import mcp_server
assert mcp_server is not None
except ImportError:
pytest.skip("mcp_server module not installed")
except SystemExit:
pytest.skip("MCP dependencies (mcp.server.fastmcp) not available")
class TestMCPServerInitialization:
"""Test MCP server initialization (SRP: Only tests initialization)."""
@pytest.mark.asyncio
async def test_mcp_server_can_be_imported(self):
"""Basic smoke test: verify MCP server components can be imported."""
try:
from mcp_server import server
assert server is not None
except ImportError:
pytest.skip("MCP server module not available")
except SystemExit:
pytest.skip("MCP dependencies (mcp.server.fastmcp) not available")
except Exception as e:
# Some initialization errors are acceptable in unit tests
pytest.skip(f"MCP server initialization requires specific setup: {e}")

View File

@@ -1,10 +0,0 @@
[bumpversion]
current_version = 0.2.1
commit = True
tag = True
tag_name = pylume-v{new_version}
message = Bump pylume to v{new_version}
[bumpversion:file:pylume/__init__.py]
search = __version__ = "{current_version}"
replace = __version__ = "{new_version}"

View File

@@ -1,46 +0,0 @@
<div align="center">
<h1>
<div class="image-wrapper" style="display: inline-block;">
<picture>
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="https://raw.githubusercontent.com/trycua/cua/main/img/logo_white.png" style="display: block; margin: auto;">
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="https://raw.githubusercontent.com/trycua/cua/main/img/logo_black.png" style="display: block; margin: auto;">
<img alt="Shows my svg">
</picture>
</div>
[![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#)
[![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85)
[![PyPI](https://img.shields.io/pypi/v/pylume?color=333333)](https://pypi.org/project/pylume/)
</h1>
</div>
**pylume** is a lightweight Python library based on [lume](https://github.com/trycua/lume) to create, run and manage macOS and Linux virtual machines (VMs) natively on Apple Silicon.
```bash
pip install pylume
```
## Usage
Please refer to this [Notebook](./samples/nb.ipynb) for a quickstart. More details about the underlying API used by pylume are available [here](https://github.com/trycua/lume/docs/API-Reference.md).
## Prebuilt Images
Pre-built images are available on [ghcr.io/trycua](https://github.com/orgs/trycua/packages).
These images come pre-configured with an SSH server and auto-login enabled.
## Contributing
We welcome and greatly appreciate contributions to lume! Whether you're improving documentation, adding new features, fixing bugs, or adding new VM images, your efforts help make pylume better for everyone.
Join our [Discord community](https://discord.com/invite/mVnXXpdE85) to discuss ideas or get assistance.
## License
lume is open-sourced under the MIT License - see the [LICENSE](LICENSE) file for details.
## Stargazers over time
[![Stargazers over time](https://starchart.cc/trycua/pylume.svg?variant=adaptive)](https://starchart.cc/trycua/pylume)

View File

@@ -1,9 +0,0 @@
"""
PyLume Python SDK - A client library for managing macOS VMs with PyLume.
"""
from pylume.exceptions import *
from pylume.models import *
from pylume.pylume import *
__version__ = "0.1.0"

View File

@@ -1,59 +0,0 @@
"""
PyLume Python SDK - A client library for managing macOS VMs with PyLume.
Example:
>>> from pylume import PyLume, VMConfig
>>> client = PyLume()
>>> config = VMConfig(name="my-vm", cpu=4, memory="8GB", disk_size="64GB")
>>> client.create_vm(config)
>>> client.run_vm("my-vm")
"""
# Import exceptions then all models
from .exceptions import (
LumeConfigError,
LumeConnectionError,
LumeError,
LumeImageError,
LumeNotFoundError,
LumeServerError,
LumeTimeoutError,
LumeVMError,
)
from .models import (
CloneSpec,
ImageInfo,
ImageList,
ImageRef,
SharedDirectory,
VMConfig,
VMRunOpts,
VMStatus,
VMUpdateOpts,
)
# Import main class last to avoid circular imports
from .pylume import PyLume
__version__ = "0.2.1"
__all__ = [
"PyLume",
"VMConfig",
"VMStatus",
"VMRunOpts",
"VMUpdateOpts",
"ImageRef",
"CloneSpec",
"SharedDirectory",
"ImageList",
"ImageInfo",
"LumeError",
"LumeServerError",
"LumeConnectionError",
"LumeTimeoutError",
"LumeNotFoundError",
"LumeConfigError",
"LumeVMError",
"LumeImageError",
]

View File

@@ -1,119 +0,0 @@
import asyncio
import json
import shlex
import subprocess
from typing import Any, Dict, Optional
from .exceptions import (
LumeConfigError,
LumeConnectionError,
LumeError,
LumeNotFoundError,
LumeServerError,
LumeTimeoutError,
)
class LumeClient:
def __init__(self, base_url: str, timeout: float = 60.0, debug: bool = False):
self.base_url = base_url
self.timeout = timeout
self.debug = debug
def _log_debug(self, message: str, **kwargs) -> None:
"""Log debug information if debug mode is enabled."""
if self.debug:
print(f"DEBUG: {message}")
if kwargs:
print(json.dumps(kwargs, indent=2))
async def _run_curl(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
) -> Any:
"""Execute a curl command and return the response."""
url = f"{self.base_url}{path}"
if params:
param_str = "&".join(f"{k}={v}" for k, v in params.items())
url = f"{url}?{param_str}"
cmd = ["curl", "-X", method, "-s", "-w", "%{http_code}", "-m", str(self.timeout)]
if data is not None:
cmd.extend(["-H", "Content-Type: application/json", "-d", json.dumps(data)])
cmd.append(url)
self._log_debug(f"Running curl command: {' '.join(map(shlex.quote, cmd))}")
try:
process = await asyncio.create_subprocess_exec(
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
raise LumeConnectionError(f"Curl command failed: {stderr.decode()}")
# The last 3 characters are the status code
response = stdout.decode()
status_code = int(response[-3:])
response_body = response[:-3] # Remove status code from response
if status_code >= 400:
if status_code == 404:
raise LumeNotFoundError(f"Resource not found: {path}")
elif status_code == 400:
raise LumeConfigError(f"Invalid request: {response_body}")
elif status_code >= 500:
raise LumeServerError(f"Server error: {response_body}")
else:
raise LumeError(f"Request failed with status {status_code}: {response_body}")
return json.loads(response_body) if response_body.strip() else None
except asyncio.TimeoutError:
raise LumeTimeoutError(f"Request timed out after {self.timeout} seconds")
async def get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Any:
"""Make a GET request."""
return await self._run_curl("GET", path, params=params)
async def post(
self, path: str, data: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None
) -> Any:
"""Make a POST request."""
old_timeout = self.timeout
if timeout is not None:
self.timeout = timeout
try:
return await self._run_curl("POST", path, data=data)
finally:
self.timeout = old_timeout
async def patch(self, path: str, data: Dict[str, Any]) -> None:
"""Make a PATCH request."""
await self._run_curl("PATCH", path, data=data)
async def delete(self, path: str) -> None:
"""Make a DELETE request."""
await self._run_curl("DELETE", path)
def print_curl(self, method: str, path: str, data: Optional[Dict[str, Any]] = None) -> None:
"""Print equivalent curl command for debugging."""
curl_cmd = f"""curl -X {method} \\
'{self.base_url}{path}'"""
if data:
curl_cmd += f" \\\n -H 'Content-Type: application/json' \\\n -d '{json.dumps(data)}'"
print("\nEquivalent curl command:")
print(curl_cmd)
print()
async def close(self) -> None:
"""Close the client resources."""
pass # No shared resources to clean up

View File

@@ -1,54 +0,0 @@
from typing import Optional
class LumeError(Exception):
"""Base exception for all PyLume errors."""
pass
class LumeServerError(LumeError):
"""Raised when there's an error with the PyLume server."""
def __init__(
self, message: str, status_code: Optional[int] = None, response_text: Optional[str] = None
):
self.status_code = status_code
self.response_text = response_text
super().__init__(message)
class LumeConnectionError(LumeError):
"""Raised when there's an error connecting to the PyLume server."""
pass
class LumeTimeoutError(LumeError):
"""Raised when a request to the PyLume server times out."""
pass
class LumeNotFoundError(LumeError):
"""Raised when a requested resource is not found."""
pass
class LumeConfigError(LumeError):
"""Raised when there's an error with the configuration."""
pass
class LumeVMError(LumeError):
"""Raised when there's an error with a VM operation."""
pass
class LumeImageError(LumeError):
"""Raised when there's an error with an image operation."""
pass

Binary file not shown.

View File

@@ -1,265 +0,0 @@
import re
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, RootModel, computed_field, validator
class DiskInfo(BaseModel):
"""Information about disk storage allocation.
Attributes:
total: Total disk space in bytes
allocated: Currently allocated disk space in bytes
"""
total: int
allocated: int
class VMConfig(BaseModel):
"""Configuration for creating a new VM.
Note: Memory and disk sizes should be specified with units (e.g., "4GB", "64GB")
Attributes:
name: Name of the virtual machine
os: Operating system type, either "macOS" or "linux"
cpu: Number of CPU cores to allocate
memory: Amount of memory to allocate with units
disk_size: Size of the disk to create with units
display: Display resolution in format "widthxheight"
ipsw: IPSW path or 'latest' for macOS VMs, None for other OS types
"""
name: str
os: Literal["macOS", "linux"] = "macOS"
cpu: int = Field(default=2, ge=1)
memory: str = "4GB"
disk_size: str = Field(default="64GB", alias="diskSize")
display: str = "1024x768"
ipsw: Optional[str] = Field(default=None, description="IPSW path or 'latest', for macOS VMs")
class Config:
populate_by_alias = True
class SharedDirectory(BaseModel):
"""Configuration for a shared directory.
Attributes:
host_path: Path to the directory on the host system
read_only: Whether the directory should be mounted as read-only
"""
host_path: str = Field(..., alias="hostPath") # Allow host_path but serialize as hostPath
read_only: bool = False
class Config:
populate_by_name = True # Allow both alias and original name
alias_generator = lambda s: "".join(
word.capitalize() if i else word for i, word in enumerate(s.split("_"))
)
class VMRunOpts(BaseModel):
"""Configuration for running a VM.
Args:
no_display: Whether to not display the VNC client
shared_directories: List of directories to share with the VM
"""
no_display: bool = Field(default=False, alias="noDisplay")
shared_directories: Optional[list[SharedDirectory]] = Field(
default=None, alias="sharedDirectories"
)
model_config = ConfigDict(
populate_by_name=True,
alias_generator=lambda s: "".join(
word.capitalize() if i else word for i, word in enumerate(s.split("_"))
),
)
def model_dump(self, **kwargs):
"""Export model data with proper field name conversion.
Converts shared directory fields to match API expectations when using aliases.
Args:
**kwargs: Keyword arguments passed to parent model_dump method
Returns:
dict: Model data with properly formatted field names
"""
data = super().model_dump(**kwargs)
# Convert shared directory fields to match API expectations
if self.shared_directories and "by_alias" in kwargs and kwargs["by_alias"]:
data["sharedDirectories"] = [
{"hostPath": d.host_path, "readOnly": d.read_only} for d in self.shared_directories
]
# Remove the snake_case version if it exists
data.pop("shared_directories", None)
return data
class VMStatus(BaseModel):
"""Status information for a virtual machine.
Attributes:
name: Name of the virtual machine
status: Current status of the VM
os: Operating system type
cpu_count: Number of CPU cores allocated
memory_size: Amount of memory allocated in bytes
disk_size: Disk storage information
vnc_url: URL for VNC connection if available
ip_address: IP address of the VM if available
"""
name: str
status: str
os: Literal["macOS", "linux"]
cpu_count: int = Field(alias="cpuCount")
memory_size: int = Field(alias="memorySize") # API returns memory size in bytes
disk_size: DiskInfo = Field(alias="diskSize")
vnc_url: Optional[str] = Field(default=None, alias="vncUrl")
ip_address: Optional[str] = Field(default=None, alias="ipAddress")
class Config:
populate_by_alias = True
@computed_field
@property
def state(self) -> str:
"""Get the current state of the VM.
Returns:
str: Current VM status
"""
return self.status
@computed_field
@property
def cpu(self) -> int:
"""Get the number of CPU cores.
Returns:
int: Number of CPU cores allocated to the VM
"""
return self.cpu_count
@computed_field
@property
def memory(self) -> str:
"""Get memory allocation in human-readable format.
Returns:
str: Memory size formatted as "{size}GB"
"""
# Convert bytes to GB
gb = self.memory_size / (1024 * 1024 * 1024)
return f"{int(gb)}GB"
class VMUpdateOpts(BaseModel):
"""Options for updating VM configuration.
Attributes:
cpu: Number of CPU cores to update to
memory: Amount of memory to update to with units
disk_size: Size of disk to update to with units
"""
cpu: Optional[int] = None
memory: Optional[str] = None
disk_size: Optional[str] = None
class ImageRef(BaseModel):
"""Reference to a VM image.
Attributes:
image: Name of the image
tag: Tag version of the image
registry: Registry hostname where image is stored
organization: Organization or namespace in the registry
"""
image: str
tag: str = "latest"
registry: Optional[str] = "ghcr.io"
organization: Optional[str] = "trycua"
def model_dump(self, **kwargs):
"""Override model_dump to return just the image:tag format.
Args:
**kwargs: Keyword arguments (ignored)
Returns:
str: Image reference in "image:tag" format
"""
return f"{self.image}:{self.tag}"
class CloneSpec(BaseModel):
"""Specification for cloning a VM.
Attributes:
name: Name of the source VM to clone
new_name: Name for the new cloned VM
"""
name: str
new_name: str = Field(alias="newName")
class Config:
populate_by_alias = True
class ImageInfo(BaseModel):
"""Model for individual image information.
Attributes:
imageId: Unique identifier for the image
"""
imageId: str
class ImageList(RootModel):
"""Response model for the images endpoint.
A list-like container for ImageInfo objects that provides
iteration and indexing capabilities.
"""
root: List[ImageInfo]
def __iter__(self):
"""Iterate over the image list.
Returns:
Iterator over ImageInfo objects
"""
return iter(self.root)
def __getitem__(self, item):
"""Get an item from the image list by index.
Args:
item: Index or slice to retrieve
Returns:
ImageInfo or list of ImageInfo objects
"""
return self.root[item]
def __len__(self):
"""Get the number of images in the list.
Returns:
int: Number of images in the list
"""
return len(self.root)

View File

@@ -1,315 +0,0 @@
import asyncio
import json
import os
import re
import signal
import subprocess
import sys
import time
from functools import wraps
from typing import Any, Callable, List, Optional, TypeVar, Union
from .client import LumeClient
from .exceptions import (
LumeConfigError,
LumeConnectionError,
LumeError,
LumeImageError,
LumeNotFoundError,
LumeServerError,
LumeTimeoutError,
LumeVMError,
)
from .models import (
CloneSpec,
ImageList,
ImageRef,
SharedDirectory,
VMConfig,
VMRunOpts,
VMStatus,
VMUpdateOpts,
)
from .server import LumeServer
# Type variable for the decorator
T = TypeVar("T")
def ensure_server(func: Callable[..., T]) -> Callable[..., T]:
"""Decorator to ensure server is running before executing the method."""
@wraps(func)
async def wrapper(self: "PyLume", *args: Any, **kwargs: Any) -> T:
# ensure_running is an async method, so we need to await it
await self.server.ensure_running()
# Initialize client if needed
await self._init_client()
return await func(self, *args, **kwargs) # type: ignore
return wrapper # type: ignore
class PyLume:
def __init__(
self,
debug: bool = False,
server_start_timeout: int = 60,
port: Optional[int] = None,
use_existing_server: bool = False,
host: str = "localhost",
):
"""Initialize the async PyLume client.
Args:
debug: Enable debug logging
auto_start_server: Whether to automatically start the lume server if not running
server_start_timeout: Timeout in seconds to wait for server to start
port: Port number for the lume server. Required when use_existing_server is True.
use_existing_server: If True, will try to connect to an existing server on the specified port
instead of starting a new one.
host: Host to use for connections (e.g., "localhost", "127.0.0.1", "host.docker.internal")
"""
if use_existing_server and port is None:
raise LumeConfigError("Port must be specified when using an existing server")
self.server = LumeServer(
debug=debug,
server_start_timeout=server_start_timeout,
port=port,
use_existing_server=use_existing_server,
host=host,
)
self.client = None
async def __aenter__(self) -> "PyLume":
"""Async context manager entry."""
if self.server.use_existing_server:
# Just ensure base_url is set for existing server
if self.server.requested_port is None:
raise LumeConfigError("Port must be specified when using an existing server")
if not self.server.base_url:
self.server.port = self.server.requested_port
self.server.base_url = f"http://{self.server.host}:{self.server.port}/lume"
# Ensure the server is running (will connect to existing or start new as needed)
await self.server.ensure_running()
# Initialize the client
await self._init_client()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Async context manager exit."""
if self.client is not None:
await self.client.close()
await self.server.stop()
async def _init_client(self) -> None:
"""Initialize the client if not already initialized."""
if self.client is None:
if self.server.base_url is None:
raise RuntimeError("Server base URL not set")
self.client = LumeClient(self.server.base_url, debug=self.server.debug)
def _log_debug(self, message: str, **kwargs) -> None:
"""Log debug information if debug mode is enabled."""
if self.server.debug:
print(f"DEBUG: {message}")
if kwargs:
print(json.dumps(kwargs, indent=2))
async def _handle_api_error(self, e: Exception, operation: str) -> None:
"""Handle API errors and raise appropriate custom exceptions."""
if isinstance(e, subprocess.SubprocessError):
raise LumeConnectionError(f"Failed to connect to PyLume server: {str(e)}")
elif isinstance(e, asyncio.TimeoutError):
raise LumeTimeoutError(f"Request timed out: {str(e)}")
if not hasattr(e, "status") and not isinstance(e, subprocess.CalledProcessError):
raise LumeServerError(f"Unknown error during {operation}: {str(e)}")
status_code = getattr(e, "status", 500)
response_text = str(e)
self._log_debug(
f"{operation} request failed", status_code=status_code, response_text=response_text
)
if status_code == 404:
raise LumeNotFoundError(f"Resource not found during {operation}")
elif status_code == 400:
raise LumeConfigError(f"Invalid configuration for {operation}: {response_text}")
elif status_code >= 500:
raise LumeServerError(
f"Server error during {operation}",
status_code=status_code,
response_text=response_text,
)
else:
raise LumeServerError(
f"Error during {operation}", status_code=status_code, response_text=response_text
)
async def _read_output(self) -> None:
"""Read and log server output."""
try:
while True:
if not self.server.server_process or self.server.server_process.poll() is not None:
self._log_debug("Server process ended")
break
# Read stdout without blocking
if self.server.server_process.stdout:
while True:
line = self.server.server_process.stdout.readline()
if not line:
break
line = line.strip()
self._log_debug(f"Server stdout: {line}")
if "Server started" in line.decode("utf-8"):
self._log_debug("Detected server started message")
return
# Read stderr without blocking
if self.server.server_process.stderr:
while True:
line = self.server.server_process.stderr.readline()
if not line:
break
line = line.strip()
self._log_debug(f"Server stderr: {line}")
if "error" in line.decode("utf-8").lower():
raise RuntimeError(f"Server error: {line}")
await asyncio.sleep(0.1) # Small delay to prevent CPU spinning
except Exception as e:
self._log_debug(f"Error in output reader: {str(e)}")
raise
@ensure_server
async def create_vm(self, spec: Union[VMConfig, dict]) -> None:
"""Create a VM with the given configuration."""
# Ensure client is initialized
await self._init_client()
if isinstance(spec, VMConfig):
spec = spec.model_dump(by_alias=True, exclude_none=True)
# Suppress optional attribute access errors
self.client.print_curl("POST", "/vms", spec) # type: ignore[attr-defined]
await self.client.post("/vms", spec) # type: ignore[attr-defined]
@ensure_server
async def run_vm(self, name: str, opts: Optional[Union[VMRunOpts, dict]] = None) -> None:
"""Run a VM."""
if opts is None:
opts = VMRunOpts(no_display=False) # type: ignore[attr-defined]
elif isinstance(opts, dict):
opts = VMRunOpts(**opts)
payload = opts.model_dump(by_alias=True, exclude_none=True)
self.client.print_curl("POST", f"/vms/{name}/run", payload) # type: ignore[attr-defined]
await self.client.post(f"/vms/{name}/run", payload) # type: ignore[attr-defined]
@ensure_server
async def list_vms(self) -> List[VMStatus]:
"""List all VMs."""
data = await self.client.get("/vms") # type: ignore[attr-defined]
return [VMStatus.model_validate(vm) for vm in data]
@ensure_server
async def get_vm(self, name: str) -> VMStatus:
"""Get VM details."""
data = await self.client.get(f"/vms/{name}") # type: ignore[attr-defined]
return VMStatus.model_validate(data)
@ensure_server
async def update_vm(self, name: str, params: Union[VMUpdateOpts, dict]) -> None:
"""Update VM settings."""
if isinstance(params, dict):
params = VMUpdateOpts(**params)
payload = params.model_dump(by_alias=True, exclude_none=True)
self.client.print_curl("PATCH", f"/vms/{name}", payload) # type: ignore[attr-defined]
await self.client.patch(f"/vms/{name}", payload) # type: ignore[attr-defined]
@ensure_server
async def stop_vm(self, name: str) -> None:
"""Stop a VM."""
await self.client.post(f"/vms/{name}/stop") # type: ignore[attr-defined]
@ensure_server
async def delete_vm(self, name: str) -> None:
"""Delete a VM."""
await self.client.delete(f"/vms/{name}") # type: ignore[attr-defined]
@ensure_server
async def pull_image(
self, spec: Union[ImageRef, dict, str], name: Optional[str] = None
) -> None:
"""Pull a VM image."""
await self._init_client()
if isinstance(spec, str):
if ":" in spec:
image_str = spec
else:
image_str = f"{spec}:latest"
registry = "ghcr.io"
organization = "trycua"
elif isinstance(spec, dict):
image = spec.get("image", "")
tag = spec.get("tag", "latest")
image_str = f"{image}:{tag}"
registry = spec.get("registry", "ghcr.io")
organization = spec.get("organization", "trycua")
else:
image_str = f"{spec.image}:{spec.tag}"
registry = spec.registry
organization = spec.organization
payload = {
"image": image_str,
"name": name,
"registry": registry,
"organization": organization,
}
self.client.print_curl("POST", "/pull", payload) # type: ignore[attr-defined]
await self.client.post("/pull", payload, timeout=300.0) # type: ignore[attr-defined]
@ensure_server
async def clone_vm(self, name: str, new_name: str) -> None:
"""Clone a VM with the given name to a new VM with new_name."""
config = CloneSpec(name=name, newName=new_name)
self.client.print_curl("POST", "/vms/clone", config.model_dump()) # type: ignore[attr-defined]
await self.client.post("/vms/clone", config.model_dump()) # type: ignore[attr-defined]
@ensure_server
async def get_latest_ipsw_url(self) -> str:
"""Get the latest IPSW URL."""
await self._init_client()
data = await self.client.get("/ipsw") # type: ignore[attr-defined]
return data["url"]
@ensure_server
async def get_images(self, organization: Optional[str] = None) -> ImageList:
"""Get list of available images."""
await self._init_client()
params = {"organization": organization} if organization else None
data = await self.client.get("/images", params) # type: ignore[attr-defined]
return ImageList(root=data)
async def close(self) -> None:
"""Close the client and stop the server."""
if self.client is not None:
await self.client.close()
self.client = None
await asyncio.sleep(1)
await self.server.stop()
async def _ensure_client(self) -> None:
"""Ensure client is initialized."""
if self.client is None:
await self._init_client()

View File

@@ -1,481 +0,0 @@
import asyncio
import json
import logging
import os
import random
import shlex
import signal
import socket
import subprocess
import sys
import tempfile
import time
from logging import getLogger
from typing import Optional
from .exceptions import LumeConnectionError
class LumeServer:
def __init__(
self,
debug: bool = False,
server_start_timeout: int = 60,
port: Optional[int] = None,
use_existing_server: bool = False,
host: str = "localhost",
):
"""Initialize the LumeServer.
Args:
debug: Enable debug logging
server_start_timeout: Timeout in seconds to wait for server to start
port: Specific port to use for the server
use_existing_server: If True, will try to connect to an existing server
instead of starting a new one
host: Host to use for connections (e.g., "localhost", "127.0.0.1", "host.docker.internal")
"""
self.debug = debug
self.server_start_timeout = server_start_timeout
self.server_process = None
self.output_file = None
self.requested_port = port
self.port = None
self.base_url = None
self.use_existing_server = use_existing_server
self.host = host
# Configure logging
self.logger = getLogger("pylume.server")
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.DEBUG if debug else logging.INFO)
self.logger.debug(f"Server initialized with host: {self.host}")
def _check_port_available(self, port: int) -> bool:
"""Check if a port is available."""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(0.5)
result = s.connect_ex(("127.0.0.1", port))
if result == 0: # Port is in use on localhost
return False
except:
pass
# Check the specified host (e.g., "host.docker.internal") if it's not a localhost alias
if self.host not in ["localhost", "127.0.0.1"]:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(0.5)
result = s.connect_ex((self.host, port))
if result == 0: # Port is in use on host
return False
except:
pass
return True
def _get_server_port(self) -> int:
"""Get an available port for the server."""
# Use requested port if specified
if self.requested_port is not None:
if not self._check_port_available(self.requested_port):
raise RuntimeError(f"Requested port {self.requested_port} is not available")
return self.requested_port
# Find a free port
for _ in range(10): # Try up to 10 times
port = random.randint(49152, 65535)
if self._check_port_available(port):
return port
raise RuntimeError("Could not find an available port")
async def _ensure_server_running(self) -> None:
"""Ensure the lume server is running, start it if it's not."""
try:
self.logger.debug("Checking if lume server is running...")
# Try to connect to the server with a short timeout
cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "5", f"{self.base_url}/vms"]
process = await asyncio.create_subprocess_exec(
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
response = stdout.decode()
status_code = int(response[-3:])
if status_code == 200:
self.logger.debug("PyLume server is running")
return
self.logger.debug("PyLume server not running, attempting to start it")
# Server not running, try to start it
lume_path = os.path.join(os.path.dirname(__file__), "lume")
if not os.path.exists(lume_path):
raise RuntimeError(f"Could not find lume binary at {lume_path}")
# Make sure the file is executable
os.chmod(lume_path, 0o755)
# Create a temporary file for server output
self.output_file = tempfile.NamedTemporaryFile(mode="w+", delete=False)
self.logger.debug(f"Using temporary file for server output: {self.output_file.name}")
# Start the server
self.logger.debug(f"Starting lume server with: {lume_path} serve --port {self.port}")
# Start server in background using subprocess.Popen
try:
self.server_process = subprocess.Popen(
[lume_path, "serve", "--port", str(self.port)],
stdout=self.output_file,
stderr=self.output_file,
cwd=os.path.dirname(lume_path),
start_new_session=True, # Run in new session to avoid blocking
)
except Exception as e:
self.output_file.close()
os.unlink(self.output_file.name)
raise RuntimeError(f"Failed to start lume server process: {str(e)}")
# Wait for server to start
self.logger.debug(
f"Waiting up to {self.server_start_timeout} seconds for server to start..."
)
start_time = time.time()
server_ready = False
last_size = 0
while time.time() - start_time < self.server_start_timeout:
if self.server_process.poll() is not None:
# Process has terminated
self.output_file.seek(0)
output = self.output_file.read()
self.output_file.close()
os.unlink(self.output_file.name)
error_msg = (
f"Server process terminated unexpectedly.\n"
f"Exit code: {self.server_process.returncode}\n"
f"Output: {output}"
)
raise RuntimeError(error_msg)
# Check output file for server ready message
self.output_file.seek(0, os.SEEK_END)
size = self.output_file.tell()
if size > last_size: # Only read if there's new content
self.output_file.seek(last_size)
new_output = self.output_file.read()
if new_output.strip(): # Only log non-empty output
self.logger.debug(f"Server output: {new_output.strip()}")
last_size = size
if "Server started" in new_output:
server_ready = True
self.logger.debug("Server startup detected")
break
# Try to connect to the server periodically
try:
cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "5", f"{self.base_url}/vms"]
process = await asyncio.create_subprocess_exec(
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
response = stdout.decode()
status_code = int(response[-3:])
if status_code == 200:
server_ready = True
self.logger.debug("Server is responding to requests")
break
except:
pass # Server not ready yet
await asyncio.sleep(1.0)
if not server_ready:
# Cleanup if server didn't start
if self.server_process:
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.server_process.kill()
self.output_file.close()
os.unlink(self.output_file.name)
raise RuntimeError(
f"Failed to start lume server after {self.server_start_timeout} seconds. "
"Check the debug output for more details."
)
# Give the server a moment to fully initialize
await asyncio.sleep(2.0)
# Verify server is responding
try:
cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "10", f"{self.base_url}/vms"]
process = await asyncio.create_subprocess_exec(
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
raise RuntimeError(f"Curl command failed: {stderr.decode()}")
response = stdout.decode()
status_code = int(response[-3:])
if status_code != 200:
raise RuntimeError(f"Server returned status code {status_code}")
self.logger.debug("PyLume server started successfully")
except Exception as e:
self.logger.debug(f"Server verification failed: {str(e)}")
if self.server_process:
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.server_process.kill()
self.output_file.close()
os.unlink(self.output_file.name)
raise RuntimeError(f"Server started but is not responding: {str(e)}")
self.logger.debug("Server startup completed successfully")
except Exception as e:
raise RuntimeError(f"Failed to start lume server: {str(e)}")
async def _start_server(self) -> None:
"""Start the lume server using the lume executable."""
self.logger.debug("Starting PyLume server")
# Get absolute path to lume executable in the same directory as this file
lume_path = os.path.join(os.path.dirname(__file__), "lume")
if not os.path.exists(lume_path):
raise RuntimeError(f"Could not find lume binary at {lume_path}")
try:
# Make executable
os.chmod(lume_path, 0o755)
# Get and validate port
self.port = self._get_server_port()
self.base_url = f"http://{self.host}:{self.port}/lume"
# Set up output handling
self.output_file = tempfile.NamedTemporaryFile(mode="w+", delete=False)
# Start the server process with the lume executable
env = os.environ.copy()
env["RUST_BACKTRACE"] = "1" # Enable backtrace for better error reporting
# Specify the host to bind to (0.0.0.0 to allow external connections)
self.server_process = subprocess.Popen(
[lume_path, "serve", "--port", str(self.port)],
stdout=self.output_file,
stderr=subprocess.STDOUT,
cwd=os.path.dirname(lume_path), # Run from same directory as executable
env=env,
)
# Wait for server to initialize
await asyncio.sleep(2)
await self._wait_for_server()
except Exception as e:
await self._cleanup()
raise RuntimeError(f"Failed to start lume server process: {str(e)}")
async def _tail_log(self) -> None:
"""Read and display server log output in debug mode."""
while True:
try:
self.output_file.seek(0, os.SEEK_END) # type: ignore[attr-defined]
line = self.output_file.readline() # type: ignore[attr-defined]
if line:
line = line.strip()
if line:
print(f"SERVER: {line}")
if self.server_process.poll() is not None: # type: ignore[attr-defined]
print("Server process ended")
break
await asyncio.sleep(0.1)
except Exception as e:
print(f"Error reading log: {e}")
await asyncio.sleep(0.1)
async def _wait_for_server(self) -> None:
"""Wait for server to start and become responsive with increased timeout."""
start_time = time.time()
while time.time() - start_time < self.server_start_timeout:
if self.server_process.poll() is not None: # type: ignore[attr-defined]
error_msg = await self._get_error_output()
await self._cleanup()
raise RuntimeError(error_msg)
try:
await self._verify_server()
self.logger.debug("Server is now responsive")
return
except Exception as e:
self.logger.debug(f"Server not ready yet: {str(e)}")
await asyncio.sleep(1.0)
await self._cleanup()
raise RuntimeError(f"Server failed to start after {self.server_start_timeout} seconds")
async def _verify_server(self) -> None:
"""Verify server is responding to requests."""
try:
cmd = [
"curl",
"-s",
"-w",
"%{http_code}",
"-m",
"10",
f"http://{self.host}:{self.port}/lume/vms",
]
process = await asyncio.create_subprocess_exec(
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
raise RuntimeError(f"Curl command failed: {stderr.decode()}")
response = stdout.decode()
status_code = int(response[-3:])
if status_code != 200:
raise RuntimeError(f"Server returned status code {status_code}")
self.logger.debug("PyLume server started successfully")
except Exception as e:
raise RuntimeError(f"Server not responding: {str(e)}")
async def _get_error_output(self) -> str:
"""Get error output from the server process."""
if not self.output_file:
return "No output available"
self.output_file.seek(0)
output = self.output_file.read()
return (
f"Server process terminated unexpectedly.\n"
f"Exit code: {self.server_process.returncode}\n" # type: ignore[attr-defined]
f"Output: {output}"
)
async def _cleanup(self) -> None:
"""Clean up all server resources."""
if self.server_process:
try:
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.server_process.kill()
except:
pass
self.server_process = None
# Clean up output file
if self.output_file:
try:
self.output_file.close()
os.unlink(self.output_file.name)
except Exception as e:
self.logger.debug(f"Error cleaning up output file: {e}")
self.output_file = None
async def ensure_running(self) -> None:
"""Ensure the server is running.
If use_existing_server is True, will only try to connect to an existing server.
Otherwise will:
1. Try to connect to an existing server on the specified port
2. If that fails and not in Docker, start a new server
3. If in Docker and no existing server is found, raise an error
"""
# First check if we're in Docker
in_docker = os.path.exists("/.dockerenv") or (
os.path.exists("/proc/1/cgroup") and "docker" in open("/proc/1/cgroup", "r").read()
)
# If using a non-localhost host like host.docker.internal, set up the connection details
if self.host not in ["localhost", "127.0.0.1"]:
if self.requested_port is None:
raise RuntimeError("Port must be specified when using a remote host")
self.port = self.requested_port
self.base_url = f"http://{self.host}:{self.port}/lume"
self.logger.debug(f"Using remote host server at {self.base_url}")
# Try to verify the server is accessible
try:
await self._verify_server()
self.logger.debug("Successfully connected to remote server")
return
except Exception as e:
if self.use_existing_server or in_docker:
# If explicitly requesting an existing server or in Docker, we can't start a new one
raise RuntimeError(
f"Failed to connect to remote server at {self.base_url}: {str(e)}"
)
else:
self.logger.debug(f"Remote server not available at {self.base_url}: {str(e)}")
# Fall back to localhost for starting a new server
self.host = "localhost"
# If explicitly using an existing server, verify it's running
if self.use_existing_server:
if self.requested_port is None:
raise RuntimeError("Port must be specified when using an existing server")
self.port = self.requested_port
self.base_url = f"http://{self.host}:{self.port}/lume"
try:
await self._verify_server()
self.logger.debug("Successfully connected to existing server")
except Exception as e:
raise RuntimeError(
f"Failed to connect to existing server at {self.base_url}: {str(e)}"
)
else:
# Try to connect to an existing server first
if self.requested_port is not None:
self.port = self.requested_port
self.base_url = f"http://{self.host}:{self.port}/lume"
try:
await self._verify_server()
self.logger.debug("Successfully connected to existing server")
return
except Exception:
self.logger.debug(f"No existing server found at {self.base_url}")
# If in Docker and can't connect to existing server, raise an error
if in_docker:
raise RuntimeError(
f"Failed to connect to server at {self.base_url} and cannot start a new server in Docker"
)
# Start a new server
self.logger.debug("Starting a new server instance")
await self._start_server()
async def stop(self) -> None:
"""Stop the server if we're managing it."""
if not self.use_existing_server:
self.logger.debug("Stopping lume server...")
await self._cleanup()

View File

@@ -1,51 +0,0 @@
[build-system]
build-backend = "pdm.backend"
requires = ["pdm-backend"]
[project]
authors = [{ name = "TryCua", email = "gh@trycua.com" }]
classifiers = [
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: MacOS :: MacOS X",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
dependencies = ["pydantic>=2.11.1"]
description = "Python SDK for lume - run macOS and Linux VMs on Apple Silicon"
dynamic = ["version"]
keywords = ["apple-silicon", "macos", "virtualization", "vm"]
license = { text = "MIT" }
name = "pylume"
readme = "README.md"
requires-python = ">=3.12"
[tool.pdm.version]
path = "pylume/__init__.py"
source = "file"
[project.urls]
homepage = "https://github.com/trycua/pylume"
repository = "https://github.com/trycua/pylume"
[tool.pdm]
distribution = true
[tool.pdm.dev-dependencies]
dev = [
"black>=23.0.0",
"isort>=5.12.0",
"pytest-asyncio>=0.23.0",
"pytest>=7.0.0",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
python_files = "test_*.py"
testpaths = ["tests"]
[tool.pdm.build]
includes = ["pylume/"]
source-includes = ["LICENSE", "README.md", "tests/"]

View File

@@ -0,0 +1,23 @@
"""Pytest configuration for pylume tests.
This module provides test fixtures for the pylume package.
Note: This package has macOS-specific dependencies and will skip tests
if the required modules are not available.
"""
from unittest.mock import Mock, patch
import pytest
@pytest.fixture
def mock_subprocess():
with patch("subprocess.run") as mock_run:
mock_run.return_value = Mock(returncode=0, stdout="", stderr="")
yield mock_run
@pytest.fixture
def mock_requests():
with patch("requests.get") as mock_get, patch("requests.post") as mock_post:
yield {"get": mock_get, "post": mock_post}

View File

@@ -0,0 +1,38 @@
"""Unit tests for pylume package.
This file tests ONLY basic pylume functionality.
Following SRP: This file tests pylume module imports and basic operations.
All external dependencies are mocked.
"""
import pytest
class TestPylumeImports:
"""Test pylume module imports (SRP: Only tests imports)."""
def test_pylume_module_exists(self):
"""Test that pylume module can be imported."""
try:
import pylume
assert pylume is not None
except ImportError:
pytest.skip("pylume module not installed")
class TestPylumeInitialization:
"""Test pylume initialization (SRP: Only tests initialization)."""
def test_pylume_can_be_imported(self):
"""Basic smoke test: verify pylume components can be imported."""
try:
import pylume
# Check for basic attributes
assert pylume is not None
except ImportError:
pytest.skip("pylume module not available")
except Exception as e:
# Some initialization errors are acceptable in unit tests
pytest.skip(f"pylume initialization requires specific setup: {e}")

View File

@@ -0,0 +1,24 @@
"""Pytest configuration for som tests.
This module provides test fixtures for the som (Set-of-Mark) package.
The som package depends on heavy ML models and will skip tests if not available.
"""
from unittest.mock import Mock, patch
import pytest
@pytest.fixture
def mock_torch():
with patch("torch.load") as mock_load:
mock_load.return_value = Mock()
yield mock_load
@pytest.fixture
def mock_icon_detector():
with patch("omniparser.IconDetector") as mock_detector:
instance = Mock()
mock_detector.return_value = instance
yield instance

View File

@@ -1,13 +1,73 @@
# """Basic tests for the omniparser package."""
"""Unit tests for som package (Set-of-Mark).
# import pytest
# from omniparser import IconDetector
This file tests ONLY basic som functionality.
Following SRP: This file tests som module imports and basic operations.
All external dependencies (ML models, OCR) are mocked.
"""
# def test_icon_detector_import():
# """Test that we can import the IconDetector class."""
# assert IconDetector is not None
import pytest
# def test_icon_detector_init():
# """Test that we can create an IconDetector instance."""
# detector = IconDetector(force_cpu=True)
# assert detector is not None
class TestSomImports:
"""Test som module imports (SRP: Only tests imports)."""
def test_som_module_exists(self):
"""Test that som module can be imported."""
try:
import som
assert som is not None
except ImportError:
pytest.skip("som module not installed")
def test_omniparser_import(self):
"""Test that OmniParser can be imported."""
try:
from som import OmniParser
assert OmniParser is not None
except ImportError:
pytest.skip("som module not available")
except Exception as e:
pytest.skip(f"som initialization requires ML models: {e}")
def test_models_import(self):
"""Test that model classes can be imported."""
try:
from som import BoundingBox, ParseResult, UIElement
assert BoundingBox is not None
assert UIElement is not None
assert ParseResult is not None
except ImportError:
pytest.skip("som models not available")
except Exception as e:
pytest.skip(f"som models require dependencies: {e}")
class TestSomModels:
"""Test som data models (SRP: Only tests model structure)."""
def test_bounding_box_structure(self):
"""Test BoundingBox class structure."""
try:
from som import BoundingBox
# Check the class exists and has expected structure
assert hasattr(BoundingBox, "__init__")
except ImportError:
pytest.skip("som models not available")
except Exception as e:
pytest.skip(f"som models require dependencies: {e}")
def test_ui_element_structure(self):
"""Test UIElement class structure."""
try:
from som import UIElement
# Check the class exists and has expected structure
assert hasattr(UIElement, "__init__")
except ImportError:
pytest.skip("som models not available")
except Exception as e:
pytest.skip(f"som models require dependencies: {e}")