Merge branch 'main' into feat/add-desktop-commands

This commit is contained in:
Dillon DuPont
2025-10-29 16:15:54 -04:00
103 changed files with 2333 additions and 2559 deletions

View File

@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.4.35
current_version = 0.4.36
commit = True
tag = True
tag_name = agent-v{new_version}

View File

@@ -28,8 +28,12 @@ class AsyncComputerHandler(Protocol):
"""Get screen dimensions as (width, height)."""
...
async def screenshot(self) -> str:
"""Take a screenshot and return as base64 string."""
async def screenshot(self, text: Optional[str] = None) -> str:
"""Take a screenshot and return as base64 string.
Args:
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
"""
...
async def click(self, x: int, y: int, button: str = "left") -> None:

View File

@@ -36,8 +36,12 @@ class cuaComputerHandler(AsyncComputerHandler):
screen_size = await self.interface.get_screen_size()
return screen_size["width"], screen_size["height"]
async def screenshot(self) -> str:
"""Take a screenshot and return as base64 string."""
async def screenshot(self, text: Optional[str] = None) -> str:
"""Take a screenshot and return as base64 string.
Args:
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
"""
assert self.interface is not None
screenshot_bytes = await self.interface.screenshot()
return base64.b64encode(screenshot_bytes).decode("utf-8")

View File

@@ -122,8 +122,12 @@ class CustomComputerHandler(AsyncComputerHandler):
return self._last_screenshot_size
async def screenshot(self) -> str:
"""Take a screenshot and return as base64 string."""
async def screenshot(self, text: Optional[str] = None) -> str:
"""Take a screenshot and return as base64 string.
Args:
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
"""
result = await self._call_function(self.functions["screenshot"])
b64_str = self._to_b64_str(result) # type: ignore

View File

