mirror of
https://github.com/trycua/computer.git
synced 2025-12-31 18:40:04 -06:00
Merge branch 'main' into feat/add-desktop-commands
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.35
|
||||
current_version = 0.4.36
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = agent-v{new_version}
|
||||
|
||||
@@ -28,8 +28,12 @@ class AsyncComputerHandler(Protocol):
|
||||
"""Get screen dimensions as (width, height)."""
|
||||
...
|
||||
|
||||
async def screenshot(self) -> str:
|
||||
"""Take a screenshot and return as base64 string."""
|
||||
async def screenshot(self, text: Optional[str] = None) -> str:
|
||||
"""Take a screenshot and return as base64 string.
|
||||
|
||||
Args:
|
||||
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
|
||||
"""
|
||||
...
|
||||
|
||||
async def click(self, x: int, y: int, button: str = "left") -> None:
|
||||
|
||||
@@ -36,8 +36,12 @@ class cuaComputerHandler(AsyncComputerHandler):
|
||||
screen_size = await self.interface.get_screen_size()
|
||||
return screen_size["width"], screen_size["height"]
|
||||
|
||||
async def screenshot(self) -> str:
|
||||
"""Take a screenshot and return as base64 string."""
|
||||
async def screenshot(self, text: Optional[str] = None) -> str:
|
||||
"""Take a screenshot and return as base64 string.
|
||||
|
||||
Args:
|
||||
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
|
||||
"""
|
||||
assert self.interface is not None
|
||||
screenshot_bytes = await self.interface.screenshot()
|
||||
return base64.b64encode(screenshot_bytes).decode("utf-8")
|
||||
|
||||
@@ -122,8 +122,12 @@ class CustomComputerHandler(AsyncComputerHandler):
|
||||
|
||||
return self._last_screenshot_size
|
||||
|
||||
async def screenshot(self) -> str:
|
||||
"""Take a screenshot and return as base64 string."""
|
||||
async def screenshot(self, text: Optional[str] = None) -> str:
|
||||
"""Take a screenshot and return as base64 string.
|
||||
|
||||
Args:
|
||||
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
|
||||
"""
|
||||
result = await self._call_function(self.functions["screenshot"])
|
||||
b64_str = self._to_b64_str(result) # type: ignore
|
||||
|
||||
|
||||
@@ -243,18 +243,20 @@ async def replace_computer_call_with_function(
|
||||
"id": item.get("id"),
|
||||
"call_id": item.get("call_id"),
|
||||
"status": "completed",
|
||||
# Fall back to string representation
|
||||
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})",
|
||||
}
|
||||
]
|
||||
|
||||
elif item_type == "computer_call_output":
|
||||
# Simple conversion: computer_call_output -> function_call_output
|
||||
output = item.get("output")
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = [output]
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"content": [item.get("output")],
|
||||
"output": output,
|
||||
"id": item.get("id"),
|
||||
"status": "completed",
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-agent"
|
||||
version = "0.4.35"
|
||||
version = "0.4.36"
|
||||
description = "CUA (Computer Use) Agent for AI-driven computer interaction"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
|
||||
84
libs/python/agent/tests/conftest.py
Normal file
84
libs/python/agent/tests/conftest.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Pytest configuration and shared fixtures for agent package tests.
|
||||
|
||||
This file contains shared fixtures and configuration for all agent tests.
|
||||
Following SRP: This file ONLY handles test setup/teardown.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm():
|
||||
"""Mock liteLLM completion calls.
|
||||
|
||||
Use this fixture to avoid making real LLM API calls during tests.
|
||||
Returns a mock that simulates LLM responses.
|
||||
"""
|
||||
with patch("litellm.acompletion") as mock_completion:
|
||||
|
||||
async def mock_response(*args, **kwargs):
|
||||
"""Simulate a typical LLM response."""
|
||||
return {
|
||||
"id": "chatcmpl-test123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": kwargs.get("model", "anthropic/claude-3-5-sonnet-20241022"),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "This is a mocked response for testing.",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
},
|
||||
}
|
||||
|
||||
mock_completion.side_effect = mock_response
|
||||
yield mock_completion
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_computer():
|
||||
"""Mock Computer interface for agent tests.
|
||||
|
||||
Use this fixture to test agent logic without requiring a real Computer instance.
|
||||
"""
|
||||
computer = AsyncMock()
|
||||
computer.interface = AsyncMock()
|
||||
computer.interface.screenshot = AsyncMock(return_value=b"fake_screenshot_data")
|
||||
computer.interface.left_click = AsyncMock()
|
||||
computer.interface.type = AsyncMock()
|
||||
computer.interface.key = AsyncMock()
|
||||
|
||||
# Mock context manager
|
||||
computer.__aenter__ = AsyncMock(return_value=computer)
|
||||
computer.__aexit__ = AsyncMock()
|
||||
|
||||
return computer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_telemetry(monkeypatch):
|
||||
"""Disable telemetry for tests.
|
||||
|
||||
Use this fixture to ensure no telemetry is sent during tests.
|
||||
"""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages():
|
||||
"""Provide sample messages for testing.
|
||||
|
||||
Returns a list of messages in the expected format.
|
||||
"""
|
||||
return [{"role": "user", "content": "Take a screenshot and tell me what you see"}]
|
||||
139
libs/python/agent/tests/test_computer_agent.py
Normal file
139
libs/python/agent/tests/test_computer_agent.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Unit tests for ComputerAgent class.
|
||||
|
||||
This file tests ONLY the ComputerAgent initialization and basic functionality.
|
||||
Following SRP: This file tests ONE class (ComputerAgent).
|
||||
All external dependencies (liteLLM, Computer) are mocked.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestComputerAgentInitialization:
|
||||
"""Test ComputerAgent initialization (SRP: Only tests initialization)."""
|
||||
|
||||
@patch("agent.agent.litellm")
|
||||
def test_agent_initialization_with_model(self, mock_litellm, disable_telemetry):
|
||||
"""Test that agent can be initialized with a model string."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
assert agent is not None
|
||||
assert hasattr(agent, "model")
|
||||
assert agent.model == "anthropic/claude-3-5-sonnet-20241022"
|
||||
|
||||
@patch("agent.agent.litellm")
|
||||
def test_agent_initialization_with_tools(self, mock_litellm, disable_telemetry, mock_computer):
|
||||
"""Test that agent can be initialized with tools."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022", tools=[mock_computer])
|
||||
|
||||
assert agent is not None
|
||||
assert hasattr(agent, "tools")
|
||||
|
||||
@patch("agent.agent.litellm")
|
||||
def test_agent_initialization_with_max_budget(self, mock_litellm, disable_telemetry):
|
||||
"""Test that agent can be initialized with max trajectory budget."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
budget = 5.0
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022", max_trajectory_budget=budget
|
||||
)
|
||||
|
||||
assert agent is not None
|
||||
|
||||
@patch("agent.agent.litellm")
|
||||
def test_agent_requires_model(self, mock_litellm, disable_telemetry):
|
||||
"""Test that agent requires a model parameter."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# Should fail without model parameter - intentionally missing required argument
|
||||
ComputerAgent() # type: ignore[call-arg]
|
||||
|
||||
|
||||
class TestComputerAgentRun:
|
||||
"""Test ComputerAgent.run() method (SRP: Only tests run logic)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("agent.agent.litellm")
|
||||
async def test_agent_run_with_messages(self, mock_litellm, disable_telemetry, sample_messages):
|
||||
"""Test that agent.run() works with valid messages."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
# Mock liteLLM response
|
||||
mock_response = {
|
||||
"id": "chatcmpl-test",
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "Test response"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||
}
|
||||
|
||||
mock_litellm.acompletion = AsyncMock(return_value=mock_response)
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Run should return an async generator
|
||||
result_generator = agent.run(sample_messages)
|
||||
|
||||
assert result_generator is not None
|
||||
# Check it's an async generator
|
||||
assert hasattr(result_generator, "__anext__")
|
||||
|
||||
def test_agent_has_run_method(self, disable_telemetry):
|
||||
"""Test that agent has run method available."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Verify run method exists
|
||||
assert hasattr(agent, "run")
|
||||
assert callable(agent.run)
|
||||
|
||||
def test_agent_has_agent_loop(self, disable_telemetry):
|
||||
"""Test that agent has agent_loop initialized."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Verify agent_loop is initialized
|
||||
assert hasattr(agent, "agent_loop")
|
||||
assert agent.agent_loop is not None
|
||||
|
||||
|
||||
class TestComputerAgentTypes:
|
||||
"""Test AgentResponse and Messages types (SRP: Only tests type definitions)."""
|
||||
|
||||
def test_messages_type_exists(self):
|
||||
"""Test that Messages type is exported."""
|
||||
from agent import Messages
|
||||
|
||||
assert Messages is not None
|
||||
|
||||
def test_agent_response_type_exists(self):
|
||||
"""Test that AgentResponse type is exported."""
|
||||
from agent import AgentResponse
|
||||
|
||||
assert AgentResponse is not None
|
||||
|
||||
|
||||
class TestComputerAgentIntegration:
|
||||
"""Test ComputerAgent integration with Computer tool (SRP: Integration within package)."""
|
||||
|
||||
def test_agent_accepts_computer_tool(self, disable_telemetry, mock_computer):
|
||||
"""Test that agent can be initialized with Computer tool."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022", tools=[mock_computer])
|
||||
|
||||
# Verify agent accepted the tool
|
||||
assert agent is not None
|
||||
assert hasattr(agent, "tools")
|
||||
47
libs/python/computer-server/tests/conftest.py
Normal file
47
libs/python/computer-server/tests/conftest.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Pytest configuration and shared fixtures for computer-server package tests.
|
||||
|
||||
This file contains shared fixtures and configuration for all computer-server tests.
|
||||
Following SRP: This file ONLY handles test setup/teardown.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
"""Mock WebSocket connection for testing.
|
||||
|
||||
Use this fixture to test WebSocket logic without real connections.
|
||||
"""
|
||||
websocket = AsyncMock()
|
||||
websocket.send = AsyncMock()
|
||||
websocket.recv = AsyncMock()
|
||||
websocket.close = AsyncMock()
|
||||
|
||||
return websocket
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_computer_interface():
|
||||
"""Mock computer interface for server tests.
|
||||
|
||||
Use this fixture to test server logic without real computer operations.
|
||||
"""
|
||||
interface = AsyncMock()
|
||||
interface.screenshot = AsyncMock(return_value=b"fake_screenshot")
|
||||
interface.left_click = AsyncMock()
|
||||
interface.type = AsyncMock()
|
||||
interface.key = AsyncMock()
|
||||
|
||||
return interface
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_telemetry(monkeypatch):
|
||||
"""Disable telemetry for tests.
|
||||
|
||||
Use this fixture to ensure no telemetry is sent during tests.
|
||||
"""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
|
||||
40
libs/python/computer-server/tests/test_server.py
Normal file
40
libs/python/computer-server/tests/test_server.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Unit tests for computer-server package.
|
||||
|
||||
This file tests ONLY basic server functionality.
|
||||
Following SRP: This file tests server initialization and basic operations.
|
||||
All external dependencies are mocked.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestServerImports:
|
||||
"""Test server module imports (SRP: Only tests imports)."""
|
||||
|
||||
def test_server_module_exists(self):
|
||||
"""Test that server module can be imported."""
|
||||
try:
|
||||
import computer_server
|
||||
|
||||
assert computer_server is not None
|
||||
except ImportError:
|
||||
pytest.skip("computer_server module not installed")
|
||||
|
||||
|
||||
class TestServerInitialization:
|
||||
"""Test server initialization (SRP: Only tests initialization)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_can_be_imported(self):
|
||||
"""Basic smoke test: verify server components can be imported."""
|
||||
try:
|
||||
from computer_server import server
|
||||
|
||||
assert server is not None
|
||||
except ImportError:
|
||||
pytest.skip("Server module not available")
|
||||
except Exception as e:
|
||||
# Some initialization errors are acceptable in unit tests
|
||||
pytest.skip(f"Server initialization requires specific setup: {e}")
|
||||
69
libs/python/computer/tests/conftest.py
Normal file
69
libs/python/computer/tests/conftest.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Pytest configuration and shared fixtures for computer package tests.
|
||||
|
||||
This file contains shared fixtures and configuration for all computer tests.
|
||||
Following SRP: This file ONLY handles test setup/teardown.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_interface():
|
||||
"""Mock computer interface for testing.
|
||||
|
||||
Use this fixture to test Computer logic without real OS calls.
|
||||
"""
|
||||
interface = AsyncMock()
|
||||
interface.screenshot = AsyncMock(return_value=b"fake_screenshot")
|
||||
interface.left_click = AsyncMock()
|
||||
interface.right_click = AsyncMock()
|
||||
interface.middle_click = AsyncMock()
|
||||
interface.double_click = AsyncMock()
|
||||
interface.type = AsyncMock()
|
||||
interface.key = AsyncMock()
|
||||
interface.move_mouse = AsyncMock()
|
||||
interface.scroll = AsyncMock()
|
||||
interface.get_screen_size = AsyncMock(return_value=(1920, 1080))
|
||||
|
||||
return interface
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cloud_provider():
|
||||
"""Mock cloud provider for testing.
|
||||
|
||||
Use this fixture to test cloud provider logic without real API calls.
|
||||
"""
|
||||
provider = AsyncMock()
|
||||
provider.start = AsyncMock()
|
||||
provider.stop = AsyncMock()
|
||||
provider.get_status = AsyncMock(return_value="running")
|
||||
provider.execute_command = AsyncMock(return_value="command output")
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_local_provider():
|
||||
"""Mock local provider for testing.
|
||||
|
||||
Use this fixture to test local provider logic without real VM operations.
|
||||
"""
|
||||
provider = AsyncMock()
|
||||
provider.start = AsyncMock()
|
||||
provider.stop = AsyncMock()
|
||||
provider.get_status = AsyncMock(return_value="running")
|
||||
provider.execute_command = AsyncMock(return_value="command output")
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_telemetry(monkeypatch):
|
||||
"""Disable telemetry for tests.
|
||||
|
||||
Use this fixture to ensure no telemetry is sent during tests.
|
||||
"""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
|
||||
67
libs/python/computer/tests/test_computer.py
Normal file
67
libs/python/computer/tests/test_computer.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Unit tests for Computer class.
|
||||
|
||||
This file tests ONLY the Computer class initialization and context manager.
|
||||
Following SRP: This file tests ONE class (Computer).
|
||||
All external dependencies (providers, interfaces) are mocked.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestComputerImport:
|
||||
"""Test Computer module imports (SRP: Only tests imports)."""
|
||||
|
||||
def test_computer_class_exists(self):
|
||||
"""Test that Computer class can be imported."""
|
||||
from computer import Computer
|
||||
|
||||
assert Computer is not None
|
||||
|
||||
def test_vm_provider_type_exists(self):
|
||||
"""Test that VMProviderType enum can be imported."""
|
||||
from computer import VMProviderType
|
||||
|
||||
assert VMProviderType is not None
|
||||
|
||||
|
||||
class TestComputerInitialization:
|
||||
"""Test Computer initialization (SRP: Only tests initialization)."""
|
||||
|
||||
def test_computer_class_can_be_imported(self, disable_telemetry):
|
||||
"""Test that Computer class can be imported without errors."""
|
||||
from computer import Computer
|
||||
|
||||
assert Computer is not None
|
||||
|
||||
def test_computer_has_required_methods(self, disable_telemetry):
|
||||
"""Test that Computer class has required methods."""
|
||||
from computer import Computer
|
||||
|
||||
assert hasattr(Computer, "__aenter__")
|
||||
assert hasattr(Computer, "__aexit__")
|
||||
|
||||
|
||||
class TestComputerContextManager:
|
||||
"""Test Computer context manager protocol (SRP: Only tests context manager)."""
|
||||
|
||||
def test_computer_is_async_context_manager(self, disable_telemetry):
|
||||
"""Test that Computer has async context manager methods."""
|
||||
from computer import Computer
|
||||
|
||||
assert hasattr(Computer, "__aenter__")
|
||||
assert hasattr(Computer, "__aexit__")
|
||||
assert callable(Computer.__aenter__)
|
||||
assert callable(Computer.__aexit__)
|
||||
|
||||
|
||||
class TestComputerInterface:
|
||||
"""Test Computer.interface property (SRP: Only tests interface access)."""
|
||||
|
||||
def test_computer_class_structure(self, disable_telemetry):
|
||||
"""Test that Computer class has expected structure."""
|
||||
from computer import Computer
|
||||
|
||||
# Verify Computer is a class
|
||||
assert isinstance(Computer, type)
|
||||
43
libs/python/core/tests/conftest.py
Normal file
43
libs/python/core/tests/conftest.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Pytest configuration and shared fixtures for core package tests.
|
||||
|
||||
This file contains shared fixtures and configuration for all core tests.
|
||||
Following SRP: This file ONLY handles test setup/teardown.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_client():
|
||||
"""Mock httpx.AsyncClient for API calls.
|
||||
|
||||
Use this fixture to avoid making real HTTP requests during tests.
|
||||
"""
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_instance = AsyncMock()
|
||||
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||
yield mock_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_posthog():
|
||||
"""Mock PostHog client for telemetry tests.
|
||||
|
||||
Use this fixture to avoid sending real telemetry during tests.
|
||||
"""
|
||||
with patch("posthog.Posthog") as mock_ph:
|
||||
mock_instance = Mock()
|
||||
mock_ph.return_value = mock_instance
|
||||
yield mock_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_telemetry(monkeypatch):
|
||||
"""Disable telemetry for tests that don't need it.
|
||||
|
||||
Use this fixture to ensure telemetry is disabled during tests.
|
||||
"""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
|
||||
yield
|
||||
255
libs/python/core/tests/test_telemetry.py
Normal file
255
libs/python/core/tests/test_telemetry.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Unit tests for core telemetry functionality.
|
||||
|
||||
This file tests ONLY telemetry logic, following SRP.
|
||||
All external dependencies (PostHog, file system) are mocked.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestTelemetryEnabled:
|
||||
"""Test telemetry enable/disable logic (SRP: Only tests enable/disable)."""
|
||||
|
||||
def test_telemetry_enabled_by_default(self, monkeypatch):
|
||||
"""Test that telemetry is enabled by default."""
|
||||
# Remove any environment variables that might affect the test
|
||||
monkeypatch.delenv("CUA_TELEMETRY", raising=False)
|
||||
monkeypatch.delenv("CUA_TELEMETRY_ENABLED", raising=False)
|
||||
|
||||
from core.telemetry import is_telemetry_enabled
|
||||
|
||||
assert is_telemetry_enabled() is True
|
||||
|
||||
def test_telemetry_disabled_with_legacy_flag(self, monkeypatch):
|
||||
"""Test that telemetry can be disabled with legacy CUA_TELEMETRY=off."""
|
||||
monkeypatch.setenv("CUA_TELEMETRY", "off")
|
||||
|
||||
from core.telemetry import is_telemetry_enabled
|
||||
|
||||
assert is_telemetry_enabled() is False
|
||||
|
||||
def test_telemetry_disabled_with_new_flag(self, monkeypatch):
|
||||
"""Test that telemetry can be disabled with CUA_TELEMETRY_ENABLED=false."""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", "false")
|
||||
|
||||
from core.telemetry import is_telemetry_enabled
|
||||
|
||||
assert is_telemetry_enabled() is False
|
||||
|
||||
@pytest.mark.parametrize("value", ["0", "false", "no", "off"])
|
||||
def test_telemetry_disabled_with_various_values(self, monkeypatch, value):
|
||||
"""Test that telemetry respects various disable values."""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", value)
|
||||
|
||||
from core.telemetry import is_telemetry_enabled
|
||||
|
||||
assert is_telemetry_enabled() is False
|
||||
|
||||
@pytest.mark.parametrize("value", ["1", "true", "yes", "on"])
|
||||
def test_telemetry_enabled_with_various_values(self, monkeypatch, value):
|
||||
"""Test that telemetry respects various enable values."""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", value)
|
||||
|
||||
from core.telemetry import is_telemetry_enabled
|
||||
|
||||
assert is_telemetry_enabled() is True
|
||||
|
||||
|
||||
class TestPostHogTelemetryClient:
|
||||
"""Test PostHogTelemetryClient class (SRP: Only tests client logic)."""
|
||||
|
||||
@patch("core.telemetry.posthog.posthog")
|
||||
@patch("core.telemetry.posthog.Path")
|
||||
def test_client_initialization(self, mock_path, mock_posthog, disable_telemetry):
|
||||
"""Test that client initializes correctly."""
|
||||
from core.telemetry.posthog import PostHogTelemetryClient
|
||||
|
||||
# Mock the storage directory
|
||||
mock_storage_dir = MagicMock()
|
||||
mock_storage_dir.exists.return_value = False
|
||||
mock_path.return_value.parent.parent = MagicMock()
|
||||
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
|
||||
|
||||
# Reset singleton
|
||||
PostHogTelemetryClient.destroy_client()
|
||||
|
||||
client = PostHogTelemetryClient()
|
||||
|
||||
assert client is not None
|
||||
assert hasattr(client, "installation_id")
|
||||
assert hasattr(client, "initialized")
|
||||
assert hasattr(client, "queued_events")
|
||||
|
||||
@patch("core.telemetry.posthog.posthog")
|
||||
@patch("core.telemetry.posthog.Path")
|
||||
def test_installation_id_generation(self, mock_path, mock_posthog, disable_telemetry):
|
||||
"""Test that installation ID is generated if not exists."""
|
||||
from core.telemetry.posthog import PostHogTelemetryClient
|
||||
|
||||
# Mock file system
|
||||
mock_id_file = MagicMock()
|
||||
mock_id_file.exists.return_value = False
|
||||
mock_storage_dir = MagicMock()
|
||||
mock_storage_dir.__truediv__.return_value = mock_id_file
|
||||
|
||||
mock_core_dir = MagicMock()
|
||||
mock_core_dir.__truediv__.return_value = mock_storage_dir
|
||||
mock_path.return_value.parent.parent = mock_core_dir
|
||||
|
||||
# Reset singleton
|
||||
PostHogTelemetryClient.destroy_client()
|
||||
|
||||
client = PostHogTelemetryClient()
|
||||
|
||||
# Should have generated a new UUID
|
||||
assert client.installation_id is not None
|
||||
assert len(client.installation_id) == 36 # UUID format
|
||||
|
||||
@patch("core.telemetry.posthog.posthog")
|
||||
@patch("core.telemetry.posthog.Path")
|
||||
def test_installation_id_persistence(self, mock_path, mock_posthog, disable_telemetry):
|
||||
"""Test that installation ID is read from file if exists."""
|
||||
from core.telemetry.posthog import PostHogTelemetryClient
|
||||
|
||||
existing_id = "test-installation-id-123"
|
||||
|
||||
# Mock file system
|
||||
mock_id_file = MagicMock()
|
||||
mock_id_file.exists.return_value = True
|
||||
mock_id_file.read_text.return_value = existing_id
|
||||
|
||||
mock_storage_dir = MagicMock()
|
||||
mock_storage_dir.__truediv__.return_value = mock_id_file
|
||||
|
||||
mock_core_dir = MagicMock()
|
||||
mock_core_dir.__truediv__.return_value = mock_storage_dir
|
||||
mock_path.return_value.parent.parent = mock_core_dir
|
||||
|
||||
# Reset singleton
|
||||
PostHogTelemetryClient.destroy_client()
|
||||
|
||||
client = PostHogTelemetryClient()
|
||||
|
||||
assert client.installation_id == existing_id
|
||||
|
||||
@patch("core.telemetry.posthog.posthog")
|
||||
@patch("core.telemetry.posthog.Path")
|
||||
def test_record_event_when_disabled(self, mock_path, mock_posthog, monkeypatch):
|
||||
"""Test that events are not recorded when telemetry is disabled."""
|
||||
from core.telemetry.posthog import PostHogTelemetryClient
|
||||
|
||||
# Disable telemetry explicitly using the correct environment variable
|
||||
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", "false")
|
||||
|
||||
# Mock file system
|
||||
mock_storage_dir = MagicMock()
|
||||
mock_storage_dir.exists.return_value = False
|
||||
mock_path.return_value.parent.parent = MagicMock()
|
||||
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
|
||||
|
||||
# Reset singleton
|
||||
PostHogTelemetryClient.destroy_client()
|
||||
|
||||
client = PostHogTelemetryClient()
|
||||
client.record_event("test_event", {"key": "value"})
|
||||
|
||||
# PostHog capture should not be called at all when telemetry is disabled
|
||||
mock_posthog.capture.assert_not_called()
|
||||
|
||||
@patch("core.telemetry.posthog.posthog")
|
||||
@patch("core.telemetry.posthog.Path")
|
||||
def test_record_event_when_enabled(self, mock_path, mock_posthog, monkeypatch):
|
||||
"""Test that events are recorded when telemetry is enabled."""
|
||||
from core.telemetry.posthog import PostHogTelemetryClient
|
||||
|
||||
# Enable telemetry
|
||||
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", "true")
|
||||
|
||||
# Mock file system
|
||||
mock_storage_dir = MagicMock()
|
||||
mock_storage_dir.exists.return_value = False
|
||||
mock_path.return_value.parent.parent = MagicMock()
|
||||
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
|
||||
|
||||
# Reset singleton
|
||||
PostHogTelemetryClient.destroy_client()
|
||||
|
||||
client = PostHogTelemetryClient()
|
||||
client.initialized = True # Pretend it's initialized
|
||||
|
||||
event_name = "test_event"
|
||||
event_props = {"key": "value"}
|
||||
client.record_event(event_name, event_props)
|
||||
|
||||
# PostHog capture should be called
|
||||
assert mock_posthog.capture.call_count >= 1
|
||||
|
||||
@patch("core.telemetry.posthog.posthog")
|
||||
@patch("core.telemetry.posthog.Path")
|
||||
def test_singleton_pattern(self, mock_path, mock_posthog, disable_telemetry):
|
||||
"""Test that get_client returns the same instance."""
|
||||
from core.telemetry.posthog import PostHogTelemetryClient
|
||||
|
||||
# Mock file system
|
||||
mock_storage_dir = MagicMock()
|
||||
mock_storage_dir.exists.return_value = False
|
||||
mock_path.return_value.parent.parent = MagicMock()
|
||||
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
|
||||
|
||||
# Reset singleton
|
||||
PostHogTelemetryClient.destroy_client()
|
||||
|
||||
client1 = PostHogTelemetryClient.get_client()
|
||||
client2 = PostHogTelemetryClient.get_client()
|
||||
|
||||
assert client1 is client2
|
||||
|
||||
|
||||
class TestRecordEvent:
|
||||
"""Test the public record_event function (SRP: Only tests public API)."""
|
||||
|
||||
@patch("core.telemetry.posthog.PostHogTelemetryClient")
|
||||
def test_record_event_calls_client(self, mock_client_class, disable_telemetry):
|
||||
"""Test that record_event delegates to the client."""
|
||||
from core.telemetry import record_event
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_class.get_client.return_value = mock_client_instance
|
||||
|
||||
event_name = "test_event"
|
||||
event_props = {"key": "value"}
|
||||
|
||||
record_event(event_name, event_props)
|
||||
|
||||
mock_client_instance.record_event.assert_called_once_with(event_name, event_props)
|
||||
|
||||
@patch("core.telemetry.posthog.PostHogTelemetryClient")
|
||||
def test_record_event_without_properties(self, mock_client_class, disable_telemetry):
|
||||
"""Test that record_event works without properties."""
|
||||
from core.telemetry import record_event
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_class.get_client.return_value = mock_client_instance
|
||||
|
||||
event_name = "test_event"
|
||||
|
||||
record_event(event_name)
|
||||
|
||||
mock_client_instance.record_event.assert_called_once_with(event_name, {})
|
||||
|
||||
|
||||
class TestDestroyTelemetryClient:
|
||||
"""Test client destruction (SRP: Only tests cleanup)."""
|
||||
|
||||
@patch("core.telemetry.posthog.PostHogTelemetryClient")
|
||||
def test_destroy_client_calls_class_method(self, mock_client_class):
|
||||
"""Test that destroy_telemetry_client delegates correctly."""
|
||||
from core.telemetry import destroy_telemetry_client
|
||||
|
||||
destroy_telemetry_client()
|
||||
|
||||
mock_client_class.destroy_client.assert_called_once()
|
||||
51
libs/python/mcp-server/tests/conftest.py
Normal file
51
libs/python/mcp-server/tests/conftest.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Pytest configuration and shared fixtures for mcp-server package tests.
|
||||
|
||||
This file contains shared fixtures and configuration for all mcp-server tests.
|
||||
Following SRP: This file ONLY handles test setup/teardown.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_context():
|
||||
"""Mock MCP context for testing.
|
||||
|
||||
Use this fixture to test MCP server logic without real MCP connections.
|
||||
"""
|
||||
context = AsyncMock()
|
||||
context.request_context = AsyncMock()
|
||||
context.session = Mock()
|
||||
context.session.send_resource_updated = AsyncMock()
|
||||
|
||||
return context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_computer():
|
||||
"""Mock Computer instance for MCP server tests.
|
||||
|
||||
Use this fixture to test MCP logic without real Computer operations.
|
||||
"""
|
||||
computer = AsyncMock()
|
||||
computer.interface = AsyncMock()
|
||||
computer.interface.screenshot = AsyncMock(return_value=b"fake_screenshot")
|
||||
computer.interface.left_click = AsyncMock()
|
||||
computer.interface.type = AsyncMock()
|
||||
|
||||
# Mock context manager
|
||||
computer.__aenter__ = AsyncMock(return_value=computer)
|
||||
computer.__aexit__ = AsyncMock()
|
||||
|
||||
return computer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_telemetry(monkeypatch):
|
||||
"""Disable telemetry for tests.
|
||||
|
||||
Use this fixture to ensure no telemetry is sent during tests.
|
||||
"""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
|
||||
44
libs/python/mcp-server/tests/test_mcp_server.py
Normal file
44
libs/python/mcp-server/tests/test_mcp_server.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Unit tests for mcp-server package.
|
||||
|
||||
This file tests ONLY basic MCP server functionality.
|
||||
Following SRP: This file tests MCP server initialization.
|
||||
All external dependencies are mocked.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestMCPServerImports:
|
||||
"""Test MCP server module imports (SRP: Only tests imports)."""
|
||||
|
||||
def test_mcp_server_module_exists(self):
|
||||
"""Test that mcp_server module can be imported."""
|
||||
try:
|
||||
import mcp_server
|
||||
|
||||
assert mcp_server is not None
|
||||
except ImportError:
|
||||
pytest.skip("mcp_server module not installed")
|
||||
except SystemExit:
|
||||
pytest.skip("MCP dependencies (mcp.server.fastmcp) not available")
|
||||
|
||||
|
||||
class TestMCPServerInitialization:
|
||||
"""Test MCP server initialization (SRP: Only tests initialization)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_server_can_be_imported(self):
|
||||
"""Basic smoke test: verify MCP server components can be imported."""
|
||||
try:
|
||||
from mcp_server import server
|
||||
|
||||
assert server is not None
|
||||
except ImportError:
|
||||
pytest.skip("MCP server module not available")
|
||||
except SystemExit:
|
||||
pytest.skip("MCP dependencies (mcp.server.fastmcp) not available")
|
||||
except Exception as e:
|
||||
# Some initialization errors are acceptable in unit tests
|
||||
pytest.skip(f"MCP server initialization requires specific setup: {e}")
|
||||
@@ -1,10 +0,0 @@
|
||||
[bumpversion]
|
||||
current_version = 0.2.1
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = pylume-v{new_version}
|
||||
message = Bump pylume to v{new_version}
|
||||
|
||||
[bumpversion:file:pylume/__init__.py]
|
||||
search = __version__ = "{current_version}"
|
||||
replace = __version__ = "{new_version}"
|
||||
@@ -1,46 +0,0 @@
|
||||
<div align="center">
|
||||
<h1>
|
||||
<div class="image-wrapper" style="display: inline-block;">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="https://raw.githubusercontent.com/trycua/cua/main/img/logo_white.png" style="display: block; margin: auto;">
|
||||
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="https://raw.githubusercontent.com/trycua/cua/main/img/logo_black.png" style="display: block; margin: auto;">
|
||||
<img alt="Shows my svg">
|
||||
</picture>
|
||||
</div>
|
||||
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/pylume/)
|
||||
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
**pylume** is a lightweight Python library based on [lume](https://github.com/trycua/lume) to create, run and manage macOS and Linux virtual machines (VMs) natively on Apple Silicon.
|
||||
|
||||
```bash
|
||||
pip install pylume
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Please refer to this [Notebook](./samples/nb.ipynb) for a quickstart. More details about the underlying API used by pylume are available [here](https://github.com/trycua/lume/docs/API-Reference.md).
|
||||
|
||||
## Prebuilt Images
|
||||
|
||||
Pre-built images are available on [ghcr.io/trycua](https://github.com/orgs/trycua/packages).
|
||||
These images come pre-configured with an SSH server and auto-login enabled.
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome and greatly appreciate contributions to lume! Whether you're improving documentation, adding new features, fixing bugs, or adding new VM images, your efforts help make pylume better for everyone.
|
||||
|
||||
Join our [Discord community](https://discord.com/invite/mVnXXpdE85) to discuss ideas or get assistance.
|
||||
|
||||
## License
|
||||
|
||||
lume is open-sourced under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
## Stargazers over time
|
||||
|
||||
[](https://starchart.cc/trycua/pylume)
|
||||
@@ -1,9 +0,0 @@
|
||||
"""
|
||||
PyLume Python SDK - A client library for managing macOS VMs with PyLume.
|
||||
"""
|
||||
|
||||
from pylume.exceptions import *
|
||||
from pylume.models import *
|
||||
from pylume.pylume import *
|
||||
|
||||
__version__ = "0.1.0"
|
||||
@@ -1,59 +0,0 @@
|
||||
"""
|
||||
PyLume Python SDK - A client library for managing macOS VMs with PyLume.
|
||||
|
||||
Example:
|
||||
>>> from pylume import PyLume, VMConfig
|
||||
>>> client = PyLume()
|
||||
>>> config = VMConfig(name="my-vm", cpu=4, memory="8GB", disk_size="64GB")
|
||||
>>> client.create_vm(config)
|
||||
>>> client.run_vm("my-vm")
|
||||
"""
|
||||
|
||||
# Import exceptions then all models
|
||||
from .exceptions import (
|
||||
LumeConfigError,
|
||||
LumeConnectionError,
|
||||
LumeError,
|
||||
LumeImageError,
|
||||
LumeNotFoundError,
|
||||
LumeServerError,
|
||||
LumeTimeoutError,
|
||||
LumeVMError,
|
||||
)
|
||||
from .models import (
|
||||
CloneSpec,
|
||||
ImageInfo,
|
||||
ImageList,
|
||||
ImageRef,
|
||||
SharedDirectory,
|
||||
VMConfig,
|
||||
VMRunOpts,
|
||||
VMStatus,
|
||||
VMUpdateOpts,
|
||||
)
|
||||
|
||||
# Import main class last to avoid circular imports
|
||||
from .pylume import PyLume
|
||||
|
||||
__version__ = "0.2.1"
|
||||
|
||||
__all__ = [
|
||||
"PyLume",
|
||||
"VMConfig",
|
||||
"VMStatus",
|
||||
"VMRunOpts",
|
||||
"VMUpdateOpts",
|
||||
"ImageRef",
|
||||
"CloneSpec",
|
||||
"SharedDirectory",
|
||||
"ImageList",
|
||||
"ImageInfo",
|
||||
"LumeError",
|
||||
"LumeServerError",
|
||||
"LumeConnectionError",
|
||||
"LumeTimeoutError",
|
||||
"LumeNotFoundError",
|
||||
"LumeConfigError",
|
||||
"LumeVMError",
|
||||
"LumeImageError",
|
||||
]
|
||||
@@ -1,119 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import shlex
|
||||
import subprocess
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .exceptions import (
|
||||
LumeConfigError,
|
||||
LumeConnectionError,
|
||||
LumeError,
|
||||
LumeNotFoundError,
|
||||
LumeServerError,
|
||||
LumeTimeoutError,
|
||||
)
|
||||
|
||||
|
||||
class LumeClient:
|
||||
def __init__(self, base_url: str, timeout: float = 60.0, debug: bool = False):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.debug = debug
|
||||
|
||||
def _log_debug(self, message: str, **kwargs) -> None:
|
||||
"""Log debug information if debug mode is enabled."""
|
||||
if self.debug:
|
||||
print(f"DEBUG: {message}")
|
||||
if kwargs:
|
||||
print(json.dumps(kwargs, indent=2))
|
||||
|
||||
async def _run_curl(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
"""Execute a curl command and return the response."""
|
||||
url = f"{self.base_url}{path}"
|
||||
if params:
|
||||
param_str = "&".join(f"{k}={v}" for k, v in params.items())
|
||||
url = f"{url}?{param_str}"
|
||||
|
||||
cmd = ["curl", "-X", method, "-s", "-w", "%{http_code}", "-m", str(self.timeout)]
|
||||
|
||||
if data is not None:
|
||||
cmd.extend(["-H", "Content-Type: application/json", "-d", json.dumps(data)])
|
||||
|
||||
cmd.append(url)
|
||||
|
||||
self._log_debug(f"Running curl command: {' '.join(map(shlex.quote, cmd))}")
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
raise LumeConnectionError(f"Curl command failed: {stderr.decode()}")
|
||||
|
||||
# The last 3 characters are the status code
|
||||
response = stdout.decode()
|
||||
status_code = int(response[-3:])
|
||||
response_body = response[:-3] # Remove status code from response
|
||||
|
||||
if status_code >= 400:
|
||||
if status_code == 404:
|
||||
raise LumeNotFoundError(f"Resource not found: {path}")
|
||||
elif status_code == 400:
|
||||
raise LumeConfigError(f"Invalid request: {response_body}")
|
||||
elif status_code >= 500:
|
||||
raise LumeServerError(f"Server error: {response_body}")
|
||||
else:
|
||||
raise LumeError(f"Request failed with status {status_code}: {response_body}")
|
||||
|
||||
return json.loads(response_body) if response_body.strip() else None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise LumeTimeoutError(f"Request timed out after {self.timeout} seconds")
|
||||
|
||||
async def get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""Make a GET request."""
|
||||
return await self._run_curl("GET", path, params=params)
|
||||
|
||||
async def post(
|
||||
self, path: str, data: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None
|
||||
) -> Any:
|
||||
"""Make a POST request."""
|
||||
old_timeout = self.timeout
|
||||
if timeout is not None:
|
||||
self.timeout = timeout
|
||||
try:
|
||||
return await self._run_curl("POST", path, data=data)
|
||||
finally:
|
||||
self.timeout = old_timeout
|
||||
|
||||
async def patch(self, path: str, data: Dict[str, Any]) -> None:
|
||||
"""Make a PATCH request."""
|
||||
await self._run_curl("PATCH", path, data=data)
|
||||
|
||||
async def delete(self, path: str) -> None:
|
||||
"""Make a DELETE request."""
|
||||
await self._run_curl("DELETE", path)
|
||||
|
||||
def print_curl(self, method: str, path: str, data: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Print equivalent curl command for debugging."""
|
||||
curl_cmd = f"""curl -X {method} \\
|
||||
'{self.base_url}{path}'"""
|
||||
|
||||
if data:
|
||||
curl_cmd += f" \\\n -H 'Content-Type: application/json' \\\n -d '{json.dumps(data)}'"
|
||||
|
||||
print("\nEquivalent curl command:")
|
||||
print(curl_cmd)
|
||||
print()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the client resources."""
|
||||
pass # No shared resources to clean up
|
||||
@@ -1,54 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LumeError(Exception):
|
||||
"""Base exception for all PyLume errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LumeServerError(LumeError):
|
||||
"""Raised when there's an error with the PyLume server."""
|
||||
|
||||
def __init__(
|
||||
self, message: str, status_code: Optional[int] = None, response_text: Optional[str] = None
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.response_text = response_text
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class LumeConnectionError(LumeError):
|
||||
"""Raised when there's an error connecting to the PyLume server."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LumeTimeoutError(LumeError):
|
||||
"""Raised when a request to the PyLume server times out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LumeNotFoundError(LumeError):
|
||||
"""Raised when a requested resource is not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LumeConfigError(LumeError):
|
||||
"""Raised when there's an error with the configuration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LumeVMError(LumeError):
|
||||
"""Raised when there's an error with a VM operation."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LumeImageError(LumeError):
|
||||
"""Raised when there's an error with an image operation."""
|
||||
|
||||
pass
|
||||
Binary file not shown.
@@ -1,265 +0,0 @@
|
||||
import re
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, computed_field, validator
|
||||
|
||||
|
||||
class DiskInfo(BaseModel):
|
||||
"""Information about disk storage allocation.
|
||||
|
||||
Attributes:
|
||||
total: Total disk space in bytes
|
||||
allocated: Currently allocated disk space in bytes
|
||||
"""
|
||||
|
||||
total: int
|
||||
allocated: int
|
||||
|
||||
|
||||
class VMConfig(BaseModel):
|
||||
"""Configuration for creating a new VM.
|
||||
|
||||
Note: Memory and disk sizes should be specified with units (e.g., "4GB", "64GB")
|
||||
|
||||
Attributes:
|
||||
name: Name of the virtual machine
|
||||
os: Operating system type, either "macOS" or "linux"
|
||||
cpu: Number of CPU cores to allocate
|
||||
memory: Amount of memory to allocate with units
|
||||
disk_size: Size of the disk to create with units
|
||||
display: Display resolution in format "widthxheight"
|
||||
ipsw: IPSW path or 'latest' for macOS VMs, None for other OS types
|
||||
"""
|
||||
|
||||
name: str
|
||||
os: Literal["macOS", "linux"] = "macOS"
|
||||
cpu: int = Field(default=2, ge=1)
|
||||
memory: str = "4GB"
|
||||
disk_size: str = Field(default="64GB", alias="diskSize")
|
||||
display: str = "1024x768"
|
||||
ipsw: Optional[str] = Field(default=None, description="IPSW path or 'latest', for macOS VMs")
|
||||
|
||||
class Config:
|
||||
populate_by_alias = True
|
||||
|
||||
|
||||
class SharedDirectory(BaseModel):
|
||||
"""Configuration for a shared directory.
|
||||
|
||||
Attributes:
|
||||
host_path: Path to the directory on the host system
|
||||
read_only: Whether the directory should be mounted as read-only
|
||||
"""
|
||||
|
||||
host_path: str = Field(..., alias="hostPath") # Allow host_path but serialize as hostPath
|
||||
read_only: bool = False
|
||||
|
||||
class Config:
|
||||
populate_by_name = True # Allow both alias and original name
|
||||
alias_generator = lambda s: "".join(
|
||||
word.capitalize() if i else word for i, word in enumerate(s.split("_"))
|
||||
)
|
||||
|
||||
|
||||
class VMRunOpts(BaseModel):
|
||||
"""Configuration for running a VM.
|
||||
|
||||
Args:
|
||||
no_display: Whether to not display the VNC client
|
||||
shared_directories: List of directories to share with the VM
|
||||
"""
|
||||
|
||||
no_display: bool = Field(default=False, alias="noDisplay")
|
||||
shared_directories: Optional[list[SharedDirectory]] = Field(
|
||||
default=None, alias="sharedDirectories"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
alias_generator=lambda s: "".join(
|
||||
word.capitalize() if i else word for i, word in enumerate(s.split("_"))
|
||||
),
|
||||
)
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
"""Export model data with proper field name conversion.
|
||||
|
||||
Converts shared directory fields to match API expectations when using aliases.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to parent model_dump method
|
||||
|
||||
Returns:
|
||||
dict: Model data with properly formatted field names
|
||||
"""
|
||||
data = super().model_dump(**kwargs)
|
||||
# Convert shared directory fields to match API expectations
|
||||
if self.shared_directories and "by_alias" in kwargs and kwargs["by_alias"]:
|
||||
data["sharedDirectories"] = [
|
||||
{"hostPath": d.host_path, "readOnly": d.read_only} for d in self.shared_directories
|
||||
]
|
||||
# Remove the snake_case version if it exists
|
||||
data.pop("shared_directories", None)
|
||||
return data
|
||||
|
||||
|
||||
class VMStatus(BaseModel):
|
||||
"""Status information for a virtual machine.
|
||||
|
||||
Attributes:
|
||||
name: Name of the virtual machine
|
||||
status: Current status of the VM
|
||||
os: Operating system type
|
||||
cpu_count: Number of CPU cores allocated
|
||||
memory_size: Amount of memory allocated in bytes
|
||||
disk_size: Disk storage information
|
||||
vnc_url: URL for VNC connection if available
|
||||
ip_address: IP address of the VM if available
|
||||
"""
|
||||
|
||||
name: str
|
||||
status: str
|
||||
os: Literal["macOS", "linux"]
|
||||
cpu_count: int = Field(alias="cpuCount")
|
||||
memory_size: int = Field(alias="memorySize") # API returns memory size in bytes
|
||||
disk_size: DiskInfo = Field(alias="diskSize")
|
||||
vnc_url: Optional[str] = Field(default=None, alias="vncUrl")
|
||||
ip_address: Optional[str] = Field(default=None, alias="ipAddress")
|
||||
|
||||
class Config:
|
||||
populate_by_alias = True
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def state(self) -> str:
|
||||
"""Get the current state of the VM.
|
||||
|
||||
Returns:
|
||||
str: Current VM status
|
||||
"""
|
||||
return self.status
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def cpu(self) -> int:
|
||||
"""Get the number of CPU cores.
|
||||
|
||||
Returns:
|
||||
int: Number of CPU cores allocated to the VM
|
||||
"""
|
||||
return self.cpu_count
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def memory(self) -> str:
|
||||
"""Get memory allocation in human-readable format.
|
||||
|
||||
Returns:
|
||||
str: Memory size formatted as "{size}GB"
|
||||
"""
|
||||
# Convert bytes to GB
|
||||
gb = self.memory_size / (1024 * 1024 * 1024)
|
||||
return f"{int(gb)}GB"
|
||||
|
||||
|
||||
class VMUpdateOpts(BaseModel):
|
||||
"""Options for updating VM configuration.
|
||||
|
||||
Attributes:
|
||||
cpu: Number of CPU cores to update to
|
||||
memory: Amount of memory to update to with units
|
||||
disk_size: Size of disk to update to with units
|
||||
"""
|
||||
|
||||
cpu: Optional[int] = None
|
||||
memory: Optional[str] = None
|
||||
disk_size: Optional[str] = None
|
||||
|
||||
|
||||
class ImageRef(BaseModel):
|
||||
"""Reference to a VM image.
|
||||
|
||||
Attributes:
|
||||
image: Name of the image
|
||||
tag: Tag version of the image
|
||||
registry: Registry hostname where image is stored
|
||||
organization: Organization or namespace in the registry
|
||||
"""
|
||||
|
||||
image: str
|
||||
tag: str = "latest"
|
||||
registry: Optional[str] = "ghcr.io"
|
||||
organization: Optional[str] = "trycua"
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
"""Override model_dump to return just the image:tag format.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments (ignored)
|
||||
|
||||
Returns:
|
||||
str: Image reference in "image:tag" format
|
||||
"""
|
||||
return f"{self.image}:{self.tag}"
|
||||
|
||||
|
||||
class CloneSpec(BaseModel):
|
||||
"""Specification for cloning a VM.
|
||||
|
||||
Attributes:
|
||||
name: Name of the source VM to clone
|
||||
new_name: Name for the new cloned VM
|
||||
"""
|
||||
|
||||
name: str
|
||||
new_name: str = Field(alias="newName")
|
||||
|
||||
class Config:
|
||||
populate_by_alias = True
|
||||
|
||||
|
||||
class ImageInfo(BaseModel):
|
||||
"""Model for individual image information.
|
||||
|
||||
Attributes:
|
||||
imageId: Unique identifier for the image
|
||||
"""
|
||||
|
||||
imageId: str
|
||||
|
||||
|
||||
class ImageList(RootModel):
|
||||
"""Response model for the images endpoint.
|
||||
|
||||
A list-like container for ImageInfo objects that provides
|
||||
iteration and indexing capabilities.
|
||||
"""
|
||||
|
||||
root: List[ImageInfo]
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over the image list.
|
||||
|
||||
Returns:
|
||||
Iterator over ImageInfo objects
|
||||
"""
|
||||
return iter(self.root)
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""Get an item from the image list by index.
|
||||
|
||||
Args:
|
||||
item: Index or slice to retrieve
|
||||
|
||||
Returns:
|
||||
ImageInfo or list of ImageInfo objects
|
||||
"""
|
||||
return self.root[item]
|
||||
|
||||
def __len__(self):
|
||||
"""Get the number of images in the list.
|
||||
|
||||
Returns:
|
||||
int: Number of images in the list
|
||||
"""
|
||||
return len(self.root)
|
||||
@@ -1,315 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, List, Optional, TypeVar, Union
|
||||
|
||||
from .client import LumeClient
|
||||
from .exceptions import (
|
||||
LumeConfigError,
|
||||
LumeConnectionError,
|
||||
LumeError,
|
||||
LumeImageError,
|
||||
LumeNotFoundError,
|
||||
LumeServerError,
|
||||
LumeTimeoutError,
|
||||
LumeVMError,
|
||||
)
|
||||
from .models import (
|
||||
CloneSpec,
|
||||
ImageList,
|
||||
ImageRef,
|
||||
SharedDirectory,
|
||||
VMConfig,
|
||||
VMRunOpts,
|
||||
VMStatus,
|
||||
VMUpdateOpts,
|
||||
)
|
||||
from .server import LumeServer
|
||||
|
||||
# Type variable for the decorator
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def ensure_server(func: Callable[..., T]) -> Callable[..., T]:
|
||||
"""Decorator to ensure server is running before executing the method."""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: "PyLume", *args: Any, **kwargs: Any) -> T:
|
||||
# ensure_running is an async method, so we need to await it
|
||||
await self.server.ensure_running()
|
||||
# Initialize client if needed
|
||||
await self._init_client()
|
||||
return await func(self, *args, **kwargs) # type: ignore
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
|
||||
class PyLume:
|
||||
def __init__(
|
||||
self,
|
||||
debug: bool = False,
|
||||
server_start_timeout: int = 60,
|
||||
port: Optional[int] = None,
|
||||
use_existing_server: bool = False,
|
||||
host: str = "localhost",
|
||||
):
|
||||
"""Initialize the async PyLume client.
|
||||
|
||||
Args:
|
||||
debug: Enable debug logging
|
||||
auto_start_server: Whether to automatically start the lume server if not running
|
||||
server_start_timeout: Timeout in seconds to wait for server to start
|
||||
port: Port number for the lume server. Required when use_existing_server is True.
|
||||
use_existing_server: If True, will try to connect to an existing server on the specified port
|
||||
instead of starting a new one.
|
||||
host: Host to use for connections (e.g., "localhost", "127.0.0.1", "host.docker.internal")
|
||||
"""
|
||||
if use_existing_server and port is None:
|
||||
raise LumeConfigError("Port must be specified when using an existing server")
|
||||
|
||||
self.server = LumeServer(
|
||||
debug=debug,
|
||||
server_start_timeout=server_start_timeout,
|
||||
port=port,
|
||||
use_existing_server=use_existing_server,
|
||||
host=host,
|
||||
)
|
||||
self.client = None
|
||||
|
||||
async def __aenter__(self) -> "PyLume":
|
||||
"""Async context manager entry."""
|
||||
if self.server.use_existing_server:
|
||||
# Just ensure base_url is set for existing server
|
||||
if self.server.requested_port is None:
|
||||
raise LumeConfigError("Port must be specified when using an existing server")
|
||||
|
||||
if not self.server.base_url:
|
||||
self.server.port = self.server.requested_port
|
||||
self.server.base_url = f"http://{self.server.host}:{self.server.port}/lume"
|
||||
|
||||
# Ensure the server is running (will connect to existing or start new as needed)
|
||||
await self.server.ensure_running()
|
||||
|
||||
# Initialize the client
|
||||
await self._init_client()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
"""Async context manager exit."""
|
||||
if self.client is not None:
|
||||
await self.client.close()
|
||||
await self.server.stop()
|
||||
|
||||
async def _init_client(self) -> None:
|
||||
"""Initialize the client if not already initialized."""
|
||||
if self.client is None:
|
||||
if self.server.base_url is None:
|
||||
raise RuntimeError("Server base URL not set")
|
||||
self.client = LumeClient(self.server.base_url, debug=self.server.debug)
|
||||
|
||||
def _log_debug(self, message: str, **kwargs) -> None:
|
||||
"""Log debug information if debug mode is enabled."""
|
||||
if self.server.debug:
|
||||
print(f"DEBUG: {message}")
|
||||
if kwargs:
|
||||
print(json.dumps(kwargs, indent=2))
|
||||
|
||||
async def _handle_api_error(self, e: Exception, operation: str) -> None:
|
||||
"""Handle API errors and raise appropriate custom exceptions."""
|
||||
if isinstance(e, subprocess.SubprocessError):
|
||||
raise LumeConnectionError(f"Failed to connect to PyLume server: {str(e)}")
|
||||
elif isinstance(e, asyncio.TimeoutError):
|
||||
raise LumeTimeoutError(f"Request timed out: {str(e)}")
|
||||
|
||||
if not hasattr(e, "status") and not isinstance(e, subprocess.CalledProcessError):
|
||||
raise LumeServerError(f"Unknown error during {operation}: {str(e)}")
|
||||
|
||||
status_code = getattr(e, "status", 500)
|
||||
response_text = str(e)
|
||||
|
||||
self._log_debug(
|
||||
f"{operation} request failed", status_code=status_code, response_text=response_text
|
||||
)
|
||||
|
||||
if status_code == 404:
|
||||
raise LumeNotFoundError(f"Resource not found during {operation}")
|
||||
elif status_code == 400:
|
||||
raise LumeConfigError(f"Invalid configuration for {operation}: {response_text}")
|
||||
elif status_code >= 500:
|
||||
raise LumeServerError(
|
||||
f"Server error during {operation}",
|
||||
status_code=status_code,
|
||||
response_text=response_text,
|
||||
)
|
||||
else:
|
||||
raise LumeServerError(
|
||||
f"Error during {operation}", status_code=status_code, response_text=response_text
|
||||
)
|
||||
|
||||
async def _read_output(self) -> None:
|
||||
"""Read and log server output."""
|
||||
try:
|
||||
while True:
|
||||
if not self.server.server_process or self.server.server_process.poll() is not None:
|
||||
self._log_debug("Server process ended")
|
||||
break
|
||||
|
||||
# Read stdout without blocking
|
||||
if self.server.server_process.stdout:
|
||||
while True:
|
||||
line = self.server.server_process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
self._log_debug(f"Server stdout: {line}")
|
||||
if "Server started" in line.decode("utf-8"):
|
||||
self._log_debug("Detected server started message")
|
||||
return
|
||||
|
||||
# Read stderr without blocking
|
||||
if self.server.server_process.stderr:
|
||||
while True:
|
||||
line = self.server.server_process.stderr.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
self._log_debug(f"Server stderr: {line}")
|
||||
if "error" in line.decode("utf-8").lower():
|
||||
raise RuntimeError(f"Server error: {line}")
|
||||
|
||||
await asyncio.sleep(0.1) # Small delay to prevent CPU spinning
|
||||
except Exception as e:
|
||||
self._log_debug(f"Error in output reader: {str(e)}")
|
||||
raise
|
||||
|
||||
@ensure_server
|
||||
async def create_vm(self, spec: Union[VMConfig, dict]) -> None:
|
||||
"""Create a VM with the given configuration."""
|
||||
# Ensure client is initialized
|
||||
await self._init_client()
|
||||
|
||||
if isinstance(spec, VMConfig):
|
||||
spec = spec.model_dump(by_alias=True, exclude_none=True)
|
||||
|
||||
# Suppress optional attribute access errors
|
||||
self.client.print_curl("POST", "/vms", spec) # type: ignore[attr-defined]
|
||||
await self.client.post("/vms", spec) # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def run_vm(self, name: str, opts: Optional[Union[VMRunOpts, dict]] = None) -> None:
|
||||
"""Run a VM."""
|
||||
if opts is None:
|
||||
opts = VMRunOpts(no_display=False) # type: ignore[attr-defined]
|
||||
elif isinstance(opts, dict):
|
||||
opts = VMRunOpts(**opts)
|
||||
|
||||
payload = opts.model_dump(by_alias=True, exclude_none=True)
|
||||
self.client.print_curl("POST", f"/vms/{name}/run", payload) # type: ignore[attr-defined]
|
||||
await self.client.post(f"/vms/{name}/run", payload) # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def list_vms(self) -> List[VMStatus]:
|
||||
"""List all VMs."""
|
||||
data = await self.client.get("/vms") # type: ignore[attr-defined]
|
||||
return [VMStatus.model_validate(vm) for vm in data]
|
||||
|
||||
@ensure_server
|
||||
async def get_vm(self, name: str) -> VMStatus:
|
||||
"""Get VM details."""
|
||||
data = await self.client.get(f"/vms/{name}") # type: ignore[attr-defined]
|
||||
return VMStatus.model_validate(data)
|
||||
|
||||
@ensure_server
|
||||
async def update_vm(self, name: str, params: Union[VMUpdateOpts, dict]) -> None:
|
||||
"""Update VM settings."""
|
||||
if isinstance(params, dict):
|
||||
params = VMUpdateOpts(**params)
|
||||
|
||||
payload = params.model_dump(by_alias=True, exclude_none=True)
|
||||
self.client.print_curl("PATCH", f"/vms/{name}", payload) # type: ignore[attr-defined]
|
||||
await self.client.patch(f"/vms/{name}", payload) # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def stop_vm(self, name: str) -> None:
|
||||
"""Stop a VM."""
|
||||
await self.client.post(f"/vms/{name}/stop") # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def delete_vm(self, name: str) -> None:
|
||||
"""Delete a VM."""
|
||||
await self.client.delete(f"/vms/{name}") # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def pull_image(
|
||||
self, spec: Union[ImageRef, dict, str], name: Optional[str] = None
|
||||
) -> None:
|
||||
"""Pull a VM image."""
|
||||
await self._init_client()
|
||||
if isinstance(spec, str):
|
||||
if ":" in spec:
|
||||
image_str = spec
|
||||
else:
|
||||
image_str = f"{spec}:latest"
|
||||
registry = "ghcr.io"
|
||||
organization = "trycua"
|
||||
elif isinstance(spec, dict):
|
||||
image = spec.get("image", "")
|
||||
tag = spec.get("tag", "latest")
|
||||
image_str = f"{image}:{tag}"
|
||||
registry = spec.get("registry", "ghcr.io")
|
||||
organization = spec.get("organization", "trycua")
|
||||
else:
|
||||
image_str = f"{spec.image}:{spec.tag}"
|
||||
registry = spec.registry
|
||||
organization = spec.organization
|
||||
|
||||
payload = {
|
||||
"image": image_str,
|
||||
"name": name,
|
||||
"registry": registry,
|
||||
"organization": organization,
|
||||
}
|
||||
|
||||
self.client.print_curl("POST", "/pull", payload) # type: ignore[attr-defined]
|
||||
await self.client.post("/pull", payload, timeout=300.0) # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def clone_vm(self, name: str, new_name: str) -> None:
|
||||
"""Clone a VM with the given name to a new VM with new_name."""
|
||||
config = CloneSpec(name=name, newName=new_name)
|
||||
self.client.print_curl("POST", "/vms/clone", config.model_dump()) # type: ignore[attr-defined]
|
||||
await self.client.post("/vms/clone", config.model_dump()) # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def get_latest_ipsw_url(self) -> str:
|
||||
"""Get the latest IPSW URL."""
|
||||
await self._init_client()
|
||||
data = await self.client.get("/ipsw") # type: ignore[attr-defined]
|
||||
return data["url"]
|
||||
|
||||
@ensure_server
|
||||
async def get_images(self, organization: Optional[str] = None) -> ImageList:
|
||||
"""Get list of available images."""
|
||||
await self._init_client()
|
||||
params = {"organization": organization} if organization else None
|
||||
data = await self.client.get("/images", params) # type: ignore[attr-defined]
|
||||
return ImageList(root=data)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the client and stop the server."""
|
||||
if self.client is not None:
|
||||
await self.client.close()
|
||||
self.client = None
|
||||
await asyncio.sleep(1)
|
||||
await self.server.stop()
|
||||
|
||||
async def _ensure_client(self) -> None:
|
||||
"""Ensure client is initialized."""
|
||||
if self.client is None:
|
||||
await self._init_client()
|
||||
@@ -1,481 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import shlex
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from .exceptions import LumeConnectionError
|
||||
|
||||
|
||||
class LumeServer:
|
||||
def __init__(
|
||||
self,
|
||||
debug: bool = False,
|
||||
server_start_timeout: int = 60,
|
||||
port: Optional[int] = None,
|
||||
use_existing_server: bool = False,
|
||||
host: str = "localhost",
|
||||
):
|
||||
"""Initialize the LumeServer.
|
||||
|
||||
Args:
|
||||
debug: Enable debug logging
|
||||
server_start_timeout: Timeout in seconds to wait for server to start
|
||||
port: Specific port to use for the server
|
||||
use_existing_server: If True, will try to connect to an existing server
|
||||
instead of starting a new one
|
||||
host: Host to use for connections (e.g., "localhost", "127.0.0.1", "host.docker.internal")
|
||||
"""
|
||||
self.debug = debug
|
||||
self.server_start_timeout = server_start_timeout
|
||||
self.server_process = None
|
||||
self.output_file = None
|
||||
self.requested_port = port
|
||||
self.port = None
|
||||
self.base_url = None
|
||||
self.use_existing_server = use_existing_server
|
||||
self.host = host
|
||||
|
||||
# Configure logging
|
||||
self.logger = getLogger("pylume.server")
|
||||
if not self.logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
self.logger.addHandler(handler)
|
||||
self.logger.setLevel(logging.DEBUG if debug else logging.INFO)
|
||||
|
||||
self.logger.debug(f"Server initialized with host: {self.host}")
|
||||
|
||||
def _check_port_available(self, port: int) -> bool:
|
||||
"""Check if a port is available."""
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(0.5)
|
||||
result = s.connect_ex(("127.0.0.1", port))
|
||||
if result == 0: # Port is in use on localhost
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check the specified host (e.g., "host.docker.internal") if it's not a localhost alias
|
||||
if self.host not in ["localhost", "127.0.0.1"]:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(0.5)
|
||||
result = s.connect_ex((self.host, port))
|
||||
if result == 0: # Port is in use on host
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
def _get_server_port(self) -> int:
|
||||
"""Get an available port for the server."""
|
||||
# Use requested port if specified
|
||||
if self.requested_port is not None:
|
||||
if not self._check_port_available(self.requested_port):
|
||||
raise RuntimeError(f"Requested port {self.requested_port} is not available")
|
||||
return self.requested_port
|
||||
|
||||
# Find a free port
|
||||
for _ in range(10): # Try up to 10 times
|
||||
port = random.randint(49152, 65535)
|
||||
if self._check_port_available(port):
|
||||
return port
|
||||
|
||||
raise RuntimeError("Could not find an available port")
|
||||
|
||||
async def _ensure_server_running(self) -> None:
|
||||
"""Ensure the lume server is running, start it if it's not."""
|
||||
try:
|
||||
self.logger.debug("Checking if lume server is running...")
|
||||
# Try to connect to the server with a short timeout
|
||||
cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "5", f"{self.base_url}/vms"]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
response = stdout.decode()
|
||||
status_code = int(response[-3:])
|
||||
if status_code == 200:
|
||||
self.logger.debug("PyLume server is running")
|
||||
return
|
||||
|
||||
self.logger.debug("PyLume server not running, attempting to start it")
|
||||
# Server not running, try to start it
|
||||
lume_path = os.path.join(os.path.dirname(__file__), "lume")
|
||||
if not os.path.exists(lume_path):
|
||||
raise RuntimeError(f"Could not find lume binary at {lume_path}")
|
||||
|
||||
# Make sure the file is executable
|
||||
os.chmod(lume_path, 0o755)
|
||||
|
||||
# Create a temporary file for server output
|
||||
self.output_file = tempfile.NamedTemporaryFile(mode="w+", delete=False)
|
||||
self.logger.debug(f"Using temporary file for server output: {self.output_file.name}")
|
||||
|
||||
# Start the server
|
||||
self.logger.debug(f"Starting lume server with: {lume_path} serve --port {self.port}")
|
||||
|
||||
# Start server in background using subprocess.Popen
|
||||
try:
|
||||
self.server_process = subprocess.Popen(
|
||||
[lume_path, "serve", "--port", str(self.port)],
|
||||
stdout=self.output_file,
|
||||
stderr=self.output_file,
|
||||
cwd=os.path.dirname(lume_path),
|
||||
start_new_session=True, # Run in new session to avoid blocking
|
||||
)
|
||||
except Exception as e:
|
||||
self.output_file.close()
|
||||
os.unlink(self.output_file.name)
|
||||
raise RuntimeError(f"Failed to start lume server process: {str(e)}")
|
||||
|
||||
# Wait for server to start
|
||||
self.logger.debug(
|
||||
f"Waiting up to {self.server_start_timeout} seconds for server to start..."
|
||||
)
|
||||
start_time = time.time()
|
||||
server_ready = False
|
||||
last_size = 0
|
||||
|
||||
while time.time() - start_time < self.server_start_timeout:
|
||||
if self.server_process.poll() is not None:
|
||||
# Process has terminated
|
||||
self.output_file.seek(0)
|
||||
output = self.output_file.read()
|
||||
self.output_file.close()
|
||||
os.unlink(self.output_file.name)
|
||||
error_msg = (
|
||||
f"Server process terminated unexpectedly.\n"
|
||||
f"Exit code: {self.server_process.returncode}\n"
|
||||
f"Output: {output}"
|
||||
)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Check output file for server ready message
|
||||
self.output_file.seek(0, os.SEEK_END)
|
||||
size = self.output_file.tell()
|
||||
if size > last_size: # Only read if there's new content
|
||||
self.output_file.seek(last_size)
|
||||
new_output = self.output_file.read()
|
||||
if new_output.strip(): # Only log non-empty output
|
||||
self.logger.debug(f"Server output: {new_output.strip()}")
|
||||
last_size = size
|
||||
|
||||
if "Server started" in new_output:
|
||||
server_ready = True
|
||||
self.logger.debug("Server startup detected")
|
||||
break
|
||||
|
||||
# Try to connect to the server periodically
|
||||
try:
|
||||
cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "5", f"{self.base_url}/vms"]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
response = stdout.decode()
|
||||
status_code = int(response[-3:])
|
||||
if status_code == 200:
|
||||
server_ready = True
|
||||
self.logger.debug("Server is responding to requests")
|
||||
break
|
||||
except:
|
||||
pass # Server not ready yet
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
if not server_ready:
|
||||
# Cleanup if server didn't start
|
||||
if self.server_process:
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.server_process.kill()
|
||||
self.output_file.close()
|
||||
os.unlink(self.output_file.name)
|
||||
raise RuntimeError(
|
||||
f"Failed to start lume server after {self.server_start_timeout} seconds. "
|
||||
"Check the debug output for more details."
|
||||
)
|
||||
|
||||
# Give the server a moment to fully initialize
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
# Verify server is responding
|
||||
try:
|
||||
cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "10", f"{self.base_url}/vms"]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
raise RuntimeError(f"Curl command failed: {stderr.decode()}")
|
||||
|
||||
response = stdout.decode()
|
||||
status_code = int(response[-3:])
|
||||
|
||||
if status_code != 200:
|
||||
raise RuntimeError(f"Server returned status code {status_code}")
|
||||
|
||||
self.logger.debug("PyLume server started successfully")
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Server verification failed: {str(e)}")
|
||||
if self.server_process:
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.server_process.kill()
|
||||
self.output_file.close()
|
||||
os.unlink(self.output_file.name)
|
||||
raise RuntimeError(f"Server started but is not responding: {str(e)}")
|
||||
|
||||
self.logger.debug("Server startup completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to start lume server: {str(e)}")
|
||||
|
||||
async def _start_server(self) -> None:
|
||||
"""Start the lume server using the lume executable."""
|
||||
self.logger.debug("Starting PyLume server")
|
||||
|
||||
# Get absolute path to lume executable in the same directory as this file
|
||||
lume_path = os.path.join(os.path.dirname(__file__), "lume")
|
||||
if not os.path.exists(lume_path):
|
||||
raise RuntimeError(f"Could not find lume binary at {lume_path}")
|
||||
|
||||
try:
|
||||
# Make executable
|
||||
os.chmod(lume_path, 0o755)
|
||||
|
||||
# Get and validate port
|
||||
self.port = self._get_server_port()
|
||||
self.base_url = f"http://{self.host}:{self.port}/lume"
|
||||
|
||||
# Set up output handling
|
||||
self.output_file = tempfile.NamedTemporaryFile(mode="w+", delete=False)
|
||||
|
||||
# Start the server process with the lume executable
|
||||
env = os.environ.copy()
|
||||
env["RUST_BACKTRACE"] = "1" # Enable backtrace for better error reporting
|
||||
|
||||
# Specify the host to bind to (0.0.0.0 to allow external connections)
|
||||
self.server_process = subprocess.Popen(
|
||||
[lume_path, "serve", "--port", str(self.port)],
|
||||
stdout=self.output_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=os.path.dirname(lume_path), # Run from same directory as executable
|
||||
env=env,
|
||||
)
|
||||
|
||||
# Wait for server to initialize
|
||||
await asyncio.sleep(2)
|
||||
await self._wait_for_server()
|
||||
|
||||
except Exception as e:
|
||||
await self._cleanup()
|
||||
raise RuntimeError(f"Failed to start lume server process: {str(e)}")
|
||||
|
||||
async def _tail_log(self) -> None:
|
||||
"""Read and display server log output in debug mode."""
|
||||
while True:
|
||||
try:
|
||||
self.output_file.seek(0, os.SEEK_END) # type: ignore[attr-defined]
|
||||
line = self.output_file.readline() # type: ignore[attr-defined]
|
||||
if line:
|
||||
line = line.strip()
|
||||
if line:
|
||||
print(f"SERVER: {line}")
|
||||
if self.server_process.poll() is not None: # type: ignore[attr-defined]
|
||||
print("Server process ended")
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
print(f"Error reading log: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _wait_for_server(self) -> None:
|
||||
"""Wait for server to start and become responsive with increased timeout."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.server_start_timeout:
|
||||
if self.server_process.poll() is not None: # type: ignore[attr-defined]
|
||||
error_msg = await self._get_error_output()
|
||||
await self._cleanup()
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
try:
|
||||
await self._verify_server()
|
||||
self.logger.debug("Server is now responsive")
|
||||
return
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Server not ready yet: {str(e)}")
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
await self._cleanup()
|
||||
raise RuntimeError(f"Server failed to start after {self.server_start_timeout} seconds")
|
||||
|
||||
async def _verify_server(self) -> None:
|
||||
"""Verify server is responding to requests."""
|
||||
try:
|
||||
cmd = [
|
||||
"curl",
|
||||
"-s",
|
||||
"-w",
|
||||
"%{http_code}",
|
||||
"-m",
|
||||
"10",
|
||||
f"http://{self.host}:{self.port}/lume/vms",
|
||||
]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
raise RuntimeError(f"Curl command failed: {stderr.decode()}")
|
||||
|
||||
response = stdout.decode()
|
||||
status_code = int(response[-3:])
|
||||
|
||||
if status_code != 200:
|
||||
raise RuntimeError(f"Server returned status code {status_code}")
|
||||
|
||||
self.logger.debug("PyLume server started successfully")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Server not responding: {str(e)}")
|
||||
|
||||
async def _get_error_output(self) -> str:
|
||||
"""Get error output from the server process."""
|
||||
if not self.output_file:
|
||||
return "No output available"
|
||||
self.output_file.seek(0)
|
||||
output = self.output_file.read()
|
||||
return (
|
||||
f"Server process terminated unexpectedly.\n"
|
||||
f"Exit code: {self.server_process.returncode}\n" # type: ignore[attr-defined]
|
||||
f"Output: {output}"
|
||||
)
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""Clean up all server resources."""
|
||||
if self.server_process:
|
||||
try:
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.server_process.kill()
|
||||
except:
|
||||
pass
|
||||
self.server_process = None
|
||||
|
||||
# Clean up output file
|
||||
if self.output_file:
|
||||
try:
|
||||
self.output_file.close()
|
||||
os.unlink(self.output_file.name)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Error cleaning up output file: {e}")
|
||||
self.output_file = None
|
||||
|
||||
async def ensure_running(self) -> None:
|
||||
"""Ensure the server is running.
|
||||
|
||||
If use_existing_server is True, will only try to connect to an existing server.
|
||||
Otherwise will:
|
||||
1. Try to connect to an existing server on the specified port
|
||||
2. If that fails and not in Docker, start a new server
|
||||
3. If in Docker and no existing server is found, raise an error
|
||||
"""
|
||||
# First check if we're in Docker
|
||||
in_docker = os.path.exists("/.dockerenv") or (
|
||||
os.path.exists("/proc/1/cgroup") and "docker" in open("/proc/1/cgroup", "r").read()
|
||||
)
|
||||
|
||||
# If using a non-localhost host like host.docker.internal, set up the connection details
|
||||
if self.host not in ["localhost", "127.0.0.1"]:
|
||||
if self.requested_port is None:
|
||||
raise RuntimeError("Port must be specified when using a remote host")
|
||||
|
||||
self.port = self.requested_port
|
||||
self.base_url = f"http://{self.host}:{self.port}/lume"
|
||||
self.logger.debug(f"Using remote host server at {self.base_url}")
|
||||
|
||||
# Try to verify the server is accessible
|
||||
try:
|
||||
await self._verify_server()
|
||||
self.logger.debug("Successfully connected to remote server")
|
||||
return
|
||||
except Exception as e:
|
||||
if self.use_existing_server or in_docker:
|
||||
# If explicitly requesting an existing server or in Docker, we can't start a new one
|
||||
raise RuntimeError(
|
||||
f"Failed to connect to remote server at {self.base_url}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
self.logger.debug(f"Remote server not available at {self.base_url}: {str(e)}")
|
||||
# Fall back to localhost for starting a new server
|
||||
self.host = "localhost"
|
||||
|
||||
# If explicitly using an existing server, verify it's running
|
||||
if self.use_existing_server:
|
||||
if self.requested_port is None:
|
||||
raise RuntimeError("Port must be specified when using an existing server")
|
||||
|
||||
self.port = self.requested_port
|
||||
self.base_url = f"http://{self.host}:{self.port}/lume"
|
||||
|
||||
try:
|
||||
await self._verify_server()
|
||||
self.logger.debug("Successfully connected to existing server")
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to connect to existing server at {self.base_url}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
# Try to connect to an existing server first
|
||||
if self.requested_port is not None:
|
||||
self.port = self.requested_port
|
||||
self.base_url = f"http://{self.host}:{self.port}/lume"
|
||||
|
||||
try:
|
||||
await self._verify_server()
|
||||
self.logger.debug("Successfully connected to existing server")
|
||||
return
|
||||
except Exception:
|
||||
self.logger.debug(f"No existing server found at {self.base_url}")
|
||||
|
||||
# If in Docker and can't connect to existing server, raise an error
|
||||
if in_docker:
|
||||
raise RuntimeError(
|
||||
f"Failed to connect to server at {self.base_url} and cannot start a new server in Docker"
|
||||
)
|
||||
|
||||
# Start a new server
|
||||
self.logger.debug("Starting a new server instance")
|
||||
await self._start_server()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the server if we're managing it."""
|
||||
if not self.use_existing_server:
|
||||
self.logger.debug("Stopping lume server...")
|
||||
await self._cleanup()
|
||||
@@ -1,51 +0,0 @@
|
||||
[build-system]
|
||||
build-backend = "pdm.backend"
|
||||
requires = ["pdm-backend"]
|
||||
|
||||
[project]
|
||||
authors = [{ name = "TryCua", email = "gh@trycua.com" }]
|
||||
classifiers = [
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: MacOS :: MacOS X",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
dependencies = ["pydantic>=2.11.1"]
|
||||
description = "Python SDK for lume - run macOS and Linux VMs on Apple Silicon"
|
||||
dynamic = ["version"]
|
||||
keywords = ["apple-silicon", "macos", "virtualization", "vm"]
|
||||
license = { text = "MIT" }
|
||||
name = "pylume"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
|
||||
[tool.pdm.version]
|
||||
path = "pylume/__init__.py"
|
||||
source = "file"
|
||||
|
||||
[project.urls]
|
||||
homepage = "https://github.com/trycua/pylume"
|
||||
repository = "https://github.com/trycua/pylume"
|
||||
|
||||
[tool.pdm]
|
||||
distribution = true
|
||||
|
||||
[tool.pdm.dev-dependencies]
|
||||
dev = [
|
||||
"black>=23.0.0",
|
||||
"isort>=5.12.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"pytest>=7.0.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
python_files = "test_*.py"
|
||||
testpaths = ["tests"]
|
||||
|
||||
[tool.pdm.build]
|
||||
includes = ["pylume/"]
|
||||
source-includes = ["LICENSE", "README.md", "tests/"]
|
||||
23
libs/python/pylume/tests/conftest.py
Normal file
23
libs/python/pylume/tests/conftest.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Pytest configuration for pylume tests.
|
||||
|
||||
This module provides test fixtures for the pylume package.
|
||||
Note: This package has macOS-specific dependencies and will skip tests
|
||||
if the required modules are not available.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_subprocess():
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = Mock(returncode=0, stdout="", stderr="")
|
||||
yield mock_run
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_requests():
|
||||
with patch("requests.get") as mock_get, patch("requests.post") as mock_post:
|
||||
yield {"get": mock_get, "post": mock_post}
|
||||
38
libs/python/pylume/tests/test_pylume.py
Normal file
38
libs/python/pylume/tests/test_pylume.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Unit tests for pylume package.
|
||||
|
||||
This file tests ONLY basic pylume functionality.
|
||||
Following SRP: This file tests pylume module imports and basic operations.
|
||||
All external dependencies are mocked.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestPylumeImports:
|
||||
"""Test pylume module imports (SRP: Only tests imports)."""
|
||||
|
||||
def test_pylume_module_exists(self):
|
||||
"""Test that pylume module can be imported."""
|
||||
try:
|
||||
import pylume
|
||||
|
||||
assert pylume is not None
|
||||
except ImportError:
|
||||
pytest.skip("pylume module not installed")
|
||||
|
||||
|
||||
class TestPylumeInitialization:
|
||||
"""Test pylume initialization (SRP: Only tests initialization)."""
|
||||
|
||||
def test_pylume_can_be_imported(self):
|
||||
"""Basic smoke test: verify pylume components can be imported."""
|
||||
try:
|
||||
import pylume
|
||||
|
||||
# Check for basic attributes
|
||||
assert pylume is not None
|
||||
except ImportError:
|
||||
pytest.skip("pylume module not available")
|
||||
except Exception as e:
|
||||
# Some initialization errors are acceptable in unit tests
|
||||
pytest.skip(f"pylume initialization requires specific setup: {e}")
|
||||
24
libs/python/som/tests/conftest.py
Normal file
24
libs/python/som/tests/conftest.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Pytest configuration for som tests.
|
||||
|
||||
This module provides test fixtures for the som (Set-of-Mark) package.
|
||||
The som package depends on heavy ML models and will skip tests if not available.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_torch():
|
||||
with patch("torch.load") as mock_load:
|
||||
mock_load.return_value = Mock()
|
||||
yield mock_load
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_icon_detector():
|
||||
with patch("omniparser.IconDetector") as mock_detector:
|
||||
instance = Mock()
|
||||
mock_detector.return_value = instance
|
||||
yield instance
|
||||
@@ -1,13 +1,73 @@
|
||||
# """Basic tests for the omniparser package."""
|
||||
"""Unit tests for som package (Set-of-Mark).
|
||||
|
||||
# import pytest
|
||||
# from omniparser import IconDetector
|
||||
This file tests ONLY basic som functionality.
|
||||
Following SRP: This file tests som module imports and basic operations.
|
||||
All external dependencies (ML models, OCR) are mocked.
|
||||
"""
|
||||
|
||||
# def test_icon_detector_import():
|
||||
# """Test that we can import the IconDetector class."""
|
||||
# assert IconDetector is not None
|
||||
import pytest
|
||||
|
||||
# def test_icon_detector_init():
|
||||
# """Test that we can create an IconDetector instance."""
|
||||
# detector = IconDetector(force_cpu=True)
|
||||
# assert detector is not None
|
||||
|
||||
class TestSomImports:
|
||||
"""Test som module imports (SRP: Only tests imports)."""
|
||||
|
||||
def test_som_module_exists(self):
|
||||
"""Test that som module can be imported."""
|
||||
try:
|
||||
import som
|
||||
|
||||
assert som is not None
|
||||
except ImportError:
|
||||
pytest.skip("som module not installed")
|
||||
|
||||
def test_omniparser_import(self):
|
||||
"""Test that OmniParser can be imported."""
|
||||
try:
|
||||
from som import OmniParser
|
||||
|
||||
assert OmniParser is not None
|
||||
except ImportError:
|
||||
pytest.skip("som module not available")
|
||||
except Exception as e:
|
||||
pytest.skip(f"som initialization requires ML models: {e}")
|
||||
|
||||
def test_models_import(self):
|
||||
"""Test that model classes can be imported."""
|
||||
try:
|
||||
from som import BoundingBox, ParseResult, UIElement
|
||||
|
||||
assert BoundingBox is not None
|
||||
assert UIElement is not None
|
||||
assert ParseResult is not None
|
||||
except ImportError:
|
||||
pytest.skip("som models not available")
|
||||
except Exception as e:
|
||||
pytest.skip(f"som models require dependencies: {e}")
|
||||
|
||||
|
||||
class TestSomModels:
|
||||
"""Test som data models (SRP: Only tests model structure)."""
|
||||
|
||||
def test_bounding_box_structure(self):
|
||||
"""Test BoundingBox class structure."""
|
||||
try:
|
||||
from som import BoundingBox
|
||||
|
||||
# Check the class exists and has expected structure
|
||||
assert hasattr(BoundingBox, "__init__")
|
||||
except ImportError:
|
||||
pytest.skip("som models not available")
|
||||
except Exception as e:
|
||||
pytest.skip(f"som models require dependencies: {e}")
|
||||
|
||||
def test_ui_element_structure(self):
|
||||
"""Test UIElement class structure."""
|
||||
try:
|
||||
from som import UIElement
|
||||
|
||||
# Check the class exists and has expected structure
|
||||
assert hasattr(UIElement, "__init__")
|
||||
except ImportError:
|
||||
pytest.skip("som models not available")
|
||||
except Exception as e:
|
||||
pytest.skip(f"som models require dependencies: {e}")
|
||||
|
||||
Reference in New Issue
Block a user