@@ -243,18 +243,20 @@ async def replace_computer_call_with_function(
"id": item.get("id"),
"call_id": item.get("call_id"),
"status": "completed",
# Fall back to string representation
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})",
}
]
elif item_type == "computer_call_output":
# Simple conversion: computer_call_output -> function_call_output
output = item.get("output")
if isinstance(output, dict):
output = [output]
return [
{
"type": "function_call_output",
"call_id": item.get("call_id"),
"content": [item.get("output")],
"output": output,
"id": item.get("id"),
"status": "completed",
}

View File

@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
[project]
name = "cua-agent"
version = "0.4.35"
version = "0.4.36"
description = "CUA (Computer Use) Agent for AI-driven computer interaction"
readme = "README.md"
authors = [

View File

@@ -0,0 +1,84 @@
"""Pytest configuration and shared fixtures for agent package tests.
This file contains shared fixtures and configuration for all agent tests.
Following SRP: This file ONLY handles test setup/teardown.
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
@pytest.fixture
def mock_litellm():
"""Mock liteLLM completion calls.
Use this fixture to avoid making real LLM API calls during tests.
Returns a mock that simulates LLM responses.
"""
with patch("litellm.acompletion") as mock_completion:
async def mock_response(*args, **kwargs):
"""Simulate a typical LLM response."""
return {
"id": "chatcmpl-test123",
"object": "chat.completion",
"created": 1234567890,
"model": kwargs.get("model", "anthropic/claude-3-5-sonnet-20241022"),
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a mocked response for testing.",
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
},
}
mock_completion.side_effect = mock_response
yield mock_completion
@pytest.fixture
def mock_computer():
"""Mock Computer interface for agent tests.
Use this fixture to test agent logic without requiring a real Computer instance.
"""
computer = AsyncMock()
computer.interface = AsyncMock()
computer.interface.screenshot = AsyncMock(return_value=b"fake_screenshot_data")
computer.interface.left_click = AsyncMock()
computer.interface.type = AsyncMock()
computer.interface.key = AsyncMock()
# Mock context manager
computer.__aenter__ = AsyncMock(return_value=computer)
computer.__aexit__ = AsyncMock()
return computer
@pytest.fixture
def disable_telemetry(monkeypatch):
"""Disable telemetry for tests.
Use this fixture to ensure no telemetry is sent during tests.
"""
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
@pytest.fixture
def sample_messages():
"""Provide sample messages for testing.
Returns a list of messages in the expected format.
"""
return [{"role": "user", "content": "Take a screenshot and tell me what you see"}]

View File

@@ -0,0 +1,139 @@
"""Unit tests for ComputerAgent class.
This file tests ONLY the ComputerAgent initialization and basic functionality.
Following SRP: This file tests ONE class (ComputerAgent).
All external dependencies (liteLLM, Computer) are mocked.
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
class TestComputerAgentInitialization:
"""Test ComputerAgent initialization (SRP: Only tests initialization)."""
@patch("agent.agent.litellm")
def test_agent_initialization_with_model(self, mock_litellm, disable_telemetry):
"""Test that agent can be initialized with a model string."""
from agent import ComputerAgent
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
assert agent is not None
assert hasattr(agent, "model")
assert agent.model == "anthropic/claude-3-5-sonnet-20241022"
@patch("agent.agent.litellm")
def test_agent_initialization_with_tools(self, mock_litellm, disable_telemetry, mock_computer):
"""Test that agent can be initialized with tools."""
from agent import ComputerAgent
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022", tools=[mock_computer])
assert agent is not None
assert hasattr(agent, "tools")
@patch("agent.agent.litellm")
def test_agent_initialization_with_max_budget(self, mock_litellm, disable_telemetry):
"""Test that agent can be initialized with max trajectory budget."""
from agent import ComputerAgent
budget = 5.0
agent = ComputerAgent(
model="anthropic/claude-3-5-sonnet-20241022", max_trajectory_budget=budget
)
assert agent is not None
@patch("agent.agent.litellm")
def test_agent_requires_model(self, mock_litellm, disable_telemetry):
"""Test that agent requires a model parameter."""
from agent import ComputerAgent
with pytest.raises(TypeError):
# Should fail without model parameter - intentionally missing required argument
ComputerAgent() # type: ignore[call-arg]
class TestComputerAgentRun:
"""Test ComputerAgent.run() method (SRP: Only tests run logic)."""
@pytest.mark.asyncio
@patch("agent.agent.litellm")
async def test_agent_run_with_messages(self, mock_litellm, disable_telemetry, sample_messages):
"""Test that agent.run() works with valid messages."""
from agent import ComputerAgent
# Mock liteLLM response
mock_response = {
"id": "chatcmpl-test",
"choices": [
{
"message": {"role": "assistant", "content": "Test response"},
"finish_reason": "stop",
}
],
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
}
mock_litellm.acompletion = AsyncMock(return_value=mock_response)
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
# Run should return an async generator
result_generator = agent.run(sample_messages)
assert result_generator is not None
# Check it's an async generator
assert hasattr(result_generator, "__anext__")
def test_agent_has_run_method(self, disable_telemetry):
"""Test that agent has run method available."""
from agent import ComputerAgent
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
# Verify run method exists
assert hasattr(agent, "run")
assert callable(agent.run)
def test_agent_has_agent_loop(self, disable_telemetry):
"""Test that agent has agent_loop initialized."""
from agent import ComputerAgent
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
# Verify agent_loop is initialized
assert hasattr(agent, "agent_loop")
assert agent.agent_loop is not None
class TestComputerAgentTypes:
"""Test AgentResponse and Messages types (SRP: Only tests type definitions)."""
def test_messages_type_exists(self):
"""Test that Messages type is exported."""
from agent import Messages
assert Messages is not None
def test_agent_response_type_exists(self):
"""Test that AgentResponse type is exported."""
from agent import AgentResponse
assert AgentResponse is not None
class TestComputerAgentIntegration:
"""Test ComputerAgent integration with Computer tool (SRP: Integration within package)."""
def test_agent_accepts_computer_tool(self, disable_telemetry, mock_computer):
"""Test that agent can be initialized with Computer tool."""
from agent import ComputerAgent
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022", tools=[mock_computer])
# Verify agent accepted the tool
assert agent is not None
assert hasattr(agent, "tools")

View File

@@ -0,0 +1,47 @@
"""Pytest configuration and shared fixtures for computer-server package tests.
This file contains shared fixtures and configuration for all computer-server tests.
Following SRP: This file ONLY handles test setup/teardown.
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
@pytest.fixture
def mock_websocket():
"""Mock WebSocket connection for testing.
Use this fixture to test WebSocket logic without real connections.
"""
websocket = AsyncMock()
websocket.send = AsyncMock()
websocket.recv = AsyncMock()
websocket.close = AsyncMock()
return websocket
@pytest.fixture
def mock_computer_interface():
"""Mock computer interface for server tests.
Use this fixture to test server logic without real computer operations.
"""
interface = AsyncMock()
interface.screenshot = AsyncMock(return_value=b"fake_screenshot")
interface.left_click = AsyncMock()
interface.type = AsyncMock()
interface.key = AsyncMock()
return interface
@pytest.fixture
def disable_telemetry(monkeypatch):
"""Disable telemetry for tests.
Use this fixture to ensure no telemetry is sent during tests.
"""
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")

View File

@@ -0,0 +1,40 @@
"""Unit tests for computer-server package.
This file tests ONLY basic server functionality.
Following SRP: This file tests server initialization and basic operations.
All external dependencies are mocked.
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
class TestServerImports:
"""Test server module imports (SRP: Only tests imports)."""
def test_server_module_exists(self):
"""Test that server module can be imported."""
try:
import computer_server
assert computer_server is not None
except ImportError:
pytest.skip("computer_server module not installed")
class TestServerInitialization:
"""Test server initialization (SRP: Only tests initialization)."""
@pytest.mark.asyncio
async def test_server_can_be_imported(self):
"""Basic smoke test: verify server components can be imported."""
try:
from computer_server import server
assert server is not None
except ImportError:
pytest.skip("Server module not available")
except Exception as e:
# Some initialization errors are acceptable in unit tests
pytest.skip(f"Server initialization requires specific setup: {e}")

View File

@@ -0,0 +1,69 @@
"""Pytest configuration and shared fixtures for computer package tests.
This file contains shared fixtures and configuration for all computer tests.
Following SRP: This file ONLY handles test setup/teardown.
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
@pytest.fixture
def mock_interface():
"""Mock computer interface for testing.
Use this fixture to test Computer logic without real OS calls.
"""
interface = AsyncMock()
interface.screenshot = AsyncMock(return_value=b"fake_screenshot")
interface.left_click = AsyncMock()
interface.right_click = AsyncMock()
interface.middle_click = AsyncMock()
interface.double_click = AsyncMock()
interface.type = AsyncMock()
interface.key = AsyncMock()
interface.move_mouse = AsyncMock()
interface.scroll = AsyncMock()
interface.get_screen_size = AsyncMock(return_value=(1920, 1080))
return interface
@pytest.fixture
def mock_cloud_provider():
"""Mock cloud provider for testing.
Use this fixture to test cloud provider logic without real API calls.
"""
provider = AsyncMock()
provider.start = AsyncMock()
provider.stop = AsyncMock()
provider.get_status = AsyncMock(return_value="running")
provider.execute_command = AsyncMock(return_value="command output")
return provider
@pytest.fixture
def mock_local_provider():
"""Mock local provider for testing.
Use this fixture to test local provider logic without real VM operations.
"""
provider = AsyncMock()
provider.start = AsyncMock()
provider.stop = AsyncMock()
provider.get_status = AsyncMock(return_value="running")
provider.execute_command = AsyncMock(return_value="command output")
return provider
@pytest.fixture
def disable_telemetry(monkeypatch):
"""Disable telemetry for tests.
Use this fixture to ensure no telemetry is sent during tests.
"""
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")

View File

@@ -0,0 +1,67 @@
"""Unit tests for Computer class.
This file tests ONLY the Computer class initialization and context manager.
Following SRP: This file tests ONE class (Computer).
All external dependencies (providers, interfaces) are mocked.
"""
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
class TestComputerImport:
"""Test Computer module imports (SRP: Only tests imports)."""
def test_computer_class_exists(self):
"""Test that Computer class can be imported."""
from computer import Computer
assert Computer is not None
def test_vm_provider_type_exists(self):
"""Test that VMProviderType enum can be imported."""
from computer import VMProviderType
assert VMProviderType is not None
class TestComputerInitialization:
"""Test Computer initialization (SRP: Only tests initialization)."""
def test_computer_class_can_be_imported(self, disable_telemetry):
"""Test that Computer class can be imported without errors."""
from computer import Computer
assert Computer is not None
def test_computer_has_required_methods(self, disable_telemetry):
"""Test that Computer class has required methods."""
from computer import Computer
assert hasattr(Computer, "__aenter__")
assert hasattr(Computer, "__aexit__")
class TestComputerContextManager:
"""Test Computer context manager protocol (SRP: Only tests context manager)."""
def test_computer_is_async_context_manager(self, disable_telemetry):
"""Test that Computer has async context manager methods."""
from computer import Computer
assert hasattr(Computer, "__aenter__")
assert hasattr(Computer, "__aexit__")
assert callable(Computer.__aenter__)
assert callable(Computer.__aexit__)
class TestComputerInterface:
"""Test Computer.interface property (SRP: Only tests interface access)."""
def test_computer_class_structure(self, disable_telemetry):
"""Test that Computer class has expected structure."""
from computer import Computer
# Verify Computer is a class
assert isinstance(Computer, type)

View File

@@ -0,0 +1,43 @@
"""Pytest configuration and shared fixtures for core package tests.
This file contains shared fixtures and configuration for all core tests.
Following SRP: This file ONLY handles test setup/teardown.
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
@pytest.fixture
def mock_httpx_client():
"""Mock httpx.AsyncClient for API calls.
Use this fixture to avoid making real HTTP requests during tests.
"""
with patch("httpx.AsyncClient") as mock_client:
mock_instance = AsyncMock()
mock_client.return_value.__aenter__.return_value = mock_instance
yield mock_instance
@pytest.fixture
def mock_posthog():
"""Mock PostHog client for telemetry tests.
Use this fixture to avoid sending real telemetry during tests.
"""
with patch("posthog.Posthog") as mock_ph:
mock_instance = Mock()
mock_ph.return_value = mock_instance
yield mock_instance
@pytest.fixture
def disable_telemetry(monkeypatch):
"""Disable telemetry for tests that don't need it.
Use this fixture to ensure telemetry is disabled during tests.
"""
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
yield

View File

@@ -0,0 +1,255 @@
"""Unit tests for core telemetry functionality.
This file tests ONLY telemetry logic, following SRP.
All external dependencies (PostHog, file system) are mocked.
"""
import os
from pathlib import Path
from unittest.mock import MagicMock, Mock, mock_open, patch
import pytest
class TestTelemetryEnabled:
"""Test telemetry enable/disable logic (SRP: Only tests enable/disable)."""
def test_telemetry_enabled_by_default(self, monkeypatch):
"""Test that telemetry is enabled by default."""
# Remove any environment variables that might affect the test
monkeypatch.delenv("CUA_TELEMETRY", raising=False)
monkeypatch.delenv("CUA_TELEMETRY_ENABLED", raising=False)
from core.telemetry import is_telemetry_enabled
assert is_telemetry_enabled() is True
def test_telemetry_disabled_with_legacy_flag(self, monkeypatch):
"""Test that telemetry can be disabled with legacy CUA_TELEMETRY=off."""
monkeypatch.setenv("CUA_TELEMETRY", "off")
from core.telemetry import is_telemetry_enabled
assert is_telemetry_enabled() is False
def test_telemetry_disabled_with_new_flag(self, monkeypatch):
"""Test that telemetry can be disabled with CUA_TELEMETRY_ENABLED=false."""
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", "false")
from core.telemetry import is_telemetry_enabled
assert is_telemetry_enabled() is False
@pytest.mark.parametrize("value", ["0", "false", "no", "off"])
def test_telemetry_disabled_with_various_values(self, monkeypatch, value):
"""Test that telemetry respects various disable values."""
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", value)
from core.telemetry import is_telemetry_enabled
assert is_telemetry_enabled() is False
@pytest.mark.parametrize("value", ["1", "true", "yes", "on"])
def test_telemetry_enabled_with_various_values(self, monkeypatch, value):
"""Test that telemetry respects various enable values."""
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", value)
from core.telemetry import is_telemetry_enabled
assert is_telemetry_enabled() is True
class TestPostHogTelemetryClient:
"""Test PostHogTelemetryClient class (SRP: Only tests client logic)."""
@patch("core.telemetry.posthog.posthog")
@patch("core.telemetry.posthog.Path")
def test_client_initialization(self, mock_path, mock_posthog, disable_telemetry):
"""Test that client initializes correctly."""
from core.telemetry.posthog import PostHogTelemetryClient
# Mock the storage directory
mock_storage_dir = MagicMock()
mock_storage_dir.exists.return_value = False
mock_path.return_value.parent.parent = MagicMock()
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
# Reset singleton
PostHogTelemetryClient.destroy_client()
client = PostHogTelemetryClient()
assert client is not None
assert hasattr(client, "installation_id")
assert hasattr(client, "initialized")
assert hasattr(client, "queued_events")
@patch("core.telemetry.posthog.posthog")
@patch("core.telemetry.posthog.Path")
def test_installation_id_generation(self, mock_path, mock_posthog, disable_telemetry):
"""Test that installation ID is generated if not exists."""
from core.telemetry.posthog import PostHogTelemetryClient
# Mock file system
mock_id_file = MagicMock()
mock_id_file.exists.return_value = False
mock_storage_dir = MagicMock()
mock_storage_dir.__truediv__.return_value = mock_id_file
mock_core_dir = MagicMock()
mock_core_dir.__truediv__.return_value = mock_storage_dir
mock_path.return_value.parent.parent = mock_core_dir
# Reset singleton
PostHogTelemetryClient.destroy_client()
client = PostHogTelemetryClient()
# Should have generated a new UUID
assert client.installation_id is not None
assert len(client.installation_id) == 36 # UUID format
@patch("core.telemetry.posthog.posthog")
@patch("core.telemetry.posthog.Path")
def test_installation_id_persistence(self, mock_path, mock_posthog, disable_telemetry):
"""Test that installation ID is read from file if exists."""
from core.telemetry.posthog import PostHogTelemetryClient
existing_id = "test-installation-id-123"
# Mock file system
mock_id_file = MagicMock()
mock_id_file.exists.return_value = True
mock_id_file.read_text.return_value = existing_id
mock_storage_dir = MagicMock()
mock_storage_dir.__truediv__.return_value = mock_id_file
mock_core_dir = MagicMock()
mock_core_dir.__truediv__.return_value = mock_storage_dir
mock_path.return_value.parent.parent = mock_core_dir
# Reset singleton
PostHogTelemetryClient.destroy_client()
client = PostHogTelemetryClient()
assert client.installation_id == existing_id
@patch("core.telemetry.posthog.posthog")
@patch("core.telemetry.posthog.Path")
def test_record_event_when_disabled(self, mock_path, mock_posthog, monkeypatch):
"""Test that events are not recorded when telemetry is disabled."""
from core.telemetry.posthog import PostHogTelemetryClient
# Disable telemetry explicitly using the correct environment variable
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", "false")
# Mock file system
mock_storage_dir = MagicMock()
mock_storage_dir.exists.return_value = False
mock_path.return_value.parent.parent = MagicMock()
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
# Reset singleton
PostHogTelemetryClient.destroy_client()
client = PostHogTelemetryClient()
client.record_event("test_event", {"key": "value"})
# PostHog capture should not be called at all when telemetry is disabled
mock_posthog.capture.assert_not_called()
@patch("core.telemetry.posthog.posthog")
@patch("core.telemetry.posthog.Path")
def test_record_event_when_enabled(self, mock_path, mock_posthog, monkeypatch):
"""Test that events are recorded when telemetry is enabled."""
from core.telemetry.posthog import PostHogTelemetryClient
# Enable telemetry
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", "true")
# Mock file system
mock_storage_dir = MagicMock()
mock_storage_dir.exists.return_value = False
mock_path.return_value.parent.parent = MagicMock()
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
# Reset singleton
PostHogTelemetryClient.destroy_client()
client = PostHogTelemetryClient()
client.initialized = True # Pretend it's initialized
event_name = "test_event"
event_props = {"key": "value"}
client.record_event(event_name, event_props)
# PostHog capture should be called
assert mock_posthog.capture.call_count >= 1
@patch("core.telemetry.posthog.posthog")
@patch("core.telemetry.posthog.Path")
def test_singleton_pattern(self, mock_path, mock_posthog, disable_telemetry):
"""Test that get_client returns the same instance."""
from core.telemetry.posthog import PostHogTelemetryClient
# Mock file system
mock_storage_dir = MagicMock()
mock_storage_dir.exists.return_value = False
mock_path.return_value.parent.parent = MagicMock()
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
# Reset singleton
PostHogTelemetryClient.destroy_client()
client1 = PostHogTelemetryClient.get_client()
client2 = PostHogTelemetryClient.get_client()
assert client1 is client2
class TestRecordEvent:
"""Test the public record_event function (SRP: Only tests public API)."""
@patch("core.telemetry.posthog.PostHogTelemetryClient")
def test_record_event_calls_client(self, mock_client_class, disable_telemetry):
"""Test that record_event delegates to the client."""
from core.telemetry import record_event
mock_client_instance = Mock()
mock_client_class.get_client.return_value = mock_client_instance
event_name = "test_event"
event_props = {"key": "value"}
record_event(event_name, event_props)
mock_client_instance.record_event.assert_called_once_with(event_name, event_props)
@patch("core.telemetry.posthog.PostHogTelemetryClient")
def test_record_event_without_properties(self, mock_client_class, disable_telemetry):
"""Test that record_event works without properties."""
from core.telemetry import record_event
mock_client_instance = Mock()
mock_client_class.get_client.return_value = mock_client_instance
event_name = "test_event"
record_event(event_name)
mock_client_instance.record_event.assert_called_once_with(event_name, {})
class TestDestroyTelemetryClient:
"""Test client destruction (SRP: Only tests cleanup)."""
@patch("core.telemetry.posthog.PostHogTelemetryClient")
def test_destroy_client_calls_class_method(self, mock_client_class):
"""Test that destroy_telemetry_client delegates correctly."""
from core.telemetry import destroy_telemetry_client
destroy_telemetry_client()
mock_client_class.destroy_client.assert_called_once()

View File

@@ -0,0 +1,51 @@
"""Pytest configuration and shared fixtures for mcp-server package tests.
This file contains shared fixtures and configuration for all mcp-server tests.
Following SRP: This file ONLY handles test setup/teardown.
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
@pytest.fixture
def mock_mcp_context():
"""Mock MCP context for testing.
Use this fixture to test MCP server logic without real MCP connections.
"""
context = AsyncMock()
context.request_context = AsyncMock()
context.session = Mock()
context.session.send_resource_updated = AsyncMock()
return context
@pytest.fixture
def mock_computer():
"""Mock Computer instance for MCP server tests.
Use this fixture to test MCP logic without real Computer operations.
"""
computer = AsyncMock()
computer.interface = AsyncMock()
computer.interface.screenshot = AsyncMock(return_value=b"fake_screenshot")
computer.interface.left_click = AsyncMock()
computer.interface.type = AsyncMock()
# Mock context manager
computer.__aenter__ = AsyncMock(return_value=computer)
computer.__aexit__ = AsyncMock()
return computer
@pytest.fixture
def disable_telemetry(monkeypatch):
"""Disable telemetry for tests.
Use this fixture to ensure no telemetry is sent during tests.
"""
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")

View File

@@ -0,0 +1,44 @@
"""Unit tests for mcp-server package.
This file tests ONLY basic MCP server functionality.
Following SRP: This file tests MCP server initialization.
All external dependencies are mocked.
"""
from unittest.mock import AsyncMock, Mock, patch
import pytest
class TestMCPServerImports:
"""Test MCP server module imports (SRP: Only tests imports)."""
def test_mcp_server_module_exists(self):
"""Test that mcp_server module can be imported."""
try:
import mcp_server
assert mcp_server is not None
except ImportError:
pytest.skip("mcp_server module not installed")
except SystemExit:
pytest.skip("MCP dependencies (mcp.server.fastmcp) not available")
class TestMCPServerInitialization:
"""Test MCP server initialization (SRP: Only tests initialization)."""
@pytest.mark.asyncio
async def test_mcp_server_can_be_imported(self):
"""Basic smoke test: verify MCP server components can be imported."""
try:
from mcp_server import server
assert server is not None
except ImportError:
pytest.skip("MCP server module not available")
except SystemExit:
pytest.skip("MCP dependencies (mcp.server.fastmcp) not available")
except Exception as e:
# Some initialization errors are acceptable in unit tests
pytest.skip(f"MCP server initialization requires specific setup: {e}")

View File

@@ -1,10 +0,0 @@
[bumpversion]
current_version = 0.2.1
commit = True
tag = True
tag_name = pylume-v{new_version}
message = Bump pylume to v{new_version}
[bumpversion:file:pylume/__init__.py]
search = __version__ = "{current_version}"
replace = __version__ = "{new_version}"

View File

@@ -1,46 +0,0 @@
<div align="center">
<h1>
<div class="image-wrapper" style="display: inline-block;">
<picture>
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="https://raw.githubusercontent.com/trycua/cua/main/img/logo_white.png" style="display: block; margin: auto;">
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="https://raw.githubusercontent.com/trycua/cua/main/img/logo_black.png" style="display: block; margin: auto;">
<img alt="Shows my svg">
</picture>
</div>
[![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#)
[![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85)
[![PyPI](https://img.shields.io/pypi/v/pylume?color=333333)](https://pypi.org/project/pylume/)
</h1>
</div>
**pylume** is a lightweight Python library based on [lume](https://github.com/trycua/lume) to create, run and manage macOS and Linux virtual machines (VMs) natively on Apple Silicon.
```bash
pip install pylume
```
## Usage
Please refer to this [Notebook](./samples/nb.ipynb) for a quickstart. More details about the underlying API used by pylume are available [here](https://github.com/trycua/lume/docs/API-Reference.md).
## Prebuilt Images
Pre-built images are available on [ghcr.io/trycua](https://github.com/orgs/trycua/packages).
These images come pre-configured with an SSH server and auto-login enabled.
## Contributing
We welcome and greatly appreciate contributions to lume! Whether you're improving documentation, adding new features, fixing bugs, or adding new VM images, your efforts help make pylume better for everyone.
Join our [Discord community](https://discord.com/invite/mVnXXpdE85) to discuss ideas or get assistance.
## License
lume is open-sourced under the MIT License - see the [LICENSE](LICENSE) file for details.
## Stargazers over time
[![Stargazers over time](https://starchart.cc/trycua/pylume.svg?variant=adaptive)](https://starchart.cc/trycua/pylume)

View File

@@ -1,9 +0,0 @@
"""
PyLume Python SDK - A client library for managing macOS VMs with PyLume.
"""
from pylume.exceptions import *
from pylume.models import *
from pylume.pylume import *
__version__ = "0.1.0"

View File

@@ -1,59 +0,0 @@
"""
PyLume Python SDK - A client library for managing macOS VMs with PyLume.
Example:
>>> from pylume import PyLume, VMConfig
>>> client = PyLume()
>>> config = VMConfig(name="my-vm", cpu=4, memory="8GB", disk_size="64GB")
>>> client.create_vm(config)
>>> client.run_vm("my-vm")
"""
# Import exceptions then all models
from .exceptions import (
LumeConfigError,
LumeConnectionError,
LumeError,
LumeImageError,
LumeNotFoundError,
LumeServerError,
LumeTimeoutError,
LumeVMError,
)
from .models import (
CloneSpec,
ImageInfo,
ImageList,
ImageRef,
SharedDirectory,
VMConfig,
VMRunOpts,
VMStatus,
VMUpdateOpts,
)
# Import main class last to avoid circular imports
from .pylume import PyLume
__version__ = "0.2.1"
__all__ = [
"PyLume",
"VMConfig",
"VMStatus",
"VMRunOpts",
"VMUpdateOpts",
"ImageRef",
"CloneSpec",
"SharedDirectory",
"ImageList",
"ImageInfo",
"LumeError",
"LumeServerError",
"LumeConnectionError",
"LumeTimeoutError",
"LumeNotFoundError",
"LumeConfigError",
"LumeVMError",
"LumeImageError",
]

View File

@@ -1,119 +0,0 @@
import asyncio
import json
import shlex
import subprocess
from typing import Any, Dict, Optional
from .exceptions import (
LumeConfigError,
LumeConnectionError,
LumeError,
LumeNotFoundError,
LumeServerError,
LumeTimeoutError,
)
class LumeClient:
def __init__(self, base_url: str, timeout: float = 60.0, debug: bool = False):
self.base_url = base_url
self.timeout = timeout
self.debug = debug
def _log_debug(self, message: str, **kwargs) -> None:
"""Log debug information if debug mode is enabled."""
if self.debug:
print(f"DEBUG: {message}")
if kwargs:
print(json.dumps(kwargs, indent=2))
async def _run_curl(
self,
method: str,
path: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
) -> Any:
"""Execute a curl command and return the response."""
url = f"{self.base_url}{path}"
if params:
param_str = "&".join(f"{k}={v}" for k, v in params.items())
url = f"{url}?{param_str}"
cmd = ["curl", "-X", method, "-s", "-w", "%{http_code}", "-m", str(self.timeout)]
if data is not None:
cmd.extend(["-H", "Content-Type: application/json", "-d", json.dumps(data)])
cmd.append(url)
self._log_debug(f"Running curl command: {' '.join(map(shlex.quote, cmd))}")
try:
process = await asyncio.create_subprocess_exec(
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
raise LumeConnectionError(f"Curl command failed: {stderr.decode()}")
# The last 3 characters are the status code
response = stdout.decode()
status_code = int(response[-3:])
response_body = response[:-3] # Remove status code from response
if status_code >= 400:
if status_code == 404:
raise LumeNotFoundError(f"Resource not found: {path}")
elif status_code == 400:
raise LumeConfigError(f"Invalid request: {response_body}")
elif status_code >= 500:
raise LumeServerError(f"Server error: {response_body}")
else:
raise LumeError(f"Request failed with status {status_code}: {response_body}")
return json.loads(response_body) if response_body.strip() else None
except asyncio.TimeoutError:
raise LumeTimeoutError(f"Request timed out after {self.timeout} seconds")
async def get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Any:
"""Make a GET request."""
return await self._run_curl("GET", path, params=params)
async def post(
self, path: str, data: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None
) -> Any:
"""Make a POST request."""
old_timeout = self.timeout
if timeout is not None:
self.timeout = timeout
try:
return await self._run_curl("POST", path, data=data)
finally:
self.timeout = old_timeout
async def patch(self, path: str, data: Dict[str, Any]) -> None:
"""Make a PATCH request."""
await self._run_curl("PATCH", path, data=data)
async def delete(self, path: str) -> None:
"""Make a DELETE request."""
await self._run_curl("DELETE", path)
def print_curl(self, method: str, path: str, data: Optional[Dict[str, Any]] = None) -> None:
"""Print equivalent curl command for debugging."""
curl_cmd = f"""curl -X {method} \\
'{self.base_url}{path}'"""
if data:
curl_cmd += f" \\\n -H 'Content-Type: application/json' \\\n -d '{json.dumps(data)}'"
print("\nEquivalent curl command:")
print(curl_cmd)
print()
async def close(self) -> None:
"""Close the client resources."""
pass # No shared resources to clean up

View File

@@ -1,54 +0,0 @@
from typing import Optional
class LumeError(Exception):
"""Base exception for all PyLume errors."""
pass
class LumeServerError(LumeError):
"""Raised when there's an error with the PyLume server."""
def __init__(
self, message: str, status_code: Optional[int] = None, response_text: Optional[str] = None
):
self.status_code = status_code
self.response_text = response_text
super().__init__(message)
class LumeConnectionError(LumeError):
"""Raised when there's an error connecting to the PyLume server."""
pass
class LumeTimeoutError(LumeError):
"""Raised when a request to the PyLume server times out."""
pass
class LumeNotFoundError(LumeError):
"""Raised when a requested resource is not found."""
pass
class LumeConfigError(LumeError):
"""Raised when there's an error with the configuration."""
pass
class LumeVMError(LumeError):
"""Raised when there's an error with a VM operation."""
pass
class LumeImageError(LumeError):
"""Raised when there's an error with an image operation."""
pass

Binary file not shown.

View File

@@ -1,265 +0,0 @@
import re
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, RootModel, computed_field, validator
class DiskInfo(BaseModel):
"""Information about disk storage allocation.
Attributes:
total: Total disk space in bytes
allocated: Currently allocated disk space in bytes
"""
total: int
allocated: int
class VMConfig(BaseModel):
"""Configuration for creating a new VM.
Note: Memory and disk sizes should be specified with units (e.g., "4GB", "64GB")
Attributes:
name: Name of the virtual machine
os: Operating system type, either "macOS" or "linux"
cpu: Number of CPU cores to allocate
memory: Amount of memory to allocate with units
disk_size: Size of the disk to create with units
display: Display resolution in format "widthxheight"
ipsw: IPSW path or 'latest' for macOS VMs, None for other OS types
"""
name: str
os: Literal["macOS", "linux"] = "macOS"
cpu: int = Field(default=2, ge=1)
memory: str = "4GB"
disk_size: str = Field(default="64GB", alias="diskSize")
display: str = "1024x768"
ipsw: Optional[str] = Field(default=None, description="IPSW path or 'latest', for macOS VMs")
class Config:
populate_by_alias = True
class SharedDirectory(BaseModel):
"""Configuration for a shared directory.
Attributes:
host_path: Path to the directory on the host system
read_only: Whether the directory should be mounted as read-only
"""
host_path: str = Field(..., alias="hostPath") # Allow host_path but serialize as hostPath
read_only: bool = False
class Config:
populate_by_name = True # Allow both alias and original name
alias_generator = lambda s: "".join(
word.capitalize() if i else word for i, word in enumerate(s.split("_"))
)
class VMRunOpts(BaseModel):
"""Configuration for running a VM.
Args:
no_display: Whether to not display the VNC client
shared_directories: List of directories to share with the VM
"""
no_display: bool = Field(default=False, alias="noDisplay")
shared_directories: Optional[list[SharedDirectory]] = Field(
default=None, alias="sharedDirectories"
)
model_config = ConfigDict(
populate_by_name=True,
alias_generator=lambda s: "".join(
word.capitalize() if i else word for i, word in enumerate(s.split("_"))
),
)
def model_dump(self, **kwargs):
"""Export model data with proper field name conversion.
Converts shared directory fields to match API expectations when using aliases.
Args:
**kwargs: Keyword arguments passed to parent model_dump method
Returns:
dict: Model data with properly formatted field names
"""
data = super().model_dump(**kwargs)
# Convert shared directory fields to match API expectations
if self.shared_directories and "by_alias" in kwargs and kwargs["by_alias"]:
data["sharedDirectories"] = [
{"hostPath": d.host_path, "readOnly": d.read_only} for d in self.shared_directories
]
# Remove the snake_case version if it exists
data.pop("shared_directories", None)
return data
class VMStatus(BaseModel):
"""Status information for a virtual machine.
Attributes:
name: Name of the virtual machine
status: Current status of the VM
os: Operating system type
cpu_count: Number of CPU cores allocated
memory_size: Amount of memory allocated in bytes
disk_size: Disk storage information
vnc_url: URL for VNC connection if available
ip_address: IP address of the VM if available
"""
name: str
status: str
os: Literal["macOS", "linux"]
cpu_count: int = Field(alias="cpuCount")
memory_size: int = Field(alias="memorySize") # API returns memory size in bytes
disk_size: DiskInfo = Field(alias="diskSize")
vnc_url: Optional[str] = Field(default=None, alias="vncUrl")
ip_address: Optional[str] = Field(default=None, alias="ipAddress")
class Config:
populate_by_alias = True
@computed_field
@property
def state(self) -> str:
"""Get the current state of the VM.
Returns:
str: Current VM status
"""
return self.status
@computed_field
@property
def cpu(self) -> int:
"""Get the number of CPU cores.
Returns:
int: Number of CPU cores allocated to the VM
"""
return self.cpu_count
@computed_field
@property
def memory(self) -> str:
"""Get memory allocation in human-readable format.
Returns:
str: Memory size formatted as "{size}GB"
"""
# Convert bytes to GB
gb = self.memory_size / (1024 * 1024 * 1024)
return f"{int(gb)}GB"
class VMUpdateOpts(BaseModel):
"""Options for updating VM configuration.
Attributes:
cpu: Number of CPU cores to update to
memory: Amount of memory to update to with units
disk_size: Size of disk to update to with units
"""
cpu: Optional[int] = None
memory: Optional[str] = None
disk_size: Optional[str] = None
class ImageRef(BaseModel):
"""Reference to a VM image.
Attributes:
image: Name of the image
tag: Tag version of the image
registry: Registry hostname where image is stored
organization: Organization or namespace in the registry
"""
image: str
tag: str = "latest"
registry: Optional[str] = "ghcr.io"
organization: Optional[str] = "trycua"
def model_dump(self, **kwargs):
"""Override model_dump to return just the image:tag format.
Args:
**kwargs: Keyword arguments (ignored)
Returns:
str: Image reference in "image:tag" format
"""
return f"{self.image}:{self.tag}"
class CloneSpec(BaseModel):
"""Specification for cloning a VM.
Attributes:
name: Name of the source VM to clone
new_name: Name for the new cloned VM
"""
name: str
new_name: str = Field(alias="newName")
class Config:
populate_by_alias = True
class ImageInfo(BaseModel):
"""Model for individual image information.
Attributes:
imageId: Unique identifier for the image
"""
imageId: str
class ImageList(RootModel):
"""Response model for the images endpoint.
A list-like container for ImageInfo objects that provides
iteration and indexing capabilities.
"""
root: List[ImageInfo]
def __iter__(self):
"""Iterate over the image list.
Returns:
Iterator over ImageInfo objects
"""
return iter(self.root)
def __getitem__(self, item):
"""Get an item from the image list by index.
Args:
item: Index or slice to retrieve
Returns:
ImageInfo or list of ImageInfo objects
"""
return self.root[item]
def __len__(self):
"""Get the number of images in the list.
Returns:
int: Number of images in the list
"""
return len(self.root)

View File

@@ -1,315 +0,0 @@
import asyncio
import json
import os
import re
import signal
import subprocess
import sys
import time
from functools import wraps
from typing import Any, Callable, List, Optional, TypeVar, Union
from .client import LumeClient
from .exceptions import (
LumeConfigError,
LumeConnectionError,
LumeError,
LumeImageError,
LumeNotFoundError,
LumeServerError,
LumeTimeoutError,
LumeVMError,
)
from .models import (
CloneSpec,
ImageList,
ImageRef,
SharedDirectory,
VMConfig,
VMRunOpts,
VMStatus,
VMUpdateOpts,
)
from .server import LumeServer
# Type variable for the decorator
T = TypeVar("T")
def ensure_server(func: Callable[..., T]) -> Callable[..., T]:
"""Decorator to ensure server is running before executing the method."""
@wraps(func)
async def wrapper(self: "PyLume", *args: Any, **kwargs: Any) -> T:
# ensure_running is an async method, so we need to await it
await self.server.ensure_running()
# Initialize client if needed
await self._init_client()
return await func(self, *args, **kwargs) # type: ignore
return wrapper # type: ignore
class PyLume:
def __init__(
self,
debug: bool = False,
server_start_timeout: int = 60,
port: Optional[int] = None,
use_existing_server: bool = False,
host: str = "localhost",
):
"""Initialize the async PyLume client.
Args:
debug: Enable debug logging
auto_start_server: Whether to automatically start the lume server if not running
server_start_timeout: Timeout in seconds to wait for server to start
port: Port number for the lume server. Required when use_existing_server is True.
use_existing_server: If True, will try to connect to an existing server on the specified port
instead of starting a new one.
host: Host to use for connections (e.g., "localhost", "127.0.0.1", "host.docker.internal")
"""
if use_existing_server and port is None:
raise LumeConfigError("Port must be specified when using an existing server")
self.server = LumeServer(
debug=debug,
server_start_timeout=server_start_timeout,
port=port,
use_existing_server=use_existing_server,
host=host,
)
self.client = None
async def __aenter__(self) -> "PyLume":
"""Async context manager entry."""
if self.server.use_existing_server:
# Just ensure base_url is set for existing server
if self.server.requested_port is None:
raise LumeConfigError("Port must be specified when using an existing server")
if not self.server.base_url:
self.server.port = self.server.requested_port
self.server.base_url = f"http://{self.server.host}:{self.server.port}/lume"
# Ensure the server is running (will connect to existing or start new as needed)
await self.server.ensure_running()
# Initialize the client
await self._init_client()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
"""Async context manager exit."""
if self.client is not None:
await self.client.close()
await self.server.stop()
async def _init_client(self) -> None:
"""Initialize the client if not already initialized."""
if self.client is None:
if self.server.base_url is None:
raise RuntimeError("Server base URL not set")
self.client = LumeClient(self.server.base_url, debug=self.server.debug)
def _log_debug(self, message: str, **kwargs) -> None:
"""Log debug information if debug mode is enabled."""
if self.server.debug:
print(f"DEBUG: {message}")
if kwargs:
print(json.dumps(kwargs, indent=2))
async def _handle_api_error(self, e: Exception, operation: str) -> None:
"""Handle API errors and raise appropriate custom exceptions."""
if isinstance(e, subprocess.SubprocessError):
raise LumeConnectionError(f"Failed to connect to PyLume server: {str(e)}")
elif isinstance(e, asyncio.TimeoutError):
raise LumeTimeoutError(f"Request timed out: {str(e)}")
if not hasattr(e, "status") and not isinstance(e, subprocess.CalledProcessError):
raise LumeServerError(f"Unknown error during {operation}: {str(e)}")
status_code = getattr(e, "status", 500)
response_text = str(e)
self._log_debug(
f"{operation} request failed", status_code=status_code, response_text=response_text
)
if status_code == 404:
raise LumeNotFoundError(f"Resource not found during {operation}")
elif status_code == 400:
raise LumeConfigError(f"Invalid configuration for {operation}: {response_text}")
elif status_code >= 500:
raise LumeServerError(
f"Server error during {operation}",
status_code=status_code,
response_text=response_text,
)
else:
raise LumeServerError(
f"Error during {operation}", status_code=status_code, response_text=response_text
)
async def _read_output(self) -> None:
"""Read and log server output."""
try:
while True:
if not self.server.server_process or self.server.server_process.poll() is not None:
self._log_debug("Server process ended")
break
# Read stdout without blocking
if self.server.server_process.stdout:
while True:
line = self.server.server_process.stdout.readline()
if not line:
break
line = line.strip()
self._log_debug(f"Server stdout: {line}")
if "Server started" in line.decode("utf-8"):
self._log_debug("Detected server started message")
return
# Read stderr without blocking
if self.server.server_process.stderr:
while True:
line = self.server.server_process.stderr.readline()
if not line:
break
line = line.strip()
self._log_debug(f"Server stderr: {line}")
if "error" in line.decode("utf-8").lower():
raise RuntimeError(f"Server error: {line}")
await asyncio.sleep(0.1) # Small delay to prevent CPU spinning
except Exception as e:
self._log_debug(f"Error in output reader: {str(e)}")
raise
@ensure_server
async def create_vm(self, spec: Union[VMConfig, dict]) -> None:
"""Create a VM with the given configuration."""
# Ensure client is initialized
await self._init_client()
if isinstance(spec, VMConfig):
spec = spec.model_dump(by_alias=True, exclude_none=True)
# Suppress optional attribute access errors
self.client.print_curl("POST", "/vms", spec) # type: ignore[attr-defined]
await self.client.post("/vms", spec) # type: ignore[attr-defined]
@ensure_server
async def run_vm(self, name: str, opts: Optional[Union[VMRunOpts, dict]] = None) -> None:
"""Run a VM."""
if opts is None:
opts = VMRunOpts(no_display=False) # type: ignore[attr-defined]
elif isinstance(opts, dict):
opts = VMRunOpts(**opts)
payload = opts.model_dump(by_alias=True, exclude_none=True)
self.client.print_curl("POST", f"/vms/{name}/run", payload) # type: ignore[attr-defined]
await self.client.post(f"/vms/{name}/run", payload) # type: ignore[attr-defined]
@ensure_server
async def list_vms(self) -> List[VMStatus]:
"""List all VMs."""
data = await self.client.get("/vms") # type: ignore[attr-defined]
return [VMStatus.model_validate(vm) for vm in data]
@ensure_server
async def get_vm(self, name: str) -> VMStatus:
"""Get VM details."""
data = await self.client.get(f"/vms/{name}") # type: ignore[attr-defined]
return VMStatus.model_validate(data)
@ensure_server
async def update_vm(self, name: str, params: Union[VMUpdateOpts, dict]) -> None:
"""Update VM settings."""
if isinstance(params, dict):
params = VMUpdateOpts(**params)
payload = params.model_dump(by_alias=True, exclude_none=True)
self.client.print_curl("PATCH", f"/vms/{name}", payload) # type: ignore[attr-defined]
await self.client.patch(f"/vms/{name}", payload) # type: ignore[attr-defined]
@ensure_server
async def stop_vm(self, name: str) -> None:
"""Stop a VM."""
await self.client.post(f"/vms/{name}/stop") # type: ignore[attr-defined]
@ensure_server
async def delete_vm(self, name: str) -> None:
"""Delete a VM."""
await self.client.delete(f"/vms/{name}") # type: ignore[attr-defined]
@ensure_server
async def pull_image(
self, spec: Union[ImageRef, dict, str], name: Optional[str] = None
) -> None:
"""Pull a VM image."""
await self._init_client()
if isinstance(spec, str):
if ":" in spec:
image_str = spec
else:
image_str = f"{spec}:latest"
registry = "ghcr.io"
organization = "trycua"
elif isinstance(spec, dict):
image = spec.get("image", "")
tag = spec.get("tag", "latest")
image_str = f"{image}:{tag}"
registry = spec.get("registry", "ghcr.io")
organization = spec.get("organization", "trycua")
else:
image_str = f"{spec.image}:{spec.tag}"
registry = spec.registry
organization = spec.organization
payload = {
"image": image_str,
"name": name,
"registry": registry,
"organization": organization,
}
self.client.print_curl("POST", "/pull", payload) # type: ignore[attr-defined]
await self.client.post("/pull", payload, timeout=300.0) # type: ignore[attr-defined]
@ensure_server
async def clone_vm(self, name: str, new_name: str) -> None:
"""Clone a VM with the given name to a new VM with new_name."""
config = CloneSpec(name=name, newName=new_name)
self.client.print_curl("POST", "/vms/clone", config.model_dump()) # type: ignore[attr-defined]
await self.client.post("/vms/clone", config.model_dump()) # type: ignore[attr-defined]
@ensure_server
async def get_latest_ipsw_url(self) -> str:
"""Get the latest IPSW URL."""
await self._init_client()
data = await self.client.get("/ipsw") # type: ignore[attr-defined]
return data["url"]
@ensure_server
async def get_images(self, organization: Optional[str] = None) -> ImageList:
"""Get list of available images."""
await self._init_client()
params = {"organization": organization} if organization else None
data = await self.client.get("/images", params) # type: ignore[attr-defined]
return ImageList(root=data)
async def close(self) -> None:
"""Close the client and stop the server."""
if self.client is not None:
await self.client.close()
self.client = None
await asyncio.sleep(1)
await self.server.stop()
async def _ensure_client(self) -> None:
"""Ensure client is initialized."""
if self.client is None:
await self._init_client()

View File

@@ -1,481 +0,0 @@
import asyncio
import json
import logging
import os
import random
import shlex
import signal
import socket
import subprocess
import sys
import tempfile
import time
from logging import getLogger
from typing import Optional
from .exceptions import LumeConnectionError
class LumeServer:
def __init__(
self,
debug: bool = False,
server_start_timeout: int = 60,
port: Optional[int] = None,
use_existing_server: bool = False,
host: str = "localhost",
):
"""Initialize the LumeServer.
Args:
debug: Enable debug logging
server_start_timeout: Timeout in seconds to wait for server to start
port: Specific port to use for the server
use_existing_server: If True, will try to connect to an existing server
instead of starting a new one
host: Host to use for connections (e.g., "localhost", "127.0.0.1", "host.docker.internal")
"""
self.debug = debug
self.server_start_timeout = server_start_timeout
self.server_process = None
self.output_file = None
self.requested_port = port
self.port = None
self.base_url = None
self.use_existing_server = use_existing_server
self.host = host
# Configure logging
self.logger = getLogger("pylume.server")
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(logging.DEBUG if debug else logging.INFO)
self.logger.debug(f"Server initialized with host: {self.host}")
def _check_port_available(self, port: int) -> bool:
"""Check if a port is available."""
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(0.5)
result = s.connect_ex(("127.0.0.1", port))
if result == 0: # Port is in use on localhost
return False
except:
pass
# Check the specified host (e.g., "host.docker.internal") if it's not a localhost alias
if self.host not in ["localhost", "127.0.0.1"]:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(0.5)
result = s.connect_ex((self.host, port))
if result == 0: # Port is in use on host
return False
except:
pass
return True
def _get_server_port(self) -> int:
"""Get an available port for the server."""
# Use requested port if specified
if self.requested_port is not None:
if not self._check_port_available(self.requested_port):
raise RuntimeError(f"Requested port {self.requested_port} is not available")
return self.requested_port
# Find a free port
for _ in range(10): # Try up to 10 times
port = random.randint(49152, 65535)
if self._check_port_available(port):
return port
raise RuntimeError("Could not find an available port")
async def _ensure_server_running(self) -> None:
"""Ensure the lume server is running, start it if it's not."""
try:
self.logger.debug("Checking if lume server is running...")
# Try to connect to the server with a short timeout
cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "5", f"{self.base_url}/vms"]
process = await asyncio.create_subprocess_exec(
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
response = stdout.decode()
status_code = int(response[-3:])
if status_code == 200:
self.logger.debug("PyLume server is running")
return
self.logger.debug("PyLume server not running, attempting to start it")
# Server not running, try to start it
lume_path = os.path.join(os.path.dirname(__file__), "lume")
if not os.path.exists(lume_path):
raise RuntimeError(f"Could not find lume binary at {lume_path}")
# Make sure the file is executable
os.chmod(lume_path, 0o755)
# Create a temporary file for server output
self.output_file = tempfile.NamedTemporaryFile(mode="w+", delete=False)
self.logger.debug(f"Using temporary file for server output: {self.output_file.name}")
# Start the server
self.logger.debug(f"Starting lume server with: {lume_path} serve --port {self.port}")
# Start server in background using subprocess.Popen
try:
self.server_process = subprocess.Popen(
[lume_path, "serve", "--port", str(self.port)],
stdout=self.output_file,
stderr=self.output_file,
cwd=os.path.dirname(lume_path),
start_new_session=True, # Run in new session to avoid blocking
)
except Exception as e:
self.output_file.close()
os.unlink(self.output_file.name)
raise RuntimeError(f"Failed to start lume server process: {str(e)}")
# Wait for server to start
self.logger.debug(
f"Waiting up to {self.server_start_timeout} seconds for server to start..."
)
start_time = time.time()
server_ready = False
last_size = 0
while time.time() - start_time < self.server_start_timeout:
if self.server_process.poll() is not None:
# Process has terminated
self.output_file.seek(0)
output = self.output_file.read()
self.output_file.close()
os.unlink(self.output_file.name)
error_msg = (
f"Server process terminated unexpectedly.\n"
f"Exit code: {self.server_process.returncode}\n"
f"Output: {output}"
)
raise RuntimeError(error_msg)
# Check output file for server ready message
self.output_file.seek(0, os.SEEK_END)
size = self.output_file.tell()
if size > last_size: # Only read if there's new content
self.output_file.seek(last_size)
new_output = self.output_file.read()
if new_output.strip(): # Only log non-empty output
self.logger.debug(f"Server output: {new_output.strip()}")
last_size = size
if "Server started" in new_output:
server_ready = True
self.logger.debug("Server startup detected")
break
# Try to connect to the server periodically
try:
cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "5", f"{self.base_url}/vms"]
process = await asyncio.create_subprocess_exec(
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode == 0:
response = stdout.decode()
status_code = int(response[-3:])
if status_code == 200:
server_ready = True
self.logger.debug("Server is responding to requests")
break
except:
pass # Server not ready yet
await asyncio.sleep(1.0)
if not server_ready:
# Cleanup if server didn't start
if self.server_process:
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.server_process.kill()
self.output_file.close()
os.unlink(self.output_file.name)
raise RuntimeError(
f"Failed to start lume server after {self.server_start_timeout} seconds. "
"Check the debug output for more details."
)
# Give the server a moment to fully initialize
await asyncio.sleep(2.0)
# Verify server is responding
try:
cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "10", f"{self.base_url}/vms"]
process = await asyncio.create_subprocess_exec(
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
raise RuntimeError(f"Curl command failed: {stderr.decode()}")
response = stdout.decode()
status_code = int(response[-3:])
if status_code != 200:
raise RuntimeError(f"Server returned status code {status_code}")
self.logger.debug("PyLume server started successfully")
except Exception as e:
self.logger.debug(f"Server verification failed: {str(e)}")
if self.server_process:
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.server_process.kill()
self.output_file.close()
os.unlink(self.output_file.name)
raise RuntimeError(f"Server started but is not responding: {str(e)}")
self.logger.debug("Server startup completed successfully")
except Exception as e:
raise RuntimeError(f"Failed to start lume server: {str(e)}")
async def _start_server(self) -> None:
"""Start the lume server using the lume executable."""
self.logger.debug("Starting PyLume server")
# Get absolute path to lume executable in the same directory as this file
lume_path = os.path.join(os.path.dirname(__file__), "lume")
if not os.path.exists(lume_path):
raise RuntimeError(f"Could not find lume binary at {lume_path}")
try:
# Make executable
os.chmod(lume_path, 0o755)
# Get and validate port
self.port = self._get_server_port()
self.base_url = f"http://{self.host}:{self.port}/lume"
# Set up output handling
self.output_file = tempfile.NamedTemporaryFile(mode="w+", delete=False)
# Start the server process with the lume executable
env = os.environ.copy()
env["RUST_BACKTRACE"] = "1" # Enable backtrace for better error reporting
# Specify the host to bind to (0.0.0.0 to allow external connections)
self.server_process = subprocess.Popen(
[lume_path, "serve", "--port", str(self.port)],
stdout=self.output_file,
stderr=subprocess.STDOUT,
cwd=os.path.dirname(lume_path), # Run from same directory as executable
env=env,
)
# Wait for server to initialize
await asyncio.sleep(2)
await self._wait_for_server()
except Exception as e:
await self._cleanup()
raise RuntimeError(f"Failed to start lume server process: {str(e)}")
async def _tail_log(self) -> None:
"""Read and display server log output in debug mode."""
while True:
try:
self.output_file.seek(0, os.SEEK_END) # type: ignore[attr-defined]
line = self.output_file.readline() # type: ignore[attr-defined]
if line:
line = line.strip()
if line:
print(f"SERVER: {line}")
if self.server_process.poll() is not None: # type: ignore[attr-defined]
print("Server process ended")
break
await asyncio.sleep(0.1)
except Exception as e:
print(f"Error reading log: {e}")
await asyncio.sleep(0.1)
async def _wait_for_server(self) -> None:
"""Wait for server to start and become responsive with increased timeout."""
start_time = time.time()
while time.time() - start_time < self.server_start_timeout:
if self.server_process.poll() is not None: # type: ignore[attr-defined]
error_msg = await self._get_error_output()
await self._cleanup()
raise RuntimeError(error_msg)
try:
await self._verify_server()
self.logger.debug("Server is now responsive")
return
except Exception as e:
self.logger.debug(f"Server not ready yet: {str(e)}")
await asyncio.sleep(1.0)
await self._cleanup()
raise RuntimeError(f"Server failed to start after {self.server_start_timeout} seconds")
async def _verify_server(self) -> None:
"""Verify server is responding to requests."""
try:
cmd = [
"curl",
"-s",
"-w",
"%{http_code}",
"-m",
"10",
f"http://{self.host}:{self.port}/lume/vms",
]
process = await asyncio.create_subprocess_exec(
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
raise RuntimeError(f"Curl command failed: {stderr.decode()}")
response = stdout.decode()
status_code = int(response[-3:])
if status_code != 200:
raise RuntimeError(f"Server returned status code {status_code}")
self.logger.debug("PyLume server started successfully")
except Exception as e:
raise RuntimeError(f"Server not responding: {str(e)}")
async def _get_error_output(self) -> str:
"""Get error output from the server process."""
if not self.output_file:
return "No output available"
self.output_file.seek(0)
output = self.output_file.read()
return (
f"Server process terminated unexpectedly.\n"
f"Exit code: {self.server_process.returncode}\n" # type: ignore[attr-defined]
f"Output: {output}"
)
async def _cleanup(self) -> None:
"""Clean up all server resources."""
if self.server_process:
try:
self.server_process.terminate()
try:
self.server_process.wait(timeout=5)
except subprocess.TimeoutExpired:
self.server_process.kill()
except:
pass
self.server_process = None
# Clean up output file
if self.output_file:
try:
self.output_file.close()
os.unlink(self.output_file.name)
except Exception as e:
self.logger.debug(f"Error cleaning up output file: {e}")
self.output_file = None
async def ensure_running(self) -> None:
"""Ensure the server is running.
If use_existing_server is True, will only try to connect to an existing server.
Otherwise will:
1. Try to connect to an existing server on the specified port
2. If that fails and not in Docker, start a new server
3. If in Docker and no existing server is found, raise an error
"""
# First check if we're in Docker
in_docker = os.path.exists("/.dockerenv") or (
os.path.exists("/proc/1/cgroup") and "docker" in open("/proc/1/cgroup", "r").read()
)
# If using a non-localhost host like host.docker.internal, set up the connection details
if self.host not in ["localhost", "127.0.0.1"]:
if self.requested_port is None:
raise RuntimeError("Port must be specified when using a remote host")
self.port = self.requested_port
self.base_url = f"http://{self.host}:{self.port}/lume"
self.logger.debug(f"Using remote host server at {self.base_url}")
# Try to verify the server is accessible
try:
await self._verify_server()
self.logger.debug("Successfully connected to remote server")
return
except Exception as e:
if self.use_existing_server or in_docker:
# If explicitly requesting an existing server or in Docker, we can't start a new one
raise RuntimeError(
f"Failed to connect to remote server at {self.base_url}: {str(e)}"
)
else:
self.logger.debug(f"Remote server not available at {self.base_url}: {str(e)}")
# Fall back to localhost for starting a new server
self.host = "localhost"
# If explicitly using an existing server, verify it's running
if self.use_existing_server:
if self.requested_port is None:
raise RuntimeError("Port must be specified when using an existing server")
self.port = self.requested_port
self.base_url = f"http://{self.host}:{self.port}/lume"
try:
await self._verify_server()
self.logger.debug("Successfully connected to existing server")
except Exception as e:
raise RuntimeError(
f"Failed to connect to existing server at {self.base_url}: {str(e)}"
)
else:
# Try to connect to an existing server first
if self.requested_port is not None:
self.port = self.requested_port
self.base_url = f"http://{self.host}:{self.port}/lume"
try:
await self._verify_server()
self.logger.debug("Successfully connected to existing server")
return
except Exception:
self.logger.debug(f"No existing server found at {self.base_url}")
# If in Docker and can't connect to existing server, raise an error
if in_docker:
raise RuntimeError(
f"Failed to connect to server at {self.base_url} and cannot start a new server in Docker"
)
# Start a new server
self.logger.debug("Starting a new server instance")
await self._start_server()
async def stop(self) -> None:
"""Stop the server if we're managing it."""
if not self.use_existing_server:
self.logger.debug("Stopping lume server...")
await self._cleanup()

View File

@@ -1,51 +0,0 @@
[build-system]
build-backend = "pdm.backend"
requires = ["pdm-backend"]
[project]
authors = [{ name = "TryCua", email = "gh@trycua.com" }]
classifiers = [
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: MacOS :: MacOS X",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
dependencies = ["pydantic>=2.11.1"]
description = "Python SDK for lume - run macOS and Linux VMs on Apple Silicon"
dynamic = ["version"]
keywords = ["apple-silicon", "macos", "virtualization", "vm"]
license = { text = "MIT" }
name = "pylume"
readme = "README.md"
requires-python = ">=3.12"
[tool.pdm.version]
path = "pylume/__init__.py"
source = "file"
[project.urls]
homepage = "https://github.com/trycua/pylume"
repository = "https://github.com/trycua/pylume"
[tool.pdm]
distribution = true
[tool.pdm.dev-dependencies]
dev = [
"black>=23.0.0",
"isort>=5.12.0",
"pytest-asyncio>=0.23.0",
"pytest>=7.0.0",
]
[tool.pytest.ini_options]
asyncio_mode = "auto"
python_files = "test_*.py"
testpaths = ["tests"]
[tool.pdm.build]
includes = ["pylume/"]
source-includes = ["LICENSE", "README.md", "tests/"]

View File

@@ -0,0 +1,23 @@
"""Pytest configuration for pylume tests.
This module provides test fixtures for the pylume package.
Note: This package has macOS-specific dependencies and will skip tests
if the required modules are not available.
"""
from unittest.mock import Mock, patch
import pytest
@pytest.fixture
def mock_subprocess():
with patch("subprocess.run") as mock_run:
mock_run.return_value = Mock(returncode=0, stdout="", stderr="")
yield mock_run
@pytest.fixture
def mock_requests():
with patch("requests.get") as mock_get, patch("requests.post") as mock_post:
yield {"get": mock_get, "post": mock_post}

View File

@@ -0,0 +1,38 @@
"""Unit tests for pylume package.
This file tests ONLY basic pylume functionality.
Following SRP: This file tests pylume module imports and basic operations.
All external dependencies are mocked.
"""
import pytest
class TestPylumeImports:
"""Test pylume module imports (SRP: Only tests imports)."""
def test_pylume_module_exists(self):
"""Test that pylume module can be imported."""
try:
import pylume
assert pylume is not None
except ImportError:
pytest.skip("pylume module not installed")
class TestPylumeInitialization:
"""Test pylume initialization (SRP: Only tests initialization)."""
def test_pylume_can_be_imported(self):
"""Basic smoke test: verify pylume components can be imported."""
try:
import pylume
# Check for basic attributes
assert pylume is not None
except ImportError:
pytest.skip("pylume module not available")
except Exception as e:
# Some initialization errors are acceptable in unit tests
pytest.skip(f"pylume initialization requires specific setup: {e}")

View File

@@ -0,0 +1,24 @@
"""Pytest configuration for som tests.
This module provides test fixtures for the som (Set-of-Mark) package.
The som package depends on heavy ML models and will skip tests if not available.
"""
from unittest.mock import Mock, patch
import pytest
@pytest.fixture
def mock_torch():
with patch("torch.load") as mock_load:
mock_load.return_value = Mock()
yield mock_load
@pytest.fixture
def mock_icon_detector():
with patch("omniparser.IconDetector") as mock_detector:
instance = Mock()
mock_detector.return_value = instance
yield instance

View File

@@ -1,13 +1,73 @@
# """Basic tests for the omniparser package."""
"""Unit tests for som package (Set-of-Mark).
# import pytest
# from omniparser import IconDetector
This file tests ONLY basic som functionality.
Following SRP: This file tests som module imports and basic operations.
All external dependencies (ML models, OCR) are mocked.
"""
# def test_icon_detector_import():
# """Test that we can import the IconDetector class."""
# assert IconDetector is not None
import pytest
# def test_icon_detector_init():
# """Test that we can create an IconDetector instance."""
# detector = IconDetector(force_cpu=True)
# assert detector is not None
class TestSomImports:
"""Test som module imports (SRP: Only tests imports)."""
def test_som_module_exists(self):
"""Test that som module can be imported."""
try:
import som
assert som is not None
except ImportError:
pytest.skip("som module not installed")
def test_omniparser_import(self):
"""Test that OmniParser can be imported."""
try:
from som import OmniParser
assert OmniParser is not None
except ImportError:
pytest.skip("som module not available")
except Exception as e:
pytest.skip(f"som initialization requires ML models: {e}")
def test_models_import(self):
"""Test that model classes can be imported."""
try:
from som import BoundingBox, ParseResult, UIElement
assert BoundingBox is not None
assert UIElement is not None
assert ParseResult is not None
except ImportError:
pytest.skip("som models not available")
except Exception as e:
pytest.skip(f"som models require dependencies: {e}")
class TestSomModels:
"""Test som data models (SRP: Only tests model structure)."""
def test_bounding_box_structure(self):
"""Test BoundingBox class structure."""
try:
from som import BoundingBox
# Check the class exists and has expected structure
assert hasattr(BoundingBox, "__init__")
except ImportError:
pytest.skip("som models not available")
except Exception as e:
pytest.skip(f"som models require dependencies: {e}")
def test_ui_element_structure(self):
"""Test UIElement class structure."""
try:
from som import UIElement
# Check the class exists and has expected structure
assert hasattr(UIElement, "__init__")
except ImportError:
pytest.skip("som models not available")
except Exception as e:
pytest.skip(f"som models require dependencies: {e}")