mirror of
https://github.com/trycua/computer.git
synced 2026-01-01 11:00:31 -06:00
Merge branch 'main' into feat/api_key_overrides
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.35
|
||||
current_version = 0.4.37
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = agent-v{new_version}
|
||||
|
||||
@@ -73,8 +73,8 @@ if __name__ == "__main__":
|
||||
## Docs
|
||||
|
||||
- [Agent Loops](https://trycua.com/docs/agent-sdk/agent-loops)
|
||||
- [Supported Agents](https://trycua.com/docs/agent-sdk/supported-agents)
|
||||
- [Supported Models](https://trycua.com/docs/agent-sdk/supported-models)
|
||||
- [Supported Agents](https://trycua.com/docs/agent-sdk/supported-agents/computer-use-agents)
|
||||
- [Supported Models](https://trycua.com/docs/agent-sdk/supported-model-providers)
|
||||
- [Chat History](https://trycua.com/docs/agent-sdk/chat-history)
|
||||
- [Callbacks](https://trycua.com/docs/agent-sdk/callbacks)
|
||||
- [Custom Tools](https://trycua.com/docs/agent-sdk/custom-tools)
|
||||
|
||||
@@ -28,8 +28,12 @@ class AsyncComputerHandler(Protocol):
|
||||
"""Get screen dimensions as (width, height)."""
|
||||
...
|
||||
|
||||
async def screenshot(self) -> str:
|
||||
"""Take a screenshot and return as base64 string."""
|
||||
async def screenshot(self, text: Optional[str] = None) -> str:
|
||||
"""Take a screenshot and return as base64 string.
|
||||
|
||||
Args:
|
||||
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
|
||||
"""
|
||||
...
|
||||
|
||||
async def click(self, x: int, y: int, button: str = "left") -> None:
|
||||
|
||||
@@ -36,8 +36,12 @@ class cuaComputerHandler(AsyncComputerHandler):
|
||||
screen_size = await self.interface.get_screen_size()
|
||||
return screen_size["width"], screen_size["height"]
|
||||
|
||||
async def screenshot(self) -> str:
|
||||
"""Take a screenshot and return as base64 string."""
|
||||
async def screenshot(self, text: Optional[str] = None) -> str:
|
||||
"""Take a screenshot and return as base64 string.
|
||||
|
||||
Args:
|
||||
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
|
||||
"""
|
||||
assert self.interface is not None
|
||||
screenshot_bytes = await self.interface.screenshot()
|
||||
return base64.b64encode(screenshot_bytes).decode("utf-8")
|
||||
|
||||
@@ -122,8 +122,12 @@ class CustomComputerHandler(AsyncComputerHandler):
|
||||
|
||||
return self._last_screenshot_size
|
||||
|
||||
async def screenshot(self) -> str:
|
||||
"""Take a screenshot and return as base64 string."""
|
||||
async def screenshot(self, text: Optional[str] = None) -> str:
|
||||
"""Take a screenshot and return as base64 string.
|
||||
|
||||
Args:
|
||||
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
|
||||
"""
|
||||
result = await self._call_function(self.functions["screenshot"])
|
||||
b64_str = self._to_b64_str(result) # type: ignore
|
||||
|
||||
|
||||
@@ -14,67 +14,73 @@ import litellm
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..responses import (
|
||||
convert_completion_messages_to_responses_items,
|
||||
convert_responses_items_to_completion_messages,
|
||||
)
|
||||
from ..types import AgentCapability, AgentResponse, Messages, Tools
|
||||
|
||||
SOM_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment",
|
||||
],
|
||||
"description": "The action to perform",
|
||||
},
|
||||
"element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)",
|
||||
},
|
||||
"start_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to start dragging from (required for drag action)",
|
||||
},
|
||||
"end_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to drag to (required for drag action)",
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)",
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')",
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment",
|
||||
],
|
||||
"description": "The action to perform",
|
||||
},
|
||||
"element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)",
|
||||
},
|
||||
"start_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to start dragging from (required for drag action)",
|
||||
},
|
||||
"end_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to drag to (required for drag action)",
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)",
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')",
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
|
||||
},
|
||||
},
|
||||
"required": ["action", "element_id"],
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
}
|
||||
|
||||
@@ -243,18 +249,20 @@ async def replace_computer_call_with_function(
|
||||
"id": item.get("id"),
|
||||
"call_id": item.get("call_id"),
|
||||
"status": "completed",
|
||||
# Fall back to string representation
|
||||
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})",
|
||||
}
|
||||
]
|
||||
|
||||
elif item_type == "computer_call_output":
|
||||
# Simple conversion: computer_call_output -> function_call_output
|
||||
output = item.get("output")
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = [output]
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"content": [item.get("output")],
|
||||
"output": item.get("output"),
|
||||
"id": item.get("id"),
|
||||
"status": "completed",
|
||||
}
|
||||
@@ -296,6 +304,13 @@ class OmniparserConfig(AsyncAgentConfig):
|
||||
|
||||
llm_model = model.split("+")[-1]
|
||||
|
||||
# Get screen dimensions from computer handler
|
||||
try:
|
||||
width, height = await computer_handler.get_dimensions()
|
||||
except Exception:
|
||||
# Fallback to default dimensions if method fails
|
||||
width, height = 1024, 768
|
||||
|
||||
# Prepare tools for OpenAI API
|
||||
openai_tools, id2xy = _prepare_tools_for_omniparser(tools)
|
||||
|
||||
@@ -309,27 +324,43 @@ class OmniparserConfig(AsyncAgentConfig):
|
||||
result = parser.parse(image_data)
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(result.annotated_image_base64, "annotated_image")
|
||||
for element in result.elements:
|
||||
id2xy[element.id] = (
|
||||
(element.bbox.x1 + element.bbox.x2) / 2,
|
||||
(element.bbox.y1 + element.bbox.y2) / 2,
|
||||
)
|
||||
|
||||
# handle computer calls -> function calls
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
# Convert OmniParser normalized coordinates (0-1) to absolute pixels, convert to pixels
|
||||
for element in result.elements:
|
||||
norm_x = (element.bbox.x1 + element.bbox.x2) / 2
|
||||
norm_y = (element.bbox.y1 + element.bbox.y2) / 2
|
||||
pixel_x = int(norm_x * width)
|
||||
pixel_y = int(norm_y * height)
|
||||
id2xy[element.id] = (pixel_x, pixel_y)
|
||||
|
||||
# Replace the original screenshot with the annotated image
|
||||
annotated_image_url = f"data:image/png;base64,{result.annotated_image_base64}"
|
||||
last_computer_call_output["output"]["image_url"] = annotated_image_url
|
||||
|
||||
xy2id = {v: k for k, v in id2xy.items()}
|
||||
messages_with_element_ids = []
|
||||
for i, message in enumerate(messages):
|
||||
if not isinstance(message, dict):
|
||||
message = message.__dict__
|
||||
new_messages += await replace_computer_call_with_function(message, id2xy) # type: ignore
|
||||
messages = new_messages
|
||||
|
||||
msg_type = message.get("type")
|
||||
|
||||
if msg_type == "computer_call" and "action" in message:
|
||||
action = message.get("action", {})
|
||||
|
||||
converted = await replace_computer_call_with_function(message, xy2id) # type: ignore
|
||||
messages_with_element_ids += converted
|
||||
|
||||
completion_messages = convert_responses_items_to_completion_messages(
|
||||
messages_with_element_ids, allow_images_in_tool_results=False
|
||||
)
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": llm_model,
|
||||
"input": messages,
|
||||
"messages": completion_messages,
|
||||
"tools": openai_tools if openai_tools else None,
|
||||
"stream": stream,
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs,
|
||||
}
|
||||
@@ -340,8 +371,8 @@ class OmniparserConfig(AsyncAgentConfig):
|
||||
|
||||
print(str(api_kwargs)[:1000])
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
# Use liteLLM completion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
@@ -355,12 +386,45 @@ class OmniparserConfig(AsyncAgentConfig):
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# handle som function calls -> xy computer calls
|
||||
new_output = []
|
||||
for i in range(len(response.output)): # type: ignore
|
||||
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) # type: ignore
|
||||
response_dict = response.model_dump() # type: ignore
|
||||
choice_messages = [choice["message"] for choice in response_dict["choices"]]
|
||||
responses_items = []
|
||||
for choice_message in choice_messages:
|
||||
responses_items.extend(convert_completion_messages_to_responses_items([choice_message]))
|
||||
|
||||
return {"output": new_output, "usage": usage}
|
||||
# Convert element_id → x,y (similar to moondream's convert_computer_calls_desc2xy)
|
||||
final_output = []
|
||||
for item in responses_items:
|
||||
if item.get("type") == "computer_call" and "action" in item:
|
||||
action = item["action"].copy()
|
||||
|
||||
# Handle single element_id
|
||||
if "element_id" in action:
|
||||
element_id = action["element_id"]
|
||||
if element_id in id2xy:
|
||||
x, y = id2xy[element_id]
|
||||
action["x"] = x
|
||||
action["y"] = y
|
||||
del action["element_id"]
|
||||
|
||||
# Handle start_element_id and end_element_id for drag operations
|
||||
elif "start_element_id" in action and "end_element_id" in action:
|
||||
start_id = action["start_element_id"]
|
||||
end_id = action["end_element_id"]
|
||||
if start_id in id2xy and end_id in id2xy:
|
||||
start_x, start_y = id2xy[start_id]
|
||||
end_x, end_y = id2xy[end_id]
|
||||
action["path"] = [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}]
|
||||
del action["start_element_id"]
|
||||
del action["end_element_id"]
|
||||
|
||||
converted_item = item.copy()
|
||||
converted_item["action"] = action
|
||||
final_output.append(converted_item)
|
||||
else:
|
||||
final_output.append(item)
|
||||
|
||||
return {"output": final_output, "usage": usage}
|
||||
|
||||
async def predict_click(
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-agent"
|
||||
version = "0.4.35"
|
||||
version = "0.4.37"
|
||||
description = "CUA (Computer Use) Agent for AI-driven computer interaction"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
|
||||
84
libs/python/agent/tests/conftest.py
Normal file
84
libs/python/agent/tests/conftest.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Pytest configuration and shared fixtures for agent package tests.
|
||||
|
||||
This file contains shared fixtures and configuration for all agent tests.
|
||||
Following SRP: This file ONLY handles test setup/teardown.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_litellm():
|
||||
"""Mock liteLLM completion calls.
|
||||
|
||||
Use this fixture to avoid making real LLM API calls during tests.
|
||||
Returns a mock that simulates LLM responses.
|
||||
"""
|
||||
with patch("litellm.acompletion") as mock_completion:
|
||||
|
||||
async def mock_response(*args, **kwargs):
|
||||
"""Simulate a typical LLM response."""
|
||||
return {
|
||||
"id": "chatcmpl-test123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": kwargs.get("model", "anthropic/claude-3-5-sonnet-20241022"),
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "This is a mocked response for testing.",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
"total_tokens": 30,
|
||||
},
|
||||
}
|
||||
|
||||
mock_completion.side_effect = mock_response
|
||||
yield mock_completion
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_computer():
|
||||
"""Mock Computer interface for agent tests.
|
||||
|
||||
Use this fixture to test agent logic without requiring a real Computer instance.
|
||||
"""
|
||||
computer = AsyncMock()
|
||||
computer.interface = AsyncMock()
|
||||
computer.interface.screenshot = AsyncMock(return_value=b"fake_screenshot_data")
|
||||
computer.interface.left_click = AsyncMock()
|
||||
computer.interface.type = AsyncMock()
|
||||
computer.interface.key = AsyncMock()
|
||||
|
||||
# Mock context manager
|
||||
computer.__aenter__ = AsyncMock(return_value=computer)
|
||||
computer.__aexit__ = AsyncMock()
|
||||
|
||||
return computer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_telemetry(monkeypatch):
|
||||
"""Disable telemetry for tests.
|
||||
|
||||
Use this fixture to ensure no telemetry is sent during tests.
|
||||
"""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_messages():
|
||||
"""Provide sample messages for testing.
|
||||
|
||||
Returns a list of messages in the expected format.
|
||||
"""
|
||||
return [{"role": "user", "content": "Take a screenshot and tell me what you see"}]
|
||||
139
libs/python/agent/tests/test_computer_agent.py
Normal file
139
libs/python/agent/tests/test_computer_agent.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""Unit tests for ComputerAgent class.
|
||||
|
||||
This file tests ONLY the ComputerAgent initialization and basic functionality.
|
||||
Following SRP: This file tests ONE class (ComputerAgent).
|
||||
All external dependencies (liteLLM, Computer) are mocked.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestComputerAgentInitialization:
|
||||
"""Test ComputerAgent initialization (SRP: Only tests initialization)."""
|
||||
|
||||
@patch("agent.agent.litellm")
|
||||
def test_agent_initialization_with_model(self, mock_litellm, disable_telemetry):
|
||||
"""Test that agent can be initialized with a model string."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
assert agent is not None
|
||||
assert hasattr(agent, "model")
|
||||
assert agent.model == "anthropic/claude-3-5-sonnet-20241022"
|
||||
|
||||
@patch("agent.agent.litellm")
|
||||
def test_agent_initialization_with_tools(self, mock_litellm, disable_telemetry, mock_computer):
|
||||
"""Test that agent can be initialized with tools."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022", tools=[mock_computer])
|
||||
|
||||
assert agent is not None
|
||||
assert hasattr(agent, "tools")
|
||||
|
||||
@patch("agent.agent.litellm")
|
||||
def test_agent_initialization_with_max_budget(self, mock_litellm, disable_telemetry):
|
||||
"""Test that agent can be initialized with max trajectory budget."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
budget = 5.0
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022", max_trajectory_budget=budget
|
||||
)
|
||||
|
||||
assert agent is not None
|
||||
|
||||
@patch("agent.agent.litellm")
|
||||
def test_agent_requires_model(self, mock_litellm, disable_telemetry):
|
||||
"""Test that agent requires a model parameter."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# Should fail without model parameter - intentionally missing required argument
|
||||
ComputerAgent() # type: ignore[call-arg]
|
||||
|
||||
|
||||
class TestComputerAgentRun:
|
||||
"""Test ComputerAgent.run() method (SRP: Only tests run logic)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("agent.agent.litellm")
|
||||
async def test_agent_run_with_messages(self, mock_litellm, disable_telemetry, sample_messages):
|
||||
"""Test that agent.run() works with valid messages."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
# Mock liteLLM response
|
||||
mock_response = {
|
||||
"id": "chatcmpl-test",
|
||||
"choices": [
|
||||
{
|
||||
"message": {"role": "assistant", "content": "Test response"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
|
||||
}
|
||||
|
||||
mock_litellm.acompletion = AsyncMock(return_value=mock_response)
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Run should return an async generator
|
||||
result_generator = agent.run(sample_messages)
|
||||
|
||||
assert result_generator is not None
|
||||
# Check it's an async generator
|
||||
assert hasattr(result_generator, "__anext__")
|
||||
|
||||
def test_agent_has_run_method(self, disable_telemetry):
|
||||
"""Test that agent has run method available."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Verify run method exists
|
||||
assert hasattr(agent, "run")
|
||||
assert callable(agent.run)
|
||||
|
||||
def test_agent_has_agent_loop(self, disable_telemetry):
|
||||
"""Test that agent has agent_loop initialized."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Verify agent_loop is initialized
|
||||
assert hasattr(agent, "agent_loop")
|
||||
assert agent.agent_loop is not None
|
||||
|
||||
|
||||
class TestComputerAgentTypes:
|
||||
"""Test AgentResponse and Messages types (SRP: Only tests type definitions)."""
|
||||
|
||||
def test_messages_type_exists(self):
|
||||
"""Test that Messages type is exported."""
|
||||
from agent import Messages
|
||||
|
||||
assert Messages is not None
|
||||
|
||||
def test_agent_response_type_exists(self):
|
||||
"""Test that AgentResponse type is exported."""
|
||||
from agent import AgentResponse
|
||||
|
||||
assert AgentResponse is not None
|
||||
|
||||
|
||||
class TestComputerAgentIntegration:
|
||||
"""Test ComputerAgent integration with Computer tool (SRP: Integration within package)."""
|
||||
|
||||
def test_agent_accepts_computer_tool(self, disable_telemetry, mock_computer):
|
||||
"""Test that agent can be initialized with Computer tool."""
|
||||
from agent import ComputerAgent
|
||||
|
||||
agent = ComputerAgent(model="anthropic/claude-3-5-sonnet-20241022", tools=[mock_computer])
|
||||
|
||||
# Verify agent accepted the tool
|
||||
assert agent is not None
|
||||
assert hasattr(agent, "tools")
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.1.25
|
||||
current_version = 0.1.28
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = computer-server-v{new_version}
|
||||
|
||||
@@ -85,6 +85,102 @@ class BaseFileHandler(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class BaseDesktopHandler(ABC):
|
||||
"""Abstract base class for OS-specific desktop handlers.
|
||||
|
||||
Categories:
|
||||
- Wallpaper Actions: Methods for wallpaper operations
|
||||
- Desktop shortcut actions: Methods for managing desktop shortcuts
|
||||
"""
|
||||
|
||||
# Wallpaper Actions
|
||||
@abstractmethod
|
||||
async def get_desktop_environment(self) -> Dict[str, Any]:
|
||||
"""Get the current desktop environment name."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_wallpaper(self, path: str) -> Dict[str, Any]:
|
||||
"""Set the desktop wallpaper to the file at path."""
|
||||
pass
|
||||
|
||||
|
||||
class BaseWindowHandler(ABC):
|
||||
"""Abstract class for OS-specific window management handlers.
|
||||
|
||||
Categories:
|
||||
- Window Management: Methods for application/window control
|
||||
"""
|
||||
|
||||
# Window Management
|
||||
@abstractmethod
|
||||
async def open(self, target: str) -> Dict[str, Any]:
|
||||
"""Open a file or URL with the default application."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def launch(self, app: str, args: Optional[List[str]] = None) -> Dict[str, Any]:
|
||||
"""Launch an application with optional arguments."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_current_window_id(self) -> Dict[str, Any]:
|
||||
"""Get the currently active window ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_application_windows(self, app: str) -> Dict[str, Any]:
|
||||
"""Get windows belonging to an application (by name or bundle)."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_window_name(self, window_id: str) -> Dict[str, Any]:
|
||||
"""Get the title/name of a window by ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_window_size(self, window_id: str | int) -> Dict[str, Any]:
|
||||
"""Get the size of a window by ID as {width, height}."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def activate_window(self, window_id: str | int) -> Dict[str, Any]:
|
||||
"""Bring a window to the foreground by ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close_window(self, window_id: str | int) -> Dict[str, Any]:
|
||||
"""Close a window by ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_window_position(self, window_id: str | int) -> Dict[str, Any]:
|
||||
"""Get the top-left position of a window as {x, y}."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_window_size(
|
||||
self, window_id: str | int, width: int, height: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Set the size of a window by ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_window_position(self, window_id: str | int, x: int, y: int) -> Dict[str, Any]:
|
||||
"""Set the position of a window by ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def maximize_window(self, window_id: str | int) -> Dict[str, Any]:
|
||||
"""Maximize a window by ID."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def minimize_window(self, window_id: str | int) -> Dict[str, Any]:
|
||||
"""Minimize a window by ID."""
|
||||
pass
|
||||
|
||||
|
||||
class BaseAutomationHandler(ABC):
|
||||
"""Abstract base class for OS-specific automation handlers.
|
||||
|
||||
|
||||
@@ -4,7 +4,13 @@ from typing import Tuple, Type
|
||||
|
||||
from computer_server.diorama.base import BaseDioramaHandler
|
||||
|
||||
from .base import BaseAccessibilityHandler, BaseAutomationHandler, BaseFileHandler
|
||||
from .base import (
|
||||
BaseAccessibilityHandler,
|
||||
BaseAutomationHandler,
|
||||
BaseDesktopHandler,
|
||||
BaseFileHandler,
|
||||
BaseWindowHandler,
|
||||
)
|
||||
|
||||
# Conditionally import platform-specific handlers
|
||||
system = platform.system().lower()
|
||||
@@ -17,7 +23,7 @@ elif system == "linux":
|
||||
elif system == "windows":
|
||||
from .windows import WindowsAccessibilityHandler, WindowsAutomationHandler
|
||||
|
||||
from .generic import GenericFileHandler
|
||||
from .generic import GenericDesktopHandler, GenericFileHandler, GenericWindowHandler
|
||||
|
||||
|
||||
class HandlerFactory:
|
||||
@@ -49,9 +55,14 @@ class HandlerFactory:
|
||||
raise RuntimeError(f"Failed to determine current OS: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def create_handlers() -> (
|
||||
Tuple[BaseAccessibilityHandler, BaseAutomationHandler, BaseDioramaHandler, BaseFileHandler]
|
||||
):
|
||||
def create_handlers() -> Tuple[
|
||||
BaseAccessibilityHandler,
|
||||
BaseAutomationHandler,
|
||||
BaseDioramaHandler,
|
||||
BaseFileHandler,
|
||||
BaseDesktopHandler,
|
||||
BaseWindowHandler,
|
||||
]:
|
||||
"""Create and return appropriate handlers for the current OS.
|
||||
|
||||
Returns:
|
||||
@@ -70,6 +81,8 @@ class HandlerFactory:
|
||||
MacOSAutomationHandler(),
|
||||
MacOSDioramaHandler(),
|
||||
GenericFileHandler(),
|
||||
GenericDesktopHandler(),
|
||||
GenericWindowHandler(),
|
||||
)
|
||||
elif os_type == "linux":
|
||||
return (
|
||||
@@ -77,6 +90,8 @@ class HandlerFactory:
|
||||
LinuxAutomationHandler(),
|
||||
BaseDioramaHandler(),
|
||||
GenericFileHandler(),
|
||||
GenericDesktopHandler(),
|
||||
GenericWindowHandler(),
|
||||
)
|
||||
elif os_type == "windows":
|
||||
return (
|
||||
@@ -84,6 +99,8 @@ class HandlerFactory:
|
||||
WindowsAutomationHandler(),
|
||||
BaseDioramaHandler(),
|
||||
GenericFileHandler(),
|
||||
GenericDesktopHandler(),
|
||||
GenericWindowHandler(),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"OS '{os_type}' is not supported")
|
||||
|
||||
@@ -2,15 +2,26 @@
|
||||
Generic handlers for all OSes.
|
||||
|
||||
Includes:
|
||||
- DesktopHandler
|
||||
- FileHandler
|
||||
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import platform
|
||||
import subprocess
|
||||
import webbrowser
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .base import BaseFileHandler
|
||||
from ..utils import wallpaper
|
||||
from .base import BaseDesktopHandler, BaseFileHandler, BaseWindowHandler
|
||||
|
||||
try:
|
||||
import pywinctl as pwc
|
||||
except Exception: # pragma: no cover
|
||||
pwc = None # type: ignore
|
||||
|
||||
|
||||
def resolve_path(path: str) -> Path:
|
||||
@@ -25,6 +36,233 @@ def resolve_path(path: str) -> Path:
|
||||
return Path(path).expanduser().resolve()
|
||||
|
||||
|
||||
# ===== Cross-platform Desktop command handlers =====
|
||||
|
||||
|
||||
class GenericDesktopHandler(BaseDesktopHandler):
|
||||
"""
|
||||
Generic desktop handler providing desktop-related operations.
|
||||
|
||||
Implements:
|
||||
- get_desktop_environment: detect current desktop environment
|
||||
- set_wallpaper: set desktop wallpaper path
|
||||
"""
|
||||
|
||||
async def get_desktop_environment(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get the current desktop environment.
|
||||
|
||||
Returns:
|
||||
Dict containing 'success' boolean and either 'environment' string or 'error' string
|
||||
"""
|
||||
try:
|
||||
env = wallpaper.get_desktop_environment()
|
||||
return {"success": True, "environment": env}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def set_wallpaper(self, path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Set the desktop wallpaper to the specified path.
|
||||
|
||||
Args:
|
||||
path: The file path to set as wallpaper
|
||||
|
||||
Returns:
|
||||
Dict containing 'success' boolean and optionally 'error' string
|
||||
"""
|
||||
try:
|
||||
file_path = resolve_path(path)
|
||||
ok = wallpaper.set_wallpaper(str(file_path))
|
||||
return {"success": bool(ok)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
# ===== Cross-platform window control command handlers =====
|
||||
|
||||
|
||||
class GenericWindowHandler(BaseWindowHandler):
|
||||
"""
|
||||
Cross-platform window management using pywinctl where possible.
|
||||
"""
|
||||
|
||||
async def open(self, target: str) -> Dict[str, Any]:
|
||||
try:
|
||||
if target.startswith("http://") or target.startswith("https://"):
|
||||
ok = webbrowser.open(target)
|
||||
return {"success": bool(ok)}
|
||||
path = str(resolve_path(target))
|
||||
sys = platform.system().lower()
|
||||
if sys == "darwin":
|
||||
subprocess.Popen(["open", path])
|
||||
elif sys == "linux":
|
||||
subprocess.Popen(["xdg-open", path])
|
||||
elif sys == "windows":
|
||||
os.startfile(path) # type: ignore[attr-defined]
|
||||
else:
|
||||
return {"success": False, "error": f"Unsupported OS: {sys}"}
|
||||
return {"success": True}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def launch(self, app: str, args: Optional[list[str]] = None) -> Dict[str, Any]:
|
||||
try:
|
||||
if args:
|
||||
proc = subprocess.Popen([app, *args])
|
||||
else:
|
||||
# allow shell command like "libreoffice --writer"
|
||||
proc = subprocess.Popen(app, shell=True)
|
||||
return {"success": True, "pid": proc.pid}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
def _get_window_by_id(self, window_id: int | str) -> Optional[Any]:
|
||||
if pwc is None:
|
||||
raise RuntimeError("pywinctl not available")
|
||||
# Find by native handle among Window objects; getAllWindowsDict keys are titles
|
||||
try:
|
||||
for w in pwc.getAllWindows():
|
||||
if str(w.getHandle()) == str(window_id):
|
||||
return w
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def get_current_window_id(self) -> Dict[str, Any]:
|
||||
try:
|
||||
if pwc is None:
|
||||
return {"success": False, "error": "pywinctl not available"}
|
||||
win = pwc.getActiveWindow()
|
||||
if not win:
|
||||
return {"success": False, "error": "No active window"}
|
||||
return {"success": True, "window_id": win.getHandle()}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def get_application_windows(self, app: str) -> Dict[str, Any]:
|
||||
try:
|
||||
if pwc is None:
|
||||
return {"success": False, "error": "pywinctl not available"}
|
||||
wins = pwc.getWindowsWithTitle(app, condition=pwc.Re.CONTAINS, flags=pwc.Re.IGNORECASE)
|
||||
ids = [w.getHandle() for w in wins]
|
||||
return {"success": True, "windows": ids}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def get_window_name(self, window_id: int | str) -> Dict[str, Any]:
|
||||
try:
|
||||
if pwc is None:
|
||||
return {"success": False, "error": "pywinctl not available"}
|
||||
w = self._get_window_by_id(window_id)
|
||||
if not w:
|
||||
return {"success": False, "error": "Window not found"}
|
||||
return {"success": True, "name": w.title}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def get_window_size(self, window_id: int | str) -> Dict[str, Any]:
|
||||
try:
|
||||
if pwc is None:
|
||||
return {"success": False, "error": "pywinctl not available"}
|
||||
w = self._get_window_by_id(window_id)
|
||||
if not w:
|
||||
return {"success": False, "error": "Window not found"}
|
||||
width, height = w.size
|
||||
return {"success": True, "width": int(width), "height": int(height)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def get_window_position(self, window_id: int | str) -> Dict[str, Any]:
|
||||
try:
|
||||
if pwc is None:
|
||||
return {"success": False, "error": "pywinctl not available"}
|
||||
w = self._get_window_by_id(window_id)
|
||||
if not w:
|
||||
return {"success": False, "error": "Window not found"}
|
||||
x, y = w.position
|
||||
return {"success": True, "x": int(x), "y": int(y)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def set_window_size(
|
||||
self, window_id: int | str, width: int, height: int
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
if pwc is None:
|
||||
return {"success": False, "error": "pywinctl not available"}
|
||||
w = self._get_window_by_id(window_id)
|
||||
if not w:
|
||||
return {"success": False, "error": "Window not found"}
|
||||
ok = w.resizeTo(int(width), int(height))
|
||||
return {"success": bool(ok)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def set_window_position(self, window_id: int | str, x: int, y: int) -> Dict[str, Any]:
|
||||
try:
|
||||
if pwc is None:
|
||||
return {"success": False, "error": "pywinctl not available"}
|
||||
w = self._get_window_by_id(window_id)
|
||||
if not w:
|
||||
return {"success": False, "error": "Window not found"}
|
||||
ok = w.moveTo(int(x), int(y))
|
||||
return {"success": bool(ok)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def maximize_window(self, window_id: int | str) -> Dict[str, Any]:
|
||||
try:
|
||||
if pwc is None:
|
||||
return {"success": False, "error": "pywinctl not available"}
|
||||
w = self._get_window_by_id(window_id)
|
||||
if not w:
|
||||
return {"success": False, "error": "Window not found"}
|
||||
ok = w.maximize()
|
||||
return {"success": bool(ok)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def minimize_window(self, window_id: int | str) -> Dict[str, Any]:
|
||||
try:
|
||||
if pwc is None:
|
||||
return {"success": False, "error": "pywinctl not available"}
|
||||
w = self._get_window_by_id(window_id)
|
||||
if not w:
|
||||
return {"success": False, "error": "Window not found"}
|
||||
ok = w.minimize()
|
||||
return {"success": bool(ok)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def activate_window(self, window_id: int | str) -> Dict[str, Any]:
|
||||
try:
|
||||
if pwc is None:
|
||||
return {"success": False, "error": "pywinctl not available"}
|
||||
w = self._get_window_by_id(window_id)
|
||||
if not w:
|
||||
return {"success": False, "error": "Window not found"}
|
||||
ok = w.activate()
|
||||
return {"success": bool(ok)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def close_window(self, window_id: int | str) -> Dict[str, Any]:
|
||||
try:
|
||||
if pwc is None:
|
||||
return {"success": False, "error": "pywinctl not available"}
|
||||
w = self._get_window_by_id(window_id)
|
||||
if not w:
|
||||
return {"success": False, "error": "Window not found"}
|
||||
ok = w.close()
|
||||
return {"success": bool(ok)}
|
||||
except Exception as e:
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
# ===== Cross-platform file system command handlers =====
|
||||
|
||||
|
||||
class GenericFileHandler(BaseFileHandler):
|
||||
"""
|
||||
Generic file handler that provides file system operations for all operating systems.
|
||||
|
||||
@@ -75,9 +75,14 @@ except Exception:
|
||||
except Exception:
|
||||
package_version = "unknown"
|
||||
|
||||
accessibility_handler, automation_handler, diorama_handler, file_handler = (
|
||||
HandlerFactory.create_handlers()
|
||||
)
|
||||
(
|
||||
accessibility_handler,
|
||||
automation_handler,
|
||||
diorama_handler,
|
||||
file_handler,
|
||||
desktop_handler,
|
||||
window_handler,
|
||||
) = HandlerFactory.create_handlers()
|
||||
handlers = {
|
||||
"version": lambda: {"protocol": protocol_version, "package": package_version},
|
||||
# App-Use commands
|
||||
@@ -99,6 +104,23 @@ handlers = {
|
||||
"delete_file": file_handler.delete_file,
|
||||
"create_dir": file_handler.create_dir,
|
||||
"delete_dir": file_handler.delete_dir,
|
||||
# Desktop commands
|
||||
"get_desktop_environment": desktop_handler.get_desktop_environment,
|
||||
"set_wallpaper": desktop_handler.set_wallpaper,
|
||||
# Window management
|
||||
"open": window_handler.open,
|
||||
"launch": window_handler.launch,
|
||||
"get_current_window_id": window_handler.get_current_window_id,
|
||||
"get_application_windows": window_handler.get_application_windows,
|
||||
"get_window_name": window_handler.get_window_name,
|
||||
"get_window_size": window_handler.get_window_size,
|
||||
"get_window_position": window_handler.get_window_position,
|
||||
"set_window_size": window_handler.set_window_size,
|
||||
"set_window_position": window_handler.set_window_position,
|
||||
"maximize_window": window_handler.maximize_window,
|
||||
"minimize_window": window_handler.minimize_window,
|
||||
"activate_window": window_handler.activate_window,
|
||||
"close_window": window_handler.close_window,
|
||||
# Mouse commands
|
||||
"mouse_down": automation_handler.mouse_down,
|
||||
"mouse_up": automation_handler.mouse_up,
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
from . import wallpaper
|
||||
|
||||
__all__ = ["wallpaper"]
|
||||
321
libs/python/computer-server/computer_server/utils/wallpaper.py
Normal file
321
libs/python/computer-server/computer_server/utils/wallpaper.py
Normal file
@@ -0,0 +1,321 @@
|
||||
"""Set the desktop wallpaper."""
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def get_desktop_environment() -> str:
|
||||
"""
|
||||
Returns the name of the current desktop environment.
|
||||
"""
|
||||
# From https://stackoverflow.com/a/21213358/2624876
|
||||
# which takes from:
|
||||
# http://stackoverflow.com/questions/2035657/what-is-my-current-desktop-environment
|
||||
# and http://ubuntuforums.org/showthread.php?t=652320
|
||||
# and http://ubuntuforums.org/showthread.php?t=1139057
|
||||
if sys.platform in ["win32", "cygwin"]:
|
||||
return "windows"
|
||||
elif sys.platform == "darwin":
|
||||
return "mac"
|
||||
else: # Most likely either a POSIX system or something not much common
|
||||
desktop_session = os.environ.get("DESKTOP_SESSION")
|
||||
if (
|
||||
desktop_session is not None
|
||||
): # easier to match if we doesn't have to deal with character cases
|
||||
desktop_session = desktop_session.lower()
|
||||
if desktop_session in [
|
||||
"gnome",
|
||||
"unity",
|
||||
"cinnamon",
|
||||
"mate",
|
||||
"xfce4",
|
||||
"lxde",
|
||||
"fluxbox",
|
||||
"blackbox",
|
||||
"openbox",
|
||||
"icewm",
|
||||
"jwm",
|
||||
"afterstep",
|
||||
"trinity",
|
||||
"kde",
|
||||
]:
|
||||
return desktop_session
|
||||
## Special cases ##
|
||||
# Canonical sets $DESKTOP_SESSION to Lubuntu rather than LXDE if using LXDE.
|
||||
# There is no guarantee that they will not do the same with the other desktop environments.
|
||||
elif "xfce" in desktop_session or desktop_session.startswith("xubuntu"):
|
||||
return "xfce4"
|
||||
elif desktop_session.startswith("ubuntustudio"):
|
||||
return "kde"
|
||||
elif desktop_session.startswith("ubuntu"):
|
||||
return "gnome"
|
||||
elif desktop_session.startswith("lubuntu"):
|
||||
return "lxde"
|
||||
elif desktop_session.startswith("kubuntu"):
|
||||
return "kde"
|
||||
elif desktop_session.startswith("razor"): # e.g. razorkwin
|
||||
return "razor-qt"
|
||||
elif desktop_session.startswith("wmaker"): # e.g. wmaker-common
|
||||
return "windowmaker"
|
||||
gnome_desktop_session_id = os.environ.get("GNOME_DESKTOP_SESSION_ID")
|
||||
if os.environ.get("KDE_FULL_SESSION") == "true":
|
||||
return "kde"
|
||||
elif gnome_desktop_session_id:
|
||||
if "deprecated" not in gnome_desktop_session_id:
|
||||
return "gnome2"
|
||||
# From http://ubuntuforums.org/showthread.php?t=652320
|
||||
elif is_running("xfce-mcs-manage"):
|
||||
return "xfce4"
|
||||
elif is_running("ksmserver"):
|
||||
return "kde"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def is_running(process: str) -> bool:
|
||||
"""Returns whether a process with the given name is (likely) currently running.
|
||||
|
||||
Uses a basic text search, and so may have false positives.
|
||||
"""
|
||||
# From http://www.bloggerpolis.com/2011/05/how-to-check-if-a-process-is-running-using-python/
|
||||
# and http://richarddingwall.name/2009/06/18/windows-equivalents-of-ps-and-kill-commands/
|
||||
try: # Linux/Unix
|
||||
s = subprocess.Popen(["ps", "axw"], stdout=subprocess.PIPE)
|
||||
except: # Windows
|
||||
s = subprocess.Popen(["tasklist", "/v"], stdout=subprocess.PIPE)
|
||||
assert s.stdout is not None
|
||||
for x in s.stdout:
|
||||
# if re.search(process, x):
|
||||
if process in str(x):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def set_wallpaper(file_loc: str, first_run: bool = True):
|
||||
"""Sets the wallpaper to the given file location."""
|
||||
# From https://stackoverflow.com/a/21213504/2624876
|
||||
# I have not personally tested most of this. -- @1j01
|
||||
# -----------------------------------------
|
||||
|
||||
# Note: There are two common Linux desktop environments where
|
||||
# I have not been able to set the desktop background from
|
||||
# command line: KDE, Enlightenment
|
||||
desktop_env = get_desktop_environment()
|
||||
if desktop_env in ["gnome", "unity", "cinnamon"]:
|
||||
# Tested on Ubuntu 22 -- @1j01
|
||||
uri = Path(file_loc).as_uri()
|
||||
SCHEMA = "org.gnome.desktop.background"
|
||||
KEY = "picture-uri"
|
||||
# Needed for Ubuntu 22 in dark mode
|
||||
# Might be better to set only one or the other, depending on the current theme
|
||||
# In the settings it will say "This background selection only applies to the dark style"
|
||||
# even if it's set for both, arguably referring to the selection that you can make on that page.
|
||||
# -- @1j01
|
||||
KEY_DARK = "picture-uri-dark"
|
||||
try:
|
||||
from gi.repository import Gio # type: ignore
|
||||
|
||||
gsettings = Gio.Settings.new(SCHEMA) # type: ignore
|
||||
gsettings.set_string(KEY, uri)
|
||||
gsettings.set_string(KEY_DARK, uri)
|
||||
except Exception:
|
||||
# Fallback tested on Ubuntu 22 -- @1j01
|
||||
args = ["gsettings", "set", SCHEMA, KEY, uri]
|
||||
subprocess.Popen(args)
|
||||
args = ["gsettings", "set", SCHEMA, KEY_DARK, uri]
|
||||
subprocess.Popen(args)
|
||||
elif desktop_env == "mate":
|
||||
try: # MATE >= 1.6
|
||||
# info from http://wiki.mate-desktop.org/docs:gsettings
|
||||
args = ["gsettings", "set", "org.mate.background", "picture-filename", file_loc]
|
||||
subprocess.Popen(args)
|
||||
except Exception: # MATE < 1.6
|
||||
# From https://bugs.launchpad.net/variety/+bug/1033918
|
||||
args = [
|
||||
"mateconftool-2",
|
||||
"-t",
|
||||
"string",
|
||||
"--set",
|
||||
"/desktop/mate/background/picture_filename",
|
||||
file_loc,
|
||||
]
|
||||
subprocess.Popen(args)
|
||||
elif desktop_env == "gnome2": # Not tested
|
||||
# From https://bugs.launchpad.net/variety/+bug/1033918
|
||||
args = [
|
||||
"gconftool-2",
|
||||
"-t",
|
||||
"string",
|
||||
"--set",
|
||||
"/desktop/gnome/background/picture_filename",
|
||||
file_loc,
|
||||
]
|
||||
subprocess.Popen(args)
|
||||
## KDE4 is difficult
|
||||
## see http://blog.zx2c4.com/699 for a solution that might work
|
||||
elif desktop_env in ["kde3", "trinity"]:
|
||||
# From http://ubuntuforums.org/archive/index.php/t-803417.html
|
||||
args = ["dcop", "kdesktop", "KBackgroundIface", "setWallpaper", "0", file_loc, "6"]
|
||||
subprocess.Popen(args)
|
||||
elif desktop_env == "xfce4":
|
||||
# Iterate over all wallpaper-related keys and set to file_loc
|
||||
try:
|
||||
list_proc = subprocess.run(
|
||||
["xfconf-query", "-c", "xfce4-desktop", "-l"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
keys = []
|
||||
if list_proc.stdout:
|
||||
for line in list_proc.stdout.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# Common keys: .../last-image and .../image-path
|
||||
if "/last-image" in line or "/image-path" in line:
|
||||
keys.append(line)
|
||||
# Fallback: known defaults if none were listed
|
||||
if not keys:
|
||||
keys = [
|
||||
"/backdrop/screen0/monitorVNC-0/workspace0/last-image",
|
||||
"/backdrop/screen0/monitor0/image-path",
|
||||
]
|
||||
for key in keys:
|
||||
subprocess.run(
|
||||
[
|
||||
"xfconf-query",
|
||||
"-c",
|
||||
"xfce4-desktop",
|
||||
"-p",
|
||||
key,
|
||||
"-s",
|
||||
file_loc,
|
||||
],
|
||||
check=False,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
# Reload xfdesktop to apply changes
|
||||
subprocess.Popen(["xfdesktop", "--reload"])
|
||||
elif desktop_env == "razor-qt": # TODO: implement reload of desktop when possible
|
||||
if first_run:
|
||||
import configparser
|
||||
|
||||
desktop_conf = configparser.ConfigParser()
|
||||
# Development version
|
||||
desktop_conf_file = os.path.join(get_config_dir("razor"), "desktop.conf")
|
||||
if os.path.isfile(desktop_conf_file):
|
||||
config_option = R"screens\1\desktops\1\wallpaper"
|
||||
else:
|
||||
desktop_conf_file = os.path.join(get_home_dir(), ".razor/desktop.conf")
|
||||
config_option = R"desktops\1\wallpaper"
|
||||
desktop_conf.read(os.path.join(desktop_conf_file))
|
||||
try:
|
||||
if desktop_conf.has_option("razor", config_option): # only replacing a value
|
||||
desktop_conf.set("razor", config_option, file_loc)
|
||||
with open(desktop_conf_file, "w", encoding="utf-8", errors="replace") as f:
|
||||
desktop_conf.write(f)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
# TODO: reload desktop when possible
|
||||
pass
|
||||
elif desktop_env in ["fluxbox", "jwm", "openbox", "afterstep"]:
|
||||
# http://fluxbox-wiki.org/index.php/Howto_set_the_background
|
||||
# used fbsetbg on jwm too since I am too lazy to edit the XML configuration
|
||||
# now where fbsetbg does the job excellent anyway.
|
||||
# and I have not figured out how else it can be set on Openbox and AfterSTep
|
||||
# but fbsetbg works excellent here too.
|
||||
try:
|
||||
args = ["fbsetbg", file_loc]
|
||||
subprocess.Popen(args)
|
||||
except Exception:
|
||||
sys.stderr.write("ERROR: Failed to set wallpaper with fbsetbg!\n")
|
||||
sys.stderr.write("Please make sre that You have fbsetbg installed.\n")
|
||||
elif desktop_env == "icewm":
|
||||
# command found at http://urukrama.wordpress.com/2007/12/05/desktop-backgrounds-in-window-managers/
|
||||
args = ["icewmbg", file_loc]
|
||||
subprocess.Popen(args)
|
||||
elif desktop_env == "blackbox":
|
||||
# command found at http://blackboxwm.sourceforge.net/BlackboxDocumentation/BlackboxBackground
|
||||
args = ["bsetbg", "-full", file_loc]
|
||||
subprocess.Popen(args)
|
||||
elif desktop_env == "lxde":
|
||||
args = ["pcmanfm", "--set-wallpaper", file_loc, "--wallpaper-mode=scaled"]
|
||||
subprocess.Popen(args)
|
||||
elif desktop_env == "windowmaker":
|
||||
# From http://www.commandlinefu.com/commands/view/3857/set-wallpaper-on-windowmaker-in-one-line
|
||||
args = ["wmsetbg", "-s", "-u", file_loc]
|
||||
subprocess.Popen(args)
|
||||
# elif desktop_env == "enlightenment": # I have not been able to make it work on e17. On e16 it would have been something in this direction
|
||||
# args = ["enlightenment_remote", "-desktop-bg-add", "0", "0", "0", "0", file_loc]
|
||||
# subprocess.Popen(args)
|
||||
elif desktop_env == "windows":
|
||||
# From https://stackoverflow.com/questions/1977694/change-desktop-background
|
||||
# Tested on Windows 10. -- @1j01
|
||||
import ctypes
|
||||
|
||||
SPI_SETDESKWALLPAPER = 20
|
||||
ctypes.windll.user32.SystemParametersInfoW(SPI_SETDESKWALLPAPER, 0, file_loc, 0) # type: ignore
|
||||
elif desktop_env == "mac":
|
||||
# From https://stackoverflow.com/questions/431205/how-can-i-programatically-change-the-background-in-mac-os-x
|
||||
try:
|
||||
# Tested on macOS 10.14.6 (Mojave) -- @1j01
|
||||
assert (
|
||||
sys.platform == "darwin"
|
||||
) # ignore `Import "appscript" could not be resolved` for other platforms
|
||||
from appscript import app, mactypes
|
||||
|
||||
app("Finder").desktop_picture.set(mactypes.File(file_loc))
|
||||
except ImportError:
|
||||
# Tested on macOS 10.14.6 (Mojave) -- @1j01
|
||||
# import subprocess
|
||||
# SCRIPT = f"""/usr/bin/osascript<<END
|
||||
# tell application "Finder" to set desktop picture to POSIX file "{file_loc}"
|
||||
# END"""
|
||||
# subprocess.Popen(SCRIPT, shell=True)
|
||||
|
||||
# Safer version, avoiding string interpolation,
|
||||
# to protect against command injection (both in the shell and in AppleScript):
|
||||
OSASCRIPT = """
|
||||
on run (clp)
|
||||
if clp's length is not 1 then error "Incorrect Parameters"
|
||||
local file_loc
|
||||
set file_loc to clp's item 1
|
||||
tell application "Finder" to set desktop picture to POSIX file file_loc
|
||||
end run
|
||||
"""
|
||||
subprocess.Popen(["osascript", "-e", OSASCRIPT, "--", file_loc])
|
||||
else:
|
||||
if first_run: # don't spam the user with the same message over and over again
|
||||
sys.stderr.write(
|
||||
"Warning: Failed to set wallpaper. Your desktop environment is not supported."
|
||||
)
|
||||
sys.stderr.write(f"You can try manually to set your wallpaper to {file_loc}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_config_dir(app_name: str) -> str:
|
||||
"""Returns the configuration directory for the given application name."""
|
||||
if "XDG_CONFIG_HOME" in os.environ:
|
||||
config_home = os.environ["XDG_CONFIG_HOME"]
|
||||
elif "APPDATA" in os.environ: # On Windows
|
||||
config_home = os.environ["APPDATA"]
|
||||
else:
|
||||
try:
|
||||
from xdg import BaseDirectory
|
||||
|
||||
config_home = BaseDirectory.xdg_config_home
|
||||
except ImportError: # Most likely a Linux/Unix system anyway
|
||||
config_home = os.path.join(get_home_dir(), ".config")
|
||||
config_dir = os.path.join(config_home, app_name)
|
||||
return config_dir
|
||||
|
||||
|
||||
def get_home_dir() -> str:
|
||||
"""Returns the home directory of the current user."""
|
||||
return os.path.expanduser("~")
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-computer-server"
|
||||
version = "0.1.25"
|
||||
version = "0.1.28"
|
||||
|
||||
description = "Server component for the Computer-Use Interface (CUI) framework powering Cua"
|
||||
authors = [
|
||||
@@ -22,7 +22,15 @@ dependencies = [
|
||||
"pillow>=10.2.0",
|
||||
"aiohttp>=3.9.1",
|
||||
"pyperclip>=1.9.0",
|
||||
"websockets>=12.0"
|
||||
"websockets>=12.0",
|
||||
"pywinctl>=0.4.1",
|
||||
# OS-specific runtime deps
|
||||
"pyobjc-framework-Cocoa>=10.1; sys_platform == 'darwin'",
|
||||
"pyobjc-framework-Quartz>=10.1; sys_platform == 'darwin'",
|
||||
"pyobjc-framework-ApplicationServices>=10.1; sys_platform == 'darwin'",
|
||||
"python-xlib>=0.33; sys_platform == 'linux'",
|
||||
"pywin32>=310; sys_platform == 'win32'",
|
||||
"pip-system-certs; sys_platform == 'win32'",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
47
libs/python/computer-server/tests/conftest.py
Normal file
47
libs/python/computer-server/tests/conftest.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Pytest configuration and shared fixtures for computer-server package tests.
|
||||
|
||||
This file contains shared fixtures and configuration for all computer-server tests.
|
||||
Following SRP: This file ONLY handles test setup/teardown.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_websocket():
|
||||
"""Mock WebSocket connection for testing.
|
||||
|
||||
Use this fixture to test WebSocket logic without real connections.
|
||||
"""
|
||||
websocket = AsyncMock()
|
||||
websocket.send = AsyncMock()
|
||||
websocket.recv = AsyncMock()
|
||||
websocket.close = AsyncMock()
|
||||
|
||||
return websocket
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_computer_interface():
|
||||
"""Mock computer interface for server tests.
|
||||
|
||||
Use this fixture to test server logic without real computer operations.
|
||||
"""
|
||||
interface = AsyncMock()
|
||||
interface.screenshot = AsyncMock(return_value=b"fake_screenshot")
|
||||
interface.left_click = AsyncMock()
|
||||
interface.type = AsyncMock()
|
||||
interface.key = AsyncMock()
|
||||
|
||||
return interface
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_telemetry(monkeypatch):
|
||||
"""Disable telemetry for tests.
|
||||
|
||||
Use this fixture to ensure no telemetry is sent during tests.
|
||||
"""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
|
||||
40
libs/python/computer-server/tests/test_server.py
Normal file
40
libs/python/computer-server/tests/test_server.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Unit tests for computer-server package.
|
||||
|
||||
This file tests ONLY basic server functionality.
|
||||
Following SRP: This file tests server initialization and basic operations.
|
||||
All external dependencies are mocked.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestServerImports:
|
||||
"""Test server module imports (SRP: Only tests imports)."""
|
||||
|
||||
def test_server_module_exists(self):
|
||||
"""Test that server module can be imported."""
|
||||
try:
|
||||
import computer_server
|
||||
|
||||
assert computer_server is not None
|
||||
except ImportError:
|
||||
pytest.skip("computer_server module not installed")
|
||||
|
||||
|
||||
class TestServerInitialization:
|
||||
"""Test server initialization (SRP: Only tests initialization)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_can_be_imported(self):
|
||||
"""Basic smoke test: verify server components can be imported."""
|
||||
try:
|
||||
from computer_server import server
|
||||
|
||||
assert server is not None
|
||||
except ImportError:
|
||||
pytest.skip("Server module not available")
|
||||
except Exception as e:
|
||||
# Some initialization errors are acceptable in unit tests
|
||||
pytest.skip(f"Server initialization requires specific setup: {e}")
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.4.7
|
||||
current_version = 0.4.11
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = computer-v{new_version}
|
||||
|
||||
@@ -436,6 +436,189 @@ class BaseComputerInterface(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
# Desktop actions
|
||||
@abstractmethod
|
||||
async def get_desktop_environment(self) -> str:
|
||||
"""Get the current desktop environment.
|
||||
|
||||
Returns:
|
||||
The name of the current desktop environment.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_wallpaper(self, path: str) -> None:
|
||||
"""Set the desktop wallpaper to the specified path.
|
||||
|
||||
Args:
|
||||
path: The file path to set as wallpaper
|
||||
"""
|
||||
pass
|
||||
|
||||
# Window management
|
||||
@abstractmethod
|
||||
async def open(self, target: str) -> None:
|
||||
"""Open a target using the system's default handler.
|
||||
|
||||
Typically opens files, folders, or URLs with the associated application.
|
||||
|
||||
Args:
|
||||
target: The file path, folder path, or URL to open.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def launch(self, app: str, args: List[str] | None = None) -> Optional[int]:
|
||||
"""Launch an application with optional arguments.
|
||||
|
||||
Args:
|
||||
app: The application executable or bundle identifier.
|
||||
args: Optional list of arguments to pass to the application.
|
||||
|
||||
Returns:
|
||||
Optional process ID (PID) of the launched application if available, otherwise None.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_current_window_id(self) -> int | str:
|
||||
"""Get the identifier of the currently active/focused window.
|
||||
|
||||
Returns:
|
||||
A window identifier that can be used with other window management methods.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_application_windows(self, app: str) -> List[int | str]:
|
||||
"""Get all window identifiers for a specific application.
|
||||
|
||||
Args:
|
||||
app: The application name, executable, or identifier to query.
|
||||
|
||||
Returns:
|
||||
A list of window identifiers belonging to the specified application.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_window_name(self, window_id: int | str) -> str:
|
||||
"""Get the title/name of a window.
|
||||
|
||||
Args:
|
||||
window_id: The window identifier.
|
||||
|
||||
Returns:
|
||||
The window's title or name string.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_window_size(self, window_id: int | str) -> tuple[int, int]:
|
||||
"""Get the size of a window in pixels.
|
||||
|
||||
Args:
|
||||
window_id: The window identifier.
|
||||
|
||||
Returns:
|
||||
A tuple of (width, height) representing the window size in pixels.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_window_position(self, window_id: int | str) -> tuple[int, int]:
|
||||
"""Get the screen position of a window.
|
||||
|
||||
Args:
|
||||
window_id: The window identifier.
|
||||
|
||||
Returns:
|
||||
A tuple of (x, y) representing the window's top-left corner in screen coordinates.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_window_size(self, window_id: int | str, width: int, height: int) -> None:
|
||||
"""Set the size of a window in pixels.
|
||||
|
||||
Args:
|
||||
window_id: The window identifier.
|
||||
width: Desired width in pixels.
|
||||
height: Desired height in pixels.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def set_window_position(self, window_id: int | str, x: int, y: int) -> None:
|
||||
"""Move a window to a specific position on the screen.
|
||||
|
||||
Args:
|
||||
window_id: The window identifier.
|
||||
x: X coordinate for the window's top-left corner.
|
||||
y: Y coordinate for the window's top-left corner.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def maximize_window(self, window_id: int | str) -> None:
|
||||
"""Maximize a window.
|
||||
|
||||
Args:
|
||||
window_id: The window identifier.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def minimize_window(self, window_id: int | str) -> None:
|
||||
"""Minimize a window.
|
||||
|
||||
Args:
|
||||
window_id: The window identifier.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def activate_window(self, window_id: int | str) -> None:
|
||||
"""Bring a window to the foreground and focus it.
|
||||
|
||||
Args:
|
||||
window_id: The window identifier.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close_window(self, window_id: int | str) -> None:
|
||||
"""Close a window.
|
||||
|
||||
Args:
|
||||
window_id: The window identifier.
|
||||
"""
|
||||
pass
|
||||
|
||||
# Convenience aliases
|
||||
async def get_window_title(self, window_id: int | str) -> str:
|
||||
"""Convenience alias for get_window_name().
|
||||
|
||||
Args:
|
||||
window_id: The window identifier.
|
||||
|
||||
Returns:
|
||||
The window's title or name string.
|
||||
"""
|
||||
return await self.get_window_name(window_id)
|
||||
|
||||
async def window_size(self, window_id: int | str) -> tuple[int, int]:
|
||||
"""Convenience alias for get_window_size().
|
||||
|
||||
Args:
|
||||
window_id: The window identifier.
|
||||
|
||||
Returns:
|
||||
A tuple of (width, height) representing the window size in pixels.
|
||||
"""
|
||||
return await self.get_window_size(window_id)
|
||||
|
||||
# Shell actions
|
||||
@abstractmethod
|
||||
async def run_command(self, command: str) -> CommandResult:
|
||||
"""Run shell command and return structured result.
|
||||
|
||||
@@ -487,6 +487,104 @@ class GenericComputerInterface(BaseComputerInterface):
|
||||
raise RuntimeError(result.get("error", "Failed to list directory"))
|
||||
return result.get("files", [])
|
||||
|
||||
# Desktop actions
|
||||
async def get_desktop_environment(self) -> str:
|
||||
result = await self._send_command("get_desktop_environment")
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to get desktop environment"))
|
||||
return result.get("environment", "unknown")
|
||||
|
||||
async def set_wallpaper(self, path: str) -> None:
|
||||
result = await self._send_command("set_wallpaper", {"path": path})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to set wallpaper"))
|
||||
|
||||
# Window management
|
||||
async def open(self, target: str) -> None:
|
||||
result = await self._send_command("open", {"target": target})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to open target"))
|
||||
|
||||
async def launch(self, app: str, args: list[str] | None = None) -> int | None:
|
||||
payload: dict[str, object] = {"app": app}
|
||||
if args is not None:
|
||||
payload["args"] = args
|
||||
result = await self._send_command("launch", payload)
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to launch application"))
|
||||
return result.get("pid") # type: ignore[return-value]
|
||||
|
||||
async def get_current_window_id(self) -> int | str:
|
||||
result = await self._send_command("get_current_window_id")
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to get current window id"))
|
||||
return result["window_id"] # type: ignore[return-value]
|
||||
|
||||
async def get_application_windows(self, app: str) -> list[int | str]:
|
||||
result = await self._send_command("get_application_windows", {"app": app})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to get application windows"))
|
||||
return list(result.get("windows", [])) # type: ignore[return-value]
|
||||
|
||||
async def get_window_name(self, window_id: int | str) -> str:
|
||||
result = await self._send_command("get_window_name", {"window_id": window_id})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to get window name"))
|
||||
return result.get("name", "") # type: ignore[return-value]
|
||||
|
||||
async def get_window_size(self, window_id: int | str) -> tuple[int, int]:
|
||||
result = await self._send_command("get_window_size", {"window_id": window_id})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to get window size"))
|
||||
return int(result.get("width", 0)), int(result.get("height", 0))
|
||||
|
||||
async def get_window_position(self, window_id: int | str) -> tuple[int, int]:
|
||||
result = await self._send_command("get_window_position", {"window_id": window_id})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to get window position"))
|
||||
return int(result.get("x", 0)), int(result.get("y", 0))
|
||||
|
||||
async def set_window_size(self, window_id: int | str, width: int, height: int) -> None:
|
||||
result = await self._send_command(
|
||||
"set_window_size", {"window_id": window_id, "width": width, "height": height}
|
||||
)
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to set window size"))
|
||||
|
||||
async def set_window_position(self, window_id: int | str, x: int, y: int) -> None:
|
||||
result = await self._send_command(
|
||||
"set_window_position", {"window_id": window_id, "x": x, "y": y}
|
||||
)
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to set window position"))
|
||||
|
||||
async def maximize_window(self, window_id: int | str) -> None:
|
||||
result = await self._send_command("maximize_window", {"window_id": window_id})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to maximize window"))
|
||||
|
||||
async def minimize_window(self, window_id: int | str) -> None:
|
||||
result = await self._send_command("minimize_window", {"window_id": window_id})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to minimize window"))
|
||||
|
||||
async def activate_window(self, window_id: int | str) -> None:
|
||||
result = await self._send_command("activate_window", {"window_id": window_id})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to activate window"))
|
||||
|
||||
async def close_window(self, window_id: int | str) -> None:
|
||||
result = await self._send_command("close_window", {"window_id": window_id})
|
||||
if not result.get("success", False):
|
||||
raise RuntimeError(result.get("error", "Failed to close window"))
|
||||
|
||||
# Convenience aliases
|
||||
async def get_window_title(self, window_id: int | str) -> str:
|
||||
return await self.get_window_name(window_id)
|
||||
|
||||
async def window_size(self, window_id: int | str) -> tuple[int, int]:
|
||||
return await self.get_window_size(window_id)
|
||||
|
||||
# Command execution
|
||||
async def run_command(self, command: str) -> CommandResult:
|
||||
result = await self._send_command("run_command", {"command": command})
|
||||
|
||||
@@ -258,14 +258,20 @@ class DockerProvider(BaseVMProvider):
|
||||
logger.info(f"Container {name} is already running")
|
||||
return existing_vm
|
||||
elif existing_vm["status"] in ["stopped", "paused"]:
|
||||
# Start existing container
|
||||
logger.info(f"Starting existing container {name}")
|
||||
start_cmd = ["docker", "start", name]
|
||||
result = subprocess.run(start_cmd, capture_output=True, text=True, check=True)
|
||||
if self.ephemeral:
|
||||
# Delete existing container
|
||||
logger.info(f"Deleting existing container {name}")
|
||||
delete_cmd = ["docker", "rm", name]
|
||||
result = subprocess.run(delete_cmd, capture_output=True, text=True, check=True)
|
||||
else:
|
||||
# Start existing container
|
||||
logger.info(f"Starting existing container {name}")
|
||||
start_cmd = ["docker", "start", name]
|
||||
result = subprocess.run(start_cmd, capture_output=True, text=True, check=True)
|
||||
|
||||
# Wait for container to be ready
|
||||
await self._wait_for_container_ready(name)
|
||||
return await self.get_vm(name, storage)
|
||||
# Wait for container to be ready
|
||||
await self._wait_for_container_ready(name)
|
||||
return await self.get_vm(name, storage)
|
||||
|
||||
# Use provided image or default
|
||||
docker_image = image if image != "default" else self.image
|
||||
@@ -307,6 +313,20 @@ class DockerProvider(BaseVMProvider):
|
||||
cmd.extend(["-e", "VNC_PW=password"]) # Set VNC password
|
||||
cmd.extend(["-e", "VNCOPTIONS=-disableBasicAuth"]) # Disable VNC basic auth
|
||||
|
||||
# Apply display resolution if provided (e.g., "1024x768")
|
||||
display_resolution = run_opts.get("display")
|
||||
if (
|
||||
isinstance(display_resolution, dict)
|
||||
and "width" in display_resolution
|
||||
and "height" in display_resolution
|
||||
):
|
||||
cmd.extend(
|
||||
[
|
||||
"-e",
|
||||
f"VNC_RESOLUTION={display_resolution['width']}x{display_resolution['height']}",
|
||||
]
|
||||
)
|
||||
|
||||
# Add the image
|
||||
cmd.append(docker_image)
|
||||
|
||||
@@ -388,6 +408,11 @@ class DockerProvider(BaseVMProvider):
|
||||
|
||||
logger.info(f"Container {name} stopped successfully")
|
||||
|
||||
# Delete container if ephemeral=True
|
||||
if self.ephemeral:
|
||||
cmd = ["docker", "rm", name]
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
||||
|
||||
return {
|
||||
"name": name,
|
||||
"status": "stopped",
|
||||
|
||||
@@ -10,6 +10,8 @@ import subprocess
|
||||
import urllib.parse
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from computer.utils import safe_join
|
||||
|
||||
# Setup logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -59,7 +61,7 @@ def lume_api_get(
|
||||
# --max-time: Maximum time for the whole operation (20 seconds)
|
||||
# -f: Fail silently (no output at all) on server errors
|
||||
# Add single quotes around URL to ensure special characters are handled correctly
|
||||
cmd = ["curl", "--connect-timeout", "15", "--max-time", "20", "-s", "-f", f"'{api_url}'"]
|
||||
cmd = ["curl", "--connect-timeout", "15", "--max-time", "20", "-s", "-f", api_url]
|
||||
|
||||
# For logging and display, show the properly escaped URL
|
||||
display_cmd = ["curl", "--connect-timeout", "15", "--max-time", "20", "-s", "-f", api_url]
|
||||
@@ -71,7 +73,7 @@ def lume_api_get(
|
||||
# Execute the command - for execution we need to use shell=True to handle URLs with special characters
|
||||
try:
|
||||
# Use a single string with shell=True for proper URL handling
|
||||
shell_cmd = " ".join(cmd)
|
||||
shell_cmd = safe_join(cmd)
|
||||
result = subprocess.run(shell_cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
# Handle curl exit codes
|
||||
@@ -514,7 +516,7 @@ def lume_api_delete(
|
||||
"-s",
|
||||
"-X",
|
||||
"DELETE",
|
||||
f"'{api_url}'",
|
||||
api_url,
|
||||
]
|
||||
|
||||
# For logging and display, show the properly escaped URL
|
||||
@@ -537,7 +539,7 @@ def lume_api_delete(
|
||||
# Execute the command - for execution we need to use shell=True to handle URLs with special characters
|
||||
try:
|
||||
# Use a single string with shell=True for proper URL handling
|
||||
shell_cmd = " ".join(cmd)
|
||||
shell_cmd = safe_join(cmd)
|
||||
result = subprocess.run(shell_cmd, shell=True, capture_output=True, text=True)
|
||||
|
||||
# Handle curl exit codes
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import shlex
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import mslex
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
|
||||
@@ -104,3 +107,25 @@ def parse_vm_info(vm_info: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Parse VM info from pylume response."""
|
||||
if not vm_info:
|
||||
return None
|
||||
|
||||
|
||||
def safe_join(argv: list[str]) -> str:
|
||||
"""
|
||||
Return a platform-correct string that safely quotes `argv` for shell execution.
|
||||
|
||||
- On POSIX: uses `shlex.join`.
|
||||
- On Windows: uses `shlex.join`.
|
||||
|
||||
Args:
|
||||
argv: iterable of argument strings (will be coerced to str).
|
||||
|
||||
Returns:
|
||||
A safely quoted command-line string appropriate for the current platform that protects against
|
||||
shell injection vulnerabilities.
|
||||
"""
|
||||
if os.name == "nt":
|
||||
# On Windows, use mslex for proper quoting
|
||||
return mslex.join(argv)
|
||||
else:
|
||||
# On POSIX systems, use shlex
|
||||
return shlex.join(argv)
|
||||
|
||||
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-computer"
|
||||
version = "0.4.10"
|
||||
version = "0.4.11"
|
||||
description = "Computer-Use Interface (CUI) framework powering Cua"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
@@ -16,7 +16,8 @@ dependencies = [
|
||||
"websockets>=12.0",
|
||||
"aiohttp>=3.9.0",
|
||||
"cua-core>=0.1.0,<0.2.0",
|
||||
"pydantic>=2.11.1"
|
||||
"pydantic>=2.11.1",
|
||||
"mslex>=1.3.0",
|
||||
]
|
||||
requires-python = ">=3.12"
|
||||
|
||||
@@ -47,4 +48,4 @@ source-includes = ["tests/", "README.md", "LICENSE"]
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
python_files = "test_*.py"
|
||||
python_files = "test_*.py"
|
||||
|
||||
69
libs/python/computer/tests/conftest.py
Normal file
69
libs/python/computer/tests/conftest.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Pytest configuration and shared fixtures for computer package tests.
|
||||
|
||||
This file contains shared fixtures and configuration for all computer tests.
|
||||
Following SRP: This file ONLY handles test setup/teardown.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_interface():
|
||||
"""Mock computer interface for testing.
|
||||
|
||||
Use this fixture to test Computer logic without real OS calls.
|
||||
"""
|
||||
interface = AsyncMock()
|
||||
interface.screenshot = AsyncMock(return_value=b"fake_screenshot")
|
||||
interface.left_click = AsyncMock()
|
||||
interface.right_click = AsyncMock()
|
||||
interface.middle_click = AsyncMock()
|
||||
interface.double_click = AsyncMock()
|
||||
interface.type = AsyncMock()
|
||||
interface.key = AsyncMock()
|
||||
interface.move_mouse = AsyncMock()
|
||||
interface.scroll = AsyncMock()
|
||||
interface.get_screen_size = AsyncMock(return_value=(1920, 1080))
|
||||
|
||||
return interface
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_cloud_provider():
|
||||
"""Mock cloud provider for testing.
|
||||
|
||||
Use this fixture to test cloud provider logic without real API calls.
|
||||
"""
|
||||
provider = AsyncMock()
|
||||
provider.start = AsyncMock()
|
||||
provider.stop = AsyncMock()
|
||||
provider.get_status = AsyncMock(return_value="running")
|
||||
provider.execute_command = AsyncMock(return_value="command output")
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_local_provider():
|
||||
"""Mock local provider for testing.
|
||||
|
||||
Use this fixture to test local provider logic without real VM operations.
|
||||
"""
|
||||
provider = AsyncMock()
|
||||
provider.start = AsyncMock()
|
||||
provider.stop = AsyncMock()
|
||||
provider.get_status = AsyncMock(return_value="running")
|
||||
provider.execute_command = AsyncMock(return_value="command output")
|
||||
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_telemetry(monkeypatch):
|
||||
"""Disable telemetry for tests.
|
||||
|
||||
Use this fixture to ensure no telemetry is sent during tests.
|
||||
"""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
|
||||
67
libs/python/computer/tests/test_computer.py
Normal file
67
libs/python/computer/tests/test_computer.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Unit tests for Computer class.
|
||||
|
||||
This file tests ONLY the Computer class initialization and context manager.
|
||||
Following SRP: This file tests ONE class (Computer).
|
||||
All external dependencies (providers, interfaces) are mocked.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestComputerImport:
|
||||
"""Test Computer module imports (SRP: Only tests imports)."""
|
||||
|
||||
def test_computer_class_exists(self):
|
||||
"""Test that Computer class can be imported."""
|
||||
from computer import Computer
|
||||
|
||||
assert Computer is not None
|
||||
|
||||
def test_vm_provider_type_exists(self):
|
||||
"""Test that VMProviderType enum can be imported."""
|
||||
from computer import VMProviderType
|
||||
|
||||
assert VMProviderType is not None
|
||||
|
||||
|
||||
class TestComputerInitialization:
|
||||
"""Test Computer initialization (SRP: Only tests initialization)."""
|
||||
|
||||
def test_computer_class_can_be_imported(self, disable_telemetry):
|
||||
"""Test that Computer class can be imported without errors."""
|
||||
from computer import Computer
|
||||
|
||||
assert Computer is not None
|
||||
|
||||
def test_computer_has_required_methods(self, disable_telemetry):
|
||||
"""Test that Computer class has required methods."""
|
||||
from computer import Computer
|
||||
|
||||
assert hasattr(Computer, "__aenter__")
|
||||
assert hasattr(Computer, "__aexit__")
|
||||
|
||||
|
||||
class TestComputerContextManager:
|
||||
"""Test Computer context manager protocol (SRP: Only tests context manager)."""
|
||||
|
||||
def test_computer_is_async_context_manager(self, disable_telemetry):
|
||||
"""Test that Computer has async context manager methods."""
|
||||
from computer import Computer
|
||||
|
||||
assert hasattr(Computer, "__aenter__")
|
||||
assert hasattr(Computer, "__aexit__")
|
||||
assert callable(Computer.__aenter__)
|
||||
assert callable(Computer.__aexit__)
|
||||
|
||||
|
||||
class TestComputerInterface:
|
||||
"""Test Computer.interface property (SRP: Only tests interface access)."""
|
||||
|
||||
def test_computer_class_structure(self, disable_telemetry):
|
||||
"""Test that Computer class has expected structure."""
|
||||
from computer import Computer
|
||||
|
||||
# Verify Computer is a class
|
||||
assert isinstance(Computer, type)
|
||||
43
libs/python/core/tests/conftest.py
Normal file
43
libs/python/core/tests/conftest.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Pytest configuration and shared fixtures for core package tests.
|
||||
|
||||
This file contains shared fixtures and configuration for all core tests.
|
||||
Following SRP: This file ONLY handles test setup/teardown.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_httpx_client():
|
||||
"""Mock httpx.AsyncClient for API calls.
|
||||
|
||||
Use this fixture to avoid making real HTTP requests during tests.
|
||||
"""
|
||||
with patch("httpx.AsyncClient") as mock_client:
|
||||
mock_instance = AsyncMock()
|
||||
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||
yield mock_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_posthog():
|
||||
"""Mock PostHog client for telemetry tests.
|
||||
|
||||
Use this fixture to avoid sending real telemetry during tests.
|
||||
"""
|
||||
with patch("posthog.Posthog") as mock_ph:
|
||||
mock_instance = Mock()
|
||||
mock_ph.return_value = mock_instance
|
||||
yield mock_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_telemetry(monkeypatch):
|
||||
"""Disable telemetry for tests that don't need it.
|
||||
|
||||
Use this fixture to ensure telemetry is disabled during tests.
|
||||
"""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
|
||||
yield
|
||||
255
libs/python/core/tests/test_telemetry.py
Normal file
255
libs/python/core/tests/test_telemetry.py
Normal file
@@ -0,0 +1,255 @@
|
||||
"""Unit tests for core telemetry functionality.
|
||||
|
||||
This file tests ONLY telemetry logic, following SRP.
|
||||
All external dependencies (PostHog, file system) are mocked.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, Mock, mock_open, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestTelemetryEnabled:
|
||||
"""Test telemetry enable/disable logic (SRP: Only tests enable/disable)."""
|
||||
|
||||
def test_telemetry_enabled_by_default(self, monkeypatch):
|
||||
"""Test that telemetry is enabled by default."""
|
||||
# Remove any environment variables that might affect the test
|
||||
monkeypatch.delenv("CUA_TELEMETRY", raising=False)
|
||||
monkeypatch.delenv("CUA_TELEMETRY_ENABLED", raising=False)
|
||||
|
||||
from core.telemetry import is_telemetry_enabled
|
||||
|
||||
assert is_telemetry_enabled() is True
|
||||
|
||||
def test_telemetry_disabled_with_legacy_flag(self, monkeypatch):
|
||||
"""Test that telemetry can be disabled with legacy CUA_TELEMETRY=off."""
|
||||
monkeypatch.setenv("CUA_TELEMETRY", "off")
|
||||
|
||||
from core.telemetry import is_telemetry_enabled
|
||||
|
||||
assert is_telemetry_enabled() is False
|
||||
|
||||
def test_telemetry_disabled_with_new_flag(self, monkeypatch):
|
||||
"""Test that telemetry can be disabled with CUA_TELEMETRY_ENABLED=false."""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", "false")
|
||||
|
||||
from core.telemetry import is_telemetry_enabled
|
||||
|
||||
assert is_telemetry_enabled() is False
|
||||
|
||||
@pytest.mark.parametrize("value", ["0", "false", "no", "off"])
|
||||
def test_telemetry_disabled_with_various_values(self, monkeypatch, value):
|
||||
"""Test that telemetry respects various disable values."""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", value)
|
||||
|
||||
from core.telemetry import is_telemetry_enabled
|
||||
|
||||
assert is_telemetry_enabled() is False
|
||||
|
||||
@pytest.mark.parametrize("value", ["1", "true", "yes", "on"])
|
||||
def test_telemetry_enabled_with_various_values(self, monkeypatch, value):
|
||||
"""Test that telemetry respects various enable values."""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", value)
|
||||
|
||||
from core.telemetry import is_telemetry_enabled
|
||||
|
||||
assert is_telemetry_enabled() is True
|
||||
|
||||
|
||||
class TestPostHogTelemetryClient:
|
||||
"""Test PostHogTelemetryClient class (SRP: Only tests client logic)."""
|
||||
|
||||
@patch("core.telemetry.posthog.posthog")
|
||||
@patch("core.telemetry.posthog.Path")
|
||||
def test_client_initialization(self, mock_path, mock_posthog, disable_telemetry):
|
||||
"""Test that client initializes correctly."""
|
||||
from core.telemetry.posthog import PostHogTelemetryClient
|
||||
|
||||
# Mock the storage directory
|
||||
mock_storage_dir = MagicMock()
|
||||
mock_storage_dir.exists.return_value = False
|
||||
mock_path.return_value.parent.parent = MagicMock()
|
||||
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
|
||||
|
||||
# Reset singleton
|
||||
PostHogTelemetryClient.destroy_client()
|
||||
|
||||
client = PostHogTelemetryClient()
|
||||
|
||||
assert client is not None
|
||||
assert hasattr(client, "installation_id")
|
||||
assert hasattr(client, "initialized")
|
||||
assert hasattr(client, "queued_events")
|
||||
|
||||
@patch("core.telemetry.posthog.posthog")
|
||||
@patch("core.telemetry.posthog.Path")
|
||||
def test_installation_id_generation(self, mock_path, mock_posthog, disable_telemetry):
|
||||
"""Test that installation ID is generated if not exists."""
|
||||
from core.telemetry.posthog import PostHogTelemetryClient
|
||||
|
||||
# Mock file system
|
||||
mock_id_file = MagicMock()
|
||||
mock_id_file.exists.return_value = False
|
||||
mock_storage_dir = MagicMock()
|
||||
mock_storage_dir.__truediv__.return_value = mock_id_file
|
||||
|
||||
mock_core_dir = MagicMock()
|
||||
mock_core_dir.__truediv__.return_value = mock_storage_dir
|
||||
mock_path.return_value.parent.parent = mock_core_dir
|
||||
|
||||
# Reset singleton
|
||||
PostHogTelemetryClient.destroy_client()
|
||||
|
||||
client = PostHogTelemetryClient()
|
||||
|
||||
# Should have generated a new UUID
|
||||
assert client.installation_id is not None
|
||||
assert len(client.installation_id) == 36 # UUID format
|
||||
|
||||
@patch("core.telemetry.posthog.posthog")
|
||||
@patch("core.telemetry.posthog.Path")
|
||||
def test_installation_id_persistence(self, mock_path, mock_posthog, disable_telemetry):
|
||||
"""Test that installation ID is read from file if exists."""
|
||||
from core.telemetry.posthog import PostHogTelemetryClient
|
||||
|
||||
existing_id = "test-installation-id-123"
|
||||
|
||||
# Mock file system
|
||||
mock_id_file = MagicMock()
|
||||
mock_id_file.exists.return_value = True
|
||||
mock_id_file.read_text.return_value = existing_id
|
||||
|
||||
mock_storage_dir = MagicMock()
|
||||
mock_storage_dir.__truediv__.return_value = mock_id_file
|
||||
|
||||
mock_core_dir = MagicMock()
|
||||
mock_core_dir.__truediv__.return_value = mock_storage_dir
|
||||
mock_path.return_value.parent.parent = mock_core_dir
|
||||
|
||||
# Reset singleton
|
||||
PostHogTelemetryClient.destroy_client()
|
||||
|
||||
client = PostHogTelemetryClient()
|
||||
|
||||
assert client.installation_id == existing_id
|
||||
|
||||
@patch("core.telemetry.posthog.posthog")
|
||||
@patch("core.telemetry.posthog.Path")
|
||||
def test_record_event_when_disabled(self, mock_path, mock_posthog, monkeypatch):
|
||||
"""Test that events are not recorded when telemetry is disabled."""
|
||||
from core.telemetry.posthog import PostHogTelemetryClient
|
||||
|
||||
# Disable telemetry explicitly using the correct environment variable
|
||||
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", "false")
|
||||
|
||||
# Mock file system
|
||||
mock_storage_dir = MagicMock()
|
||||
mock_storage_dir.exists.return_value = False
|
||||
mock_path.return_value.parent.parent = MagicMock()
|
||||
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
|
||||
|
||||
# Reset singleton
|
||||
PostHogTelemetryClient.destroy_client()
|
||||
|
||||
client = PostHogTelemetryClient()
|
||||
client.record_event("test_event", {"key": "value"})
|
||||
|
||||
# PostHog capture should not be called at all when telemetry is disabled
|
||||
mock_posthog.capture.assert_not_called()
|
||||
|
||||
@patch("core.telemetry.posthog.posthog")
|
||||
@patch("core.telemetry.posthog.Path")
|
||||
def test_record_event_when_enabled(self, mock_path, mock_posthog, monkeypatch):
|
||||
"""Test that events are recorded when telemetry is enabled."""
|
||||
from core.telemetry.posthog import PostHogTelemetryClient
|
||||
|
||||
# Enable telemetry
|
||||
monkeypatch.setenv("CUA_TELEMETRY_ENABLED", "true")
|
||||
|
||||
# Mock file system
|
||||
mock_storage_dir = MagicMock()
|
||||
mock_storage_dir.exists.return_value = False
|
||||
mock_path.return_value.parent.parent = MagicMock()
|
||||
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
|
||||
|
||||
# Reset singleton
|
||||
PostHogTelemetryClient.destroy_client()
|
||||
|
||||
client = PostHogTelemetryClient()
|
||||
client.initialized = True # Pretend it's initialized
|
||||
|
||||
event_name = "test_event"
|
||||
event_props = {"key": "value"}
|
||||
client.record_event(event_name, event_props)
|
||||
|
||||
# PostHog capture should be called
|
||||
assert mock_posthog.capture.call_count >= 1
|
||||
|
||||
@patch("core.telemetry.posthog.posthog")
|
||||
@patch("core.telemetry.posthog.Path")
|
||||
def test_singleton_pattern(self, mock_path, mock_posthog, disable_telemetry):
|
||||
"""Test that get_client returns the same instance."""
|
||||
from core.telemetry.posthog import PostHogTelemetryClient
|
||||
|
||||
# Mock file system
|
||||
mock_storage_dir = MagicMock()
|
||||
mock_storage_dir.exists.return_value = False
|
||||
mock_path.return_value.parent.parent = MagicMock()
|
||||
mock_path.return_value.parent.parent.__truediv__.return_value = mock_storage_dir
|
||||
|
||||
# Reset singleton
|
||||
PostHogTelemetryClient.destroy_client()
|
||||
|
||||
client1 = PostHogTelemetryClient.get_client()
|
||||
client2 = PostHogTelemetryClient.get_client()
|
||||
|
||||
assert client1 is client2
|
||||
|
||||
|
||||
class TestRecordEvent:
|
||||
"""Test the public record_event function (SRP: Only tests public API)."""
|
||||
|
||||
@patch("core.telemetry.posthog.PostHogTelemetryClient")
|
||||
def test_record_event_calls_client(self, mock_client_class, disable_telemetry):
|
||||
"""Test that record_event delegates to the client."""
|
||||
from core.telemetry import record_event
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_class.get_client.return_value = mock_client_instance
|
||||
|
||||
event_name = "test_event"
|
||||
event_props = {"key": "value"}
|
||||
|
||||
record_event(event_name, event_props)
|
||||
|
||||
mock_client_instance.record_event.assert_called_once_with(event_name, event_props)
|
||||
|
||||
@patch("core.telemetry.posthog.PostHogTelemetryClient")
|
||||
def test_record_event_without_properties(self, mock_client_class, disable_telemetry):
|
||||
"""Test that record_event works without properties."""
|
||||
from core.telemetry import record_event
|
||||
|
||||
mock_client_instance = Mock()
|
||||
mock_client_class.get_client.return_value = mock_client_instance
|
||||
|
||||
event_name = "test_event"
|
||||
|
||||
record_event(event_name)
|
||||
|
||||
mock_client_instance.record_event.assert_called_once_with(event_name, {})
|
||||
|
||||
|
||||
class TestDestroyTelemetryClient:
|
||||
"""Test client destruction (SRP: Only tests cleanup)."""
|
||||
|
||||
@patch("core.telemetry.posthog.PostHogTelemetryClient")
|
||||
def test_destroy_client_calls_class_method(self, mock_client_class):
|
||||
"""Test that destroy_telemetry_client delegates correctly."""
|
||||
from core.telemetry import destroy_telemetry_client
|
||||
|
||||
destroy_telemetry_client()
|
||||
|
||||
mock_client_class.destroy_client.assert_called_once()
|
||||
@@ -1,5 +1,5 @@
|
||||
[bumpversion]
|
||||
current_version = 0.1.14
|
||||
current_version = 0.1.15
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = mcp-server-v{new_version}
|
||||
|
||||
@@ -7,7 +7,7 @@ name = "cua-mcp-server"
|
||||
description = "MCP Server for Computer-Use Agent (CUA)"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
version = "0.1.14"
|
||||
version = "0.1.15"
|
||||
authors = [
|
||||
{name = "TryCua", email = "gh@trycua.com"}
|
||||
]
|
||||
|
||||
51
libs/python/mcp-server/tests/conftest.py
Normal file
51
libs/python/mcp-server/tests/conftest.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""Pytest configuration and shared fixtures for mcp-server package tests.
|
||||
|
||||
This file contains shared fixtures and configuration for all mcp-server tests.
|
||||
Following SRP: This file ONLY handles test setup/teardown.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_context():
|
||||
"""Mock MCP context for testing.
|
||||
|
||||
Use this fixture to test MCP server logic without real MCP connections.
|
||||
"""
|
||||
context = AsyncMock()
|
||||
context.request_context = AsyncMock()
|
||||
context.session = Mock()
|
||||
context.session.send_resource_updated = AsyncMock()
|
||||
|
||||
return context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_computer():
|
||||
"""Mock Computer instance for MCP server tests.
|
||||
|
||||
Use this fixture to test MCP logic without real Computer operations.
|
||||
"""
|
||||
computer = AsyncMock()
|
||||
computer.interface = AsyncMock()
|
||||
computer.interface.screenshot = AsyncMock(return_value=b"fake_screenshot")
|
||||
computer.interface.left_click = AsyncMock()
|
||||
computer.interface.type = AsyncMock()
|
||||
|
||||
# Mock context manager
|
||||
computer.__aenter__ = AsyncMock(return_value=computer)
|
||||
computer.__aexit__ = AsyncMock()
|
||||
|
||||
return computer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_telemetry(monkeypatch):
|
||||
"""Disable telemetry for tests.
|
||||
|
||||
Use this fixture to ensure no telemetry is sent during tests.
|
||||
"""
|
||||
monkeypatch.setenv("CUA_TELEMETRY_DISABLED", "1")
|
||||
44
libs/python/mcp-server/tests/test_mcp_server.py
Normal file
44
libs/python/mcp-server/tests/test_mcp_server.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Unit tests for mcp-server package.
|
||||
|
||||
This file tests ONLY basic MCP server functionality.
|
||||
Following SRP: This file tests MCP server initialization.
|
||||
All external dependencies are mocked.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestMCPServerImports:
|
||||
"""Test MCP server module imports (SRP: Only tests imports)."""
|
||||
|
||||
def test_mcp_server_module_exists(self):
|
||||
"""Test that mcp_server module can be imported."""
|
||||
try:
|
||||
import mcp_server
|
||||
|
||||
assert mcp_server is not None
|
||||
except ImportError:
|
||||
pytest.skip("mcp_server module not installed")
|
||||
except SystemExit:
|
||||
pytest.skip("MCP dependencies (mcp.server.fastmcp) not available")
|
||||
|
||||
|
||||
class TestMCPServerInitialization:
|
||||
"""Test MCP server initialization (SRP: Only tests initialization)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_server_can_be_imported(self):
|
||||
"""Basic smoke test: verify MCP server components can be imported."""
|
||||
try:
|
||||
from mcp_server import server
|
||||
|
||||
assert server is not None
|
||||
except ImportError:
|
||||
pytest.skip("MCP server module not available")
|
||||
except SystemExit:
|
||||
pytest.skip("MCP dependencies (mcp.server.fastmcp) not available")
|
||||
except Exception as e:
|
||||
# Some initialization errors are acceptable in unit tests
|
||||
pytest.skip(f"MCP server initialization requires specific setup: {e}")
|
||||
@@ -1,10 +0,0 @@
|
||||
[bumpversion]
|
||||
current_version = 0.2.1
|
||||
commit = True
|
||||
tag = True
|
||||
tag_name = pylume-v{new_version}
|
||||
message = Bump pylume to v{new_version}
|
||||
|
||||
[bumpversion:file:pylume/__init__.py]
|
||||
search = __version__ = "{current_version}"
|
||||
replace = __version__ = "{new_version}"
|
||||
@@ -1,46 +0,0 @@
|
||||
<div align="center">
|
||||
<h1>
|
||||
<div class="image-wrapper" style="display: inline-block;">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="https://raw.githubusercontent.com/trycua/cua/main/img/logo_white.png" style="display: block; margin: auto;">
|
||||
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="https://raw.githubusercontent.com/trycua/cua/main/img/logo_black.png" style="display: block; margin: auto;">
|
||||
<img alt="Shows my svg">
|
||||
</picture>
|
||||
</div>
|
||||
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/pylume/)
|
||||
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
**pylume** is a lightweight Python library based on [lume](https://github.com/trycua/lume) to create, run and manage macOS and Linux virtual machines (VMs) natively on Apple Silicon.
|
||||
|
||||
```bash
|
||||
pip install pylume
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
Please refer to this [Notebook](./samples/nb.ipynb) for a quickstart. More details about the underlying API used by pylume are available [here](https://github.com/trycua/lume/docs/API-Reference.md).
|
||||
|
||||
## Prebuilt Images
|
||||
|
||||
Pre-built images are available on [ghcr.io/trycua](https://github.com/orgs/trycua/packages).
|
||||
These images come pre-configured with an SSH server and auto-login enabled.
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome and greatly appreciate contributions to lume! Whether you're improving documentation, adding new features, fixing bugs, or adding new VM images, your efforts help make pylume better for everyone.
|
||||
|
||||
Join our [Discord community](https://discord.com/invite/mVnXXpdE85) to discuss ideas or get assistance.
|
||||
|
||||
## License
|
||||
|
||||
lume is open-sourced under the MIT License - see the [LICENSE](LICENSE) file for details.
|
||||
|
||||
## Stargazers over time
|
||||
|
||||
[](https://starchart.cc/trycua/pylume)
|
||||
@@ -1,9 +0,0 @@
|
||||
"""
|
||||
PyLume Python SDK - A client library for managing macOS VMs with PyLume.
|
||||
"""
|
||||
|
||||
from pylume.exceptions import *
|
||||
from pylume.models import *
|
||||
from pylume.pylume import *
|
||||
|
||||
__version__ = "0.1.0"
|
||||
@@ -1,59 +0,0 @@
|
||||
"""
|
||||
PyLume Python SDK - A client library for managing macOS VMs with PyLume.
|
||||
|
||||
Example:
|
||||
>>> from pylume import PyLume, VMConfig
|
||||
>>> client = PyLume()
|
||||
>>> config = VMConfig(name="my-vm", cpu=4, memory="8GB", disk_size="64GB")
|
||||
>>> client.create_vm(config)
|
||||
>>> client.run_vm("my-vm")
|
||||
"""
|
||||
|
||||
# Import exceptions then all models
|
||||
from .exceptions import (
|
||||
LumeConfigError,
|
||||
LumeConnectionError,
|
||||
LumeError,
|
||||
LumeImageError,
|
||||
LumeNotFoundError,
|
||||
LumeServerError,
|
||||
LumeTimeoutError,
|
||||
LumeVMError,
|
||||
)
|
||||
from .models import (
|
||||
CloneSpec,
|
||||
ImageInfo,
|
||||
ImageList,
|
||||
ImageRef,
|
||||
SharedDirectory,
|
||||
VMConfig,
|
||||
VMRunOpts,
|
||||
VMStatus,
|
||||
VMUpdateOpts,
|
||||
)
|
||||
|
||||
# Import main class last to avoid circular imports
|
||||
from .pylume import PyLume
|
||||
|
||||
__version__ = "0.2.1"
|
||||
|
||||
__all__ = [
|
||||
"PyLume",
|
||||
"VMConfig",
|
||||
"VMStatus",
|
||||
"VMRunOpts",
|
||||
"VMUpdateOpts",
|
||||
"ImageRef",
|
||||
"CloneSpec",
|
||||
"SharedDirectory",
|
||||
"ImageList",
|
||||
"ImageInfo",
|
||||
"LumeError",
|
||||
"LumeServerError",
|
||||
"LumeConnectionError",
|
||||
"LumeTimeoutError",
|
||||
"LumeNotFoundError",
|
||||
"LumeConfigError",
|
||||
"LumeVMError",
|
||||
"LumeImageError",
|
||||
]
|
||||
@@ -1,119 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import shlex
|
||||
import subprocess
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from .exceptions import (
|
||||
LumeConfigError,
|
||||
LumeConnectionError,
|
||||
LumeError,
|
||||
LumeNotFoundError,
|
||||
LumeServerError,
|
||||
LumeTimeoutError,
|
||||
)
|
||||
|
||||
|
||||
class LumeClient:
|
||||
def __init__(self, base_url: str, timeout: float = 60.0, debug: bool = False):
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.debug = debug
|
||||
|
||||
def _log_debug(self, message: str, **kwargs) -> None:
|
||||
"""Log debug information if debug mode is enabled."""
|
||||
if self.debug:
|
||||
print(f"DEBUG: {message}")
|
||||
if kwargs:
|
||||
print(json.dumps(kwargs, indent=2))
|
||||
|
||||
async def _run_curl(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
) -> Any:
|
||||
"""Execute a curl command and return the response."""
|
||||
url = f"{self.base_url}{path}"
|
||||
if params:
|
||||
param_str = "&".join(f"{k}={v}" for k, v in params.items())
|
||||
url = f"{url}?{param_str}"
|
||||
|
||||
cmd = ["curl", "-X", method, "-s", "-w", "%{http_code}", "-m", str(self.timeout)]
|
||||
|
||||
if data is not None:
|
||||
cmd.extend(["-H", "Content-Type: application/json", "-d", json.dumps(data)])
|
||||
|
||||
cmd.append(url)
|
||||
|
||||
self._log_debug(f"Running curl command: {' '.join(map(shlex.quote, cmd))}")
|
||||
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
raise LumeConnectionError(f"Curl command failed: {stderr.decode()}")
|
||||
|
||||
# The last 3 characters are the status code
|
||||
response = stdout.decode()
|
||||
status_code = int(response[-3:])
|
||||
response_body = response[:-3] # Remove status code from response
|
||||
|
||||
if status_code >= 400:
|
||||
if status_code == 404:
|
||||
raise LumeNotFoundError(f"Resource not found: {path}")
|
||||
elif status_code == 400:
|
||||
raise LumeConfigError(f"Invalid request: {response_body}")
|
||||
elif status_code >= 500:
|
||||
raise LumeServerError(f"Server error: {response_body}")
|
||||
else:
|
||||
raise LumeError(f"Request failed with status {status_code}: {response_body}")
|
||||
|
||||
return json.loads(response_body) if response_body.strip() else None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise LumeTimeoutError(f"Request timed out after {self.timeout} seconds")
|
||||
|
||||
async def get(self, path: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""Make a GET request."""
|
||||
return await self._run_curl("GET", path, params=params)
|
||||
|
||||
async def post(
|
||||
self, path: str, data: Optional[Dict[str, Any]] = None, timeout: Optional[float] = None
|
||||
) -> Any:
|
||||
"""Make a POST request."""
|
||||
old_timeout = self.timeout
|
||||
if timeout is not None:
|
||||
self.timeout = timeout
|
||||
try:
|
||||
return await self._run_curl("POST", path, data=data)
|
||||
finally:
|
||||
self.timeout = old_timeout
|
||||
|
||||
async def patch(self, path: str, data: Dict[str, Any]) -> None:
|
||||
"""Make a PATCH request."""
|
||||
await self._run_curl("PATCH", path, data=data)
|
||||
|
||||
async def delete(self, path: str) -> None:
|
||||
"""Make a DELETE request."""
|
||||
await self._run_curl("DELETE", path)
|
||||
|
||||
def print_curl(self, method: str, path: str, data: Optional[Dict[str, Any]] = None) -> None:
|
||||
"""Print equivalent curl command for debugging."""
|
||||
curl_cmd = f"""curl -X {method} \\
|
||||
'{self.base_url}{path}'"""
|
||||
|
||||
if data:
|
||||
curl_cmd += f" \\\n -H 'Content-Type: application/json' \\\n -d '{json.dumps(data)}'"
|
||||
|
||||
print("\nEquivalent curl command:")
|
||||
print(curl_cmd)
|
||||
print()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the client resources."""
|
||||
pass # No shared resources to clean up
|
||||
@@ -1,54 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LumeError(Exception):
|
||||
"""Base exception for all PyLume errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LumeServerError(LumeError):
|
||||
"""Raised when there's an error with the PyLume server."""
|
||||
|
||||
def __init__(
|
||||
self, message: str, status_code: Optional[int] = None, response_text: Optional[str] = None
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.response_text = response_text
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class LumeConnectionError(LumeError):
|
||||
"""Raised when there's an error connecting to the PyLume server."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LumeTimeoutError(LumeError):
|
||||
"""Raised when a request to the PyLume server times out."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LumeNotFoundError(LumeError):
|
||||
"""Raised when a requested resource is not found."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LumeConfigError(LumeError):
|
||||
"""Raised when there's an error with the configuration."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LumeVMError(LumeError):
|
||||
"""Raised when there's an error with a VM operation."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class LumeImageError(LumeError):
|
||||
"""Raised when there's an error with an image operation."""
|
||||
|
||||
pass
|
||||
Binary file not shown.
@@ -1,265 +0,0 @@
|
||||
import re
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, RootModel, computed_field, validator
|
||||
|
||||
|
||||
class DiskInfo(BaseModel):
|
||||
"""Information about disk storage allocation.
|
||||
|
||||
Attributes:
|
||||
total: Total disk space in bytes
|
||||
allocated: Currently allocated disk space in bytes
|
||||
"""
|
||||
|
||||
total: int
|
||||
allocated: int
|
||||
|
||||
|
||||
class VMConfig(BaseModel):
|
||||
"""Configuration for creating a new VM.
|
||||
|
||||
Note: Memory and disk sizes should be specified with units (e.g., "4GB", "64GB")
|
||||
|
||||
Attributes:
|
||||
name: Name of the virtual machine
|
||||
os: Operating system type, either "macOS" or "linux"
|
||||
cpu: Number of CPU cores to allocate
|
||||
memory: Amount of memory to allocate with units
|
||||
disk_size: Size of the disk to create with units
|
||||
display: Display resolution in format "widthxheight"
|
||||
ipsw: IPSW path or 'latest' for macOS VMs, None for other OS types
|
||||
"""
|
||||
|
||||
name: str
|
||||
os: Literal["macOS", "linux"] = "macOS"
|
||||
cpu: int = Field(default=2, ge=1)
|
||||
memory: str = "4GB"
|
||||
disk_size: str = Field(default="64GB", alias="diskSize")
|
||||
display: str = "1024x768"
|
||||
ipsw: Optional[str] = Field(default=None, description="IPSW path or 'latest', for macOS VMs")
|
||||
|
||||
class Config:
|
||||
populate_by_alias = True
|
||||
|
||||
|
||||
class SharedDirectory(BaseModel):
|
||||
"""Configuration for a shared directory.
|
||||
|
||||
Attributes:
|
||||
host_path: Path to the directory on the host system
|
||||
read_only: Whether the directory should be mounted as read-only
|
||||
"""
|
||||
|
||||
host_path: str = Field(..., alias="hostPath") # Allow host_path but serialize as hostPath
|
||||
read_only: bool = False
|
||||
|
||||
class Config:
|
||||
populate_by_name = True # Allow both alias and original name
|
||||
alias_generator = lambda s: "".join(
|
||||
word.capitalize() if i else word for i, word in enumerate(s.split("_"))
|
||||
)
|
||||
|
||||
|
||||
class VMRunOpts(BaseModel):
|
||||
"""Configuration for running a VM.
|
||||
|
||||
Args:
|
||||
no_display: Whether to not display the VNC client
|
||||
shared_directories: List of directories to share with the VM
|
||||
"""
|
||||
|
||||
no_display: bool = Field(default=False, alias="noDisplay")
|
||||
shared_directories: Optional[list[SharedDirectory]] = Field(
|
||||
default=None, alias="sharedDirectories"
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
populate_by_name=True,
|
||||
alias_generator=lambda s: "".join(
|
||||
word.capitalize() if i else word for i, word in enumerate(s.split("_"))
|
||||
),
|
||||
)
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
"""Export model data with proper field name conversion.
|
||||
|
||||
Converts shared directory fields to match API expectations when using aliases.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments passed to parent model_dump method
|
||||
|
||||
Returns:
|
||||
dict: Model data with properly formatted field names
|
||||
"""
|
||||
data = super().model_dump(**kwargs)
|
||||
# Convert shared directory fields to match API expectations
|
||||
if self.shared_directories and "by_alias" in kwargs and kwargs["by_alias"]:
|
||||
data["sharedDirectories"] = [
|
||||
{"hostPath": d.host_path, "readOnly": d.read_only} for d in self.shared_directories
|
||||
]
|
||||
# Remove the snake_case version if it exists
|
||||
data.pop("shared_directories", None)
|
||||
return data
|
||||
|
||||
|
||||
class VMStatus(BaseModel):
|
||||
"""Status information for a virtual machine.
|
||||
|
||||
Attributes:
|
||||
name: Name of the virtual machine
|
||||
status: Current status of the VM
|
||||
os: Operating system type
|
||||
cpu_count: Number of CPU cores allocated
|
||||
memory_size: Amount of memory allocated in bytes
|
||||
disk_size: Disk storage information
|
||||
vnc_url: URL for VNC connection if available
|
||||
ip_address: IP address of the VM if available
|
||||
"""
|
||||
|
||||
name: str
|
||||
status: str
|
||||
os: Literal["macOS", "linux"]
|
||||
cpu_count: int = Field(alias="cpuCount")
|
||||
memory_size: int = Field(alias="memorySize") # API returns memory size in bytes
|
||||
disk_size: DiskInfo = Field(alias="diskSize")
|
||||
vnc_url: Optional[str] = Field(default=None, alias="vncUrl")
|
||||
ip_address: Optional[str] = Field(default=None, alias="ipAddress")
|
||||
|
||||
class Config:
|
||||
populate_by_alias = True
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def state(self) -> str:
|
||||
"""Get the current state of the VM.
|
||||
|
||||
Returns:
|
||||
str: Current VM status
|
||||
"""
|
||||
return self.status
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def cpu(self) -> int:
|
||||
"""Get the number of CPU cores.
|
||||
|
||||
Returns:
|
||||
int: Number of CPU cores allocated to the VM
|
||||
"""
|
||||
return self.cpu_count
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def memory(self) -> str:
|
||||
"""Get memory allocation in human-readable format.
|
||||
|
||||
Returns:
|
||||
str: Memory size formatted as "{size}GB"
|
||||
"""
|
||||
# Convert bytes to GB
|
||||
gb = self.memory_size / (1024 * 1024 * 1024)
|
||||
return f"{int(gb)}GB"
|
||||
|
||||
|
||||
class VMUpdateOpts(BaseModel):
|
||||
"""Options for updating VM configuration.
|
||||
|
||||
Attributes:
|
||||
cpu: Number of CPU cores to update to
|
||||
memory: Amount of memory to update to with units
|
||||
disk_size: Size of disk to update to with units
|
||||
"""
|
||||
|
||||
cpu: Optional[int] = None
|
||||
memory: Optional[str] = None
|
||||
disk_size: Optional[str] = None
|
||||
|
||||
|
||||
class ImageRef(BaseModel):
|
||||
"""Reference to a VM image.
|
||||
|
||||
Attributes:
|
||||
image: Name of the image
|
||||
tag: Tag version of the image
|
||||
registry: Registry hostname where image is stored
|
||||
organization: Organization or namespace in the registry
|
||||
"""
|
||||
|
||||
image: str
|
||||
tag: str = "latest"
|
||||
registry: Optional[str] = "ghcr.io"
|
||||
organization: Optional[str] = "trycua"
|
||||
|
||||
def model_dump(self, **kwargs):
|
||||
"""Override model_dump to return just the image:tag format.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments (ignored)
|
||||
|
||||
Returns:
|
||||
str: Image reference in "image:tag" format
|
||||
"""
|
||||
return f"{self.image}:{self.tag}"
|
||||
|
||||
|
||||
class CloneSpec(BaseModel):
|
||||
"""Specification for cloning a VM.
|
||||
|
||||
Attributes:
|
||||
name: Name of the source VM to clone
|
||||
new_name: Name for the new cloned VM
|
||||
"""
|
||||
|
||||
name: str
|
||||
new_name: str = Field(alias="newName")
|
||||
|
||||
class Config:
|
||||
populate_by_alias = True
|
||||
|
||||
|
||||
class ImageInfo(BaseModel):
|
||||
"""Model for individual image information.
|
||||
|
||||
Attributes:
|
||||
imageId: Unique identifier for the image
|
||||
"""
|
||||
|
||||
imageId: str
|
||||
|
||||
|
||||
class ImageList(RootModel):
|
||||
"""Response model for the images endpoint.
|
||||
|
||||
A list-like container for ImageInfo objects that provides
|
||||
iteration and indexing capabilities.
|
||||
"""
|
||||
|
||||
root: List[ImageInfo]
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over the image list.
|
||||
|
||||
Returns:
|
||||
Iterator over ImageInfo objects
|
||||
"""
|
||||
return iter(self.root)
|
||||
|
||||
def __getitem__(self, item):
|
||||
"""Get an item from the image list by index.
|
||||
|
||||
Args:
|
||||
item: Index or slice to retrieve
|
||||
|
||||
Returns:
|
||||
ImageInfo or list of ImageInfo objects
|
||||
"""
|
||||
return self.root[item]
|
||||
|
||||
def __len__(self):
|
||||
"""Get the number of images in the list.
|
||||
|
||||
Returns:
|
||||
int: Number of images in the list
|
||||
"""
|
||||
return len(self.root)
|
||||
@@ -1,315 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, List, Optional, TypeVar, Union
|
||||
|
||||
from .client import LumeClient
|
||||
from .exceptions import (
|
||||
LumeConfigError,
|
||||
LumeConnectionError,
|
||||
LumeError,
|
||||
LumeImageError,
|
||||
LumeNotFoundError,
|
||||
LumeServerError,
|
||||
LumeTimeoutError,
|
||||
LumeVMError,
|
||||
)
|
||||
from .models import (
|
||||
CloneSpec,
|
||||
ImageList,
|
||||
ImageRef,
|
||||
SharedDirectory,
|
||||
VMConfig,
|
||||
VMRunOpts,
|
||||
VMStatus,
|
||||
VMUpdateOpts,
|
||||
)
|
||||
from .server import LumeServer
|
||||
|
||||
# Type variable for the decorator
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def ensure_server(func: Callable[..., T]) -> Callable[..., T]:
|
||||
"""Decorator to ensure server is running before executing the method."""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(self: "PyLume", *args: Any, **kwargs: Any) -> T:
|
||||
# ensure_running is an async method, so we need to await it
|
||||
await self.server.ensure_running()
|
||||
# Initialize client if needed
|
||||
await self._init_client()
|
||||
return await func(self, *args, **kwargs) # type: ignore
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
|
||||
class PyLume:
|
||||
def __init__(
|
||||
self,
|
||||
debug: bool = False,
|
||||
server_start_timeout: int = 60,
|
||||
port: Optional[int] = None,
|
||||
use_existing_server: bool = False,
|
||||
host: str = "localhost",
|
||||
):
|
||||
"""Initialize the async PyLume client.
|
||||
|
||||
Args:
|
||||
debug: Enable debug logging
|
||||
auto_start_server: Whether to automatically start the lume server if not running
|
||||
server_start_timeout: Timeout in seconds to wait for server to start
|
||||
port: Port number for the lume server. Required when use_existing_server is True.
|
||||
use_existing_server: If True, will try to connect to an existing server on the specified port
|
||||
instead of starting a new one.
|
||||
host: Host to use for connections (e.g., "localhost", "127.0.0.1", "host.docker.internal")
|
||||
"""
|
||||
if use_existing_server and port is None:
|
||||
raise LumeConfigError("Port must be specified when using an existing server")
|
||||
|
||||
self.server = LumeServer(
|
||||
debug=debug,
|
||||
server_start_timeout=server_start_timeout,
|
||||
port=port,
|
||||
use_existing_server=use_existing_server,
|
||||
host=host,
|
||||
)
|
||||
self.client = None
|
||||
|
||||
async def __aenter__(self) -> "PyLume":
|
||||
"""Async context manager entry."""
|
||||
if self.server.use_existing_server:
|
||||
# Just ensure base_url is set for existing server
|
||||
if self.server.requested_port is None:
|
||||
raise LumeConfigError("Port must be specified when using an existing server")
|
||||
|
||||
if not self.server.base_url:
|
||||
self.server.port = self.server.requested_port
|
||||
self.server.base_url = f"http://{self.server.host}:{self.server.port}/lume"
|
||||
|
||||
# Ensure the server is running (will connect to existing or start new as needed)
|
||||
await self.server.ensure_running()
|
||||
|
||||
# Initialize the client
|
||||
await self._init_client()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
||||
"""Async context manager exit."""
|
||||
if self.client is not None:
|
||||
await self.client.close()
|
||||
await self.server.stop()
|
||||
|
||||
async def _init_client(self) -> None:
|
||||
"""Initialize the client if not already initialized."""
|
||||
if self.client is None:
|
||||
if self.server.base_url is None:
|
||||
raise RuntimeError("Server base URL not set")
|
||||
self.client = LumeClient(self.server.base_url, debug=self.server.debug)
|
||||
|
||||
def _log_debug(self, message: str, **kwargs) -> None:
|
||||
"""Log debug information if debug mode is enabled."""
|
||||
if self.server.debug:
|
||||
print(f"DEBUG: {message}")
|
||||
if kwargs:
|
||||
print(json.dumps(kwargs, indent=2))
|
||||
|
||||
async def _handle_api_error(self, e: Exception, operation: str) -> None:
|
||||
"""Handle API errors and raise appropriate custom exceptions."""
|
||||
if isinstance(e, subprocess.SubprocessError):
|
||||
raise LumeConnectionError(f"Failed to connect to PyLume server: {str(e)}")
|
||||
elif isinstance(e, asyncio.TimeoutError):
|
||||
raise LumeTimeoutError(f"Request timed out: {str(e)}")
|
||||
|
||||
if not hasattr(e, "status") and not isinstance(e, subprocess.CalledProcessError):
|
||||
raise LumeServerError(f"Unknown error during {operation}: {str(e)}")
|
||||
|
||||
status_code = getattr(e, "status", 500)
|
||||
response_text = str(e)
|
||||
|
||||
self._log_debug(
|
||||
f"{operation} request failed", status_code=status_code, response_text=response_text
|
||||
)
|
||||
|
||||
if status_code == 404:
|
||||
raise LumeNotFoundError(f"Resource not found during {operation}")
|
||||
elif status_code == 400:
|
||||
raise LumeConfigError(f"Invalid configuration for {operation}: {response_text}")
|
||||
elif status_code >= 500:
|
||||
raise LumeServerError(
|
||||
f"Server error during {operation}",
|
||||
status_code=status_code,
|
||||
response_text=response_text,
|
||||
)
|
||||
else:
|
||||
raise LumeServerError(
|
||||
f"Error during {operation}", status_code=status_code, response_text=response_text
|
||||
)
|
||||
|
||||
async def _read_output(self) -> None:
|
||||
"""Read and log server output."""
|
||||
try:
|
||||
while True:
|
||||
if not self.server.server_process or self.server.server_process.poll() is not None:
|
||||
self._log_debug("Server process ended")
|
||||
break
|
||||
|
||||
# Read stdout without blocking
|
||||
if self.server.server_process.stdout:
|
||||
while True:
|
||||
line = self.server.server_process.stdout.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
self._log_debug(f"Server stdout: {line}")
|
||||
if "Server started" in line.decode("utf-8"):
|
||||
self._log_debug("Detected server started message")
|
||||
return
|
||||
|
||||
# Read stderr without blocking
|
||||
if self.server.server_process.stderr:
|
||||
while True:
|
||||
line = self.server.server_process.stderr.readline()
|
||||
if not line:
|
||||
break
|
||||
line = line.strip()
|
||||
self._log_debug(f"Server stderr: {line}")
|
||||
if "error" in line.decode("utf-8").lower():
|
||||
raise RuntimeError(f"Server error: {line}")
|
||||
|
||||
await asyncio.sleep(0.1) # Small delay to prevent CPU spinning
|
||||
except Exception as e:
|
||||
self._log_debug(f"Error in output reader: {str(e)}")
|
||||
raise
|
||||
|
||||
@ensure_server
|
||||
async def create_vm(self, spec: Union[VMConfig, dict]) -> None:
|
||||
"""Create a VM with the given configuration."""
|
||||
# Ensure client is initialized
|
||||
await self._init_client()
|
||||
|
||||
if isinstance(spec, VMConfig):
|
||||
spec = spec.model_dump(by_alias=True, exclude_none=True)
|
||||
|
||||
# Suppress optional attribute access errors
|
||||
self.client.print_curl("POST", "/vms", spec) # type: ignore[attr-defined]
|
||||
await self.client.post("/vms", spec) # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def run_vm(self, name: str, opts: Optional[Union[VMRunOpts, dict]] = None) -> None:
|
||||
"""Run a VM."""
|
||||
if opts is None:
|
||||
opts = VMRunOpts(no_display=False) # type: ignore[attr-defined]
|
||||
elif isinstance(opts, dict):
|
||||
opts = VMRunOpts(**opts)
|
||||
|
||||
payload = opts.model_dump(by_alias=True, exclude_none=True)
|
||||
self.client.print_curl("POST", f"/vms/{name}/run", payload) # type: ignore[attr-defined]
|
||||
await self.client.post(f"/vms/{name}/run", payload) # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def list_vms(self) -> List[VMStatus]:
|
||||
"""List all VMs."""
|
||||
data = await self.client.get("/vms") # type: ignore[attr-defined]
|
||||
return [VMStatus.model_validate(vm) for vm in data]
|
||||
|
||||
@ensure_server
|
||||
async def get_vm(self, name: str) -> VMStatus:
|
||||
"""Get VM details."""
|
||||
data = await self.client.get(f"/vms/{name}") # type: ignore[attr-defined]
|
||||
return VMStatus.model_validate(data)
|
||||
|
||||
@ensure_server
|
||||
async def update_vm(self, name: str, params: Union[VMUpdateOpts, dict]) -> None:
|
||||
"""Update VM settings."""
|
||||
if isinstance(params, dict):
|
||||
params = VMUpdateOpts(**params)
|
||||
|
||||
payload = params.model_dump(by_alias=True, exclude_none=True)
|
||||
self.client.print_curl("PATCH", f"/vms/{name}", payload) # type: ignore[attr-defined]
|
||||
await self.client.patch(f"/vms/{name}", payload) # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def stop_vm(self, name: str) -> None:
|
||||
"""Stop a VM."""
|
||||
await self.client.post(f"/vms/{name}/stop") # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def delete_vm(self, name: str) -> None:
|
||||
"""Delete a VM."""
|
||||
await self.client.delete(f"/vms/{name}") # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def pull_image(
|
||||
self, spec: Union[ImageRef, dict, str], name: Optional[str] = None
|
||||
) -> None:
|
||||
"""Pull a VM image."""
|
||||
await self._init_client()
|
||||
if isinstance(spec, str):
|
||||
if ":" in spec:
|
||||
image_str = spec
|
||||
else:
|
||||
image_str = f"{spec}:latest"
|
||||
registry = "ghcr.io"
|
||||
organization = "trycua"
|
||||
elif isinstance(spec, dict):
|
||||
image = spec.get("image", "")
|
||||
tag = spec.get("tag", "latest")
|
||||
image_str = f"{image}:{tag}"
|
||||
registry = spec.get("registry", "ghcr.io")
|
||||
organization = spec.get("organization", "trycua")
|
||||
else:
|
||||
image_str = f"{spec.image}:{spec.tag}"
|
||||
registry = spec.registry
|
||||
organization = spec.organization
|
||||
|
||||
payload = {
|
||||
"image": image_str,
|
||||
"name": name,
|
||||
"registry": registry,
|
||||
"organization": organization,
|
||||
}
|
||||
|
||||
self.client.print_curl("POST", "/pull", payload) # type: ignore[attr-defined]
|
||||
await self.client.post("/pull", payload, timeout=300.0) # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def clone_vm(self, name: str, new_name: str) -> None:
|
||||
"""Clone a VM with the given name to a new VM with new_name."""
|
||||
config = CloneSpec(name=name, newName=new_name)
|
||||
self.client.print_curl("POST", "/vms/clone", config.model_dump()) # type: ignore[attr-defined]
|
||||
await self.client.post("/vms/clone", config.model_dump()) # type: ignore[attr-defined]
|
||||
|
||||
@ensure_server
|
||||
async def get_latest_ipsw_url(self) -> str:
|
||||
"""Get the latest IPSW URL."""
|
||||
await self._init_client()
|
||||
data = await self.client.get("/ipsw") # type: ignore[attr-defined]
|
||||
return data["url"]
|
||||
|
||||
@ensure_server
|
||||
async def get_images(self, organization: Optional[str] = None) -> ImageList:
|
||||
"""Get list of available images."""
|
||||
await self._init_client()
|
||||
params = {"organization": organization} if organization else None
|
||||
data = await self.client.get("/images", params) # type: ignore[attr-defined]
|
||||
return ImageList(root=data)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the client and stop the server."""
|
||||
if self.client is not None:
|
||||
await self.client.close()
|
||||
self.client = None
|
||||
await asyncio.sleep(1)
|
||||
await self.server.stop()
|
||||
|
||||
async def _ensure_client(self) -> None:
|
||||
"""Ensure client is initialized."""
|
||||
if self.client is None:
|
||||
await self._init_client()
|
||||
@@ -1,481 +0,0 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import shlex
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
from logging import getLogger
|
||||
from typing import Optional
|
||||
|
||||
from .exceptions import LumeConnectionError
|
||||
|
||||
|
||||
class LumeServer:
|
||||
def __init__(
|
||||
self,
|
||||
debug: bool = False,
|
||||
server_start_timeout: int = 60,
|
||||
port: Optional[int] = None,
|
||||
use_existing_server: bool = False,
|
||||
host: str = "localhost",
|
||||
):
|
||||
"""Initialize the LumeServer.
|
||||
|
||||
Args:
|
||||
debug: Enable debug logging
|
||||
server_start_timeout: Timeout in seconds to wait for server to start
|
||||
port: Specific port to use for the server
|
||||
use_existing_server: If True, will try to connect to an existing server
|
||||
instead of starting a new one
|
||||
host: Host to use for connections (e.g., "localhost", "127.0.0.1", "host.docker.internal")
|
||||
"""
|
||||
self.debug = debug
|
||||
self.server_start_timeout = server_start_timeout
|
||||
self.server_process = None
|
||||
self.output_file = None
|
||||
self.requested_port = port
|
||||
self.port = None
|
||||
self.base_url = None
|
||||
self.use_existing_server = use_existing_server
|
||||
self.host = host
|
||||
|
||||
# Configure logging
|
||||
self.logger = getLogger("pylume.server")
|
||||
if not self.logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
handler.setFormatter(formatter)
|
||||
self.logger.addHandler(handler)
|
||||
self.logger.setLevel(logging.DEBUG if debug else logging.INFO)
|
||||
|
||||
self.logger.debug(f"Server initialized with host: {self.host}")
|
||||
|
||||
def _check_port_available(self, port: int) -> bool:
|
||||
"""Check if a port is available."""
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(0.5)
|
||||
result = s.connect_ex(("127.0.0.1", port))
|
||||
if result == 0: # Port is in use on localhost
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check the specified host (e.g., "host.docker.internal") if it's not a localhost alias
|
||||
if self.host not in ["localhost", "127.0.0.1"]:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.settimeout(0.5)
|
||||
result = s.connect_ex((self.host, port))
|
||||
if result == 0: # Port is in use on host
|
||||
return False
|
||||
except:
|
||||
pass
|
||||
|
||||
return True
|
||||
|
||||
def _get_server_port(self) -> int:
|
||||
"""Get an available port for the server."""
|
||||
# Use requested port if specified
|
||||
if self.requested_port is not None:
|
||||
if not self._check_port_available(self.requested_port):
|
||||
raise RuntimeError(f"Requested port {self.requested_port} is not available")
|
||||
return self.requested_port
|
||||
|
||||
# Find a free port
|
||||
for _ in range(10): # Try up to 10 times
|
||||
port = random.randint(49152, 65535)
|
||||
if self._check_port_available(port):
|
||||
return port
|
||||
|
||||
raise RuntimeError("Could not find an available port")
|
||||
|
||||
async def _ensure_server_running(self) -> None:
|
||||
"""Ensure the lume server is running, start it if it's not."""
|
||||
try:
|
||||
self.logger.debug("Checking if lume server is running...")
|
||||
# Try to connect to the server with a short timeout
|
||||
cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "5", f"{self.base_url}/vms"]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
response = stdout.decode()
|
||||
status_code = int(response[-3:])
|
||||
if status_code == 200:
|
||||
self.logger.debug("PyLume server is running")
|
||||
return
|
||||
|
||||
self.logger.debug("PyLume server not running, attempting to start it")
|
||||
# Server not running, try to start it
|
||||
lume_path = os.path.join(os.path.dirname(__file__), "lume")
|
||||
if not os.path.exists(lume_path):
|
||||
raise RuntimeError(f"Could not find lume binary at {lume_path}")
|
||||
|
||||
# Make sure the file is executable
|
||||
os.chmod(lume_path, 0o755)
|
||||
|
||||
# Create a temporary file for server output
|
||||
self.output_file = tempfile.NamedTemporaryFile(mode="w+", delete=False)
|
||||
self.logger.debug(f"Using temporary file for server output: {self.output_file.name}")
|
||||
|
||||
# Start the server
|
||||
self.logger.debug(f"Starting lume server with: {lume_path} serve --port {self.port}")
|
||||
|
||||
# Start server in background using subprocess.Popen
|
||||
try:
|
||||
self.server_process = subprocess.Popen(
|
||||
[lume_path, "serve", "--port", str(self.port)],
|
||||
stdout=self.output_file,
|
||||
stderr=self.output_file,
|
||||
cwd=os.path.dirname(lume_path),
|
||||
start_new_session=True, # Run in new session to avoid blocking
|
||||
)
|
||||
except Exception as e:
|
||||
self.output_file.close()
|
||||
os.unlink(self.output_file.name)
|
||||
raise RuntimeError(f"Failed to start lume server process: {str(e)}")
|
||||
|
||||
# Wait for server to start
|
||||
self.logger.debug(
|
||||
f"Waiting up to {self.server_start_timeout} seconds for server to start..."
|
||||
)
|
||||
start_time = time.time()
|
||||
server_ready = False
|
||||
last_size = 0
|
||||
|
||||
while time.time() - start_time < self.server_start_timeout:
|
||||
if self.server_process.poll() is not None:
|
||||
# Process has terminated
|
||||
self.output_file.seek(0)
|
||||
output = self.output_file.read()
|
||||
self.output_file.close()
|
||||
os.unlink(self.output_file.name)
|
||||
error_msg = (
|
||||
f"Server process terminated unexpectedly.\n"
|
||||
f"Exit code: {self.server_process.returncode}\n"
|
||||
f"Output: {output}"
|
||||
)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Check output file for server ready message
|
||||
self.output_file.seek(0, os.SEEK_END)
|
||||
size = self.output_file.tell()
|
||||
if size > last_size: # Only read if there's new content
|
||||
self.output_file.seek(last_size)
|
||||
new_output = self.output_file.read()
|
||||
if new_output.strip(): # Only log non-empty output
|
||||
self.logger.debug(f"Server output: {new_output.strip()}")
|
||||
last_size = size
|
||||
|
||||
if "Server started" in new_output:
|
||||
server_ready = True
|
||||
self.logger.debug("Server startup detected")
|
||||
break
|
||||
|
||||
# Try to connect to the server periodically
|
||||
try:
|
||||
cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "5", f"{self.base_url}/vms"]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode == 0:
|
||||
response = stdout.decode()
|
||||
status_code = int(response[-3:])
|
||||
if status_code == 200:
|
||||
server_ready = True
|
||||
self.logger.debug("Server is responding to requests")
|
||||
break
|
||||
except:
|
||||
pass # Server not ready yet
|
||||
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
if not server_ready:
|
||||
# Cleanup if server didn't start
|
||||
if self.server_process:
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.server_process.kill()
|
||||
self.output_file.close()
|
||||
os.unlink(self.output_file.name)
|
||||
raise RuntimeError(
|
||||
f"Failed to start lume server after {self.server_start_timeout} seconds. "
|
||||
"Check the debug output for more details."
|
||||
)
|
||||
|
||||
# Give the server a moment to fully initialize
|
||||
await asyncio.sleep(2.0)
|
||||
|
||||
# Verify server is responding
|
||||
try:
|
||||
cmd = ["curl", "-s", "-w", "%{http_code}", "-m", "10", f"{self.base_url}/vms"]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
raise RuntimeError(f"Curl command failed: {stderr.decode()}")
|
||||
|
||||
response = stdout.decode()
|
||||
status_code = int(response[-3:])
|
||||
|
||||
if status_code != 200:
|
||||
raise RuntimeError(f"Server returned status code {status_code}")
|
||||
|
||||
self.logger.debug("PyLume server started successfully")
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Server verification failed: {str(e)}")
|
||||
if self.server_process:
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.server_process.kill()
|
||||
self.output_file.close()
|
||||
os.unlink(self.output_file.name)
|
||||
raise RuntimeError(f"Server started but is not responding: {str(e)}")
|
||||
|
||||
self.logger.debug("Server startup completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to start lume server: {str(e)}")
|
||||
|
||||
async def _start_server(self) -> None:
|
||||
"""Start the lume server using the lume executable."""
|
||||
self.logger.debug("Starting PyLume server")
|
||||
|
||||
# Get absolute path to lume executable in the same directory as this file
|
||||
lume_path = os.path.join(os.path.dirname(__file__), "lume")
|
||||
if not os.path.exists(lume_path):
|
||||
raise RuntimeError(f"Could not find lume binary at {lume_path}")
|
||||
|
||||
try:
|
||||
# Make executable
|
||||
os.chmod(lume_path, 0o755)
|
||||
|
||||
# Get and validate port
|
||||
self.port = self._get_server_port()
|
||||
self.base_url = f"http://{self.host}:{self.port}/lume"
|
||||
|
||||
# Set up output handling
|
||||
self.output_file = tempfile.NamedTemporaryFile(mode="w+", delete=False)
|
||||
|
||||
# Start the server process with the lume executable
|
||||
env = os.environ.copy()
|
||||
env["RUST_BACKTRACE"] = "1" # Enable backtrace for better error reporting
|
||||
|
||||
# Specify the host to bind to (0.0.0.0 to allow external connections)
|
||||
self.server_process = subprocess.Popen(
|
||||
[lume_path, "serve", "--port", str(self.port)],
|
||||
stdout=self.output_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=os.path.dirname(lume_path), # Run from same directory as executable
|
||||
env=env,
|
||||
)
|
||||
|
||||
# Wait for server to initialize
|
||||
await asyncio.sleep(2)
|
||||
await self._wait_for_server()
|
||||
|
||||
except Exception as e:
|
||||
await self._cleanup()
|
||||
raise RuntimeError(f"Failed to start lume server process: {str(e)}")
|
||||
|
||||
async def _tail_log(self) -> None:
|
||||
"""Read and display server log output in debug mode."""
|
||||
while True:
|
||||
try:
|
||||
self.output_file.seek(0, os.SEEK_END) # type: ignore[attr-defined]
|
||||
line = self.output_file.readline() # type: ignore[attr-defined]
|
||||
if line:
|
||||
line = line.strip()
|
||||
if line:
|
||||
print(f"SERVER: {line}")
|
||||
if self.server_process.poll() is not None: # type: ignore[attr-defined]
|
||||
print("Server process ended")
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
except Exception as e:
|
||||
print(f"Error reading log: {e}")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
async def _wait_for_server(self) -> None:
|
||||
"""Wait for server to start and become responsive with increased timeout."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < self.server_start_timeout:
|
||||
if self.server_process.poll() is not None: # type: ignore[attr-defined]
|
||||
error_msg = await self._get_error_output()
|
||||
await self._cleanup()
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
try:
|
||||
await self._verify_server()
|
||||
self.logger.debug("Server is now responsive")
|
||||
return
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Server not ready yet: {str(e)}")
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
await self._cleanup()
|
||||
raise RuntimeError(f"Server failed to start after {self.server_start_timeout} seconds")
|
||||
|
||||
async def _verify_server(self) -> None:
|
||||
"""Verify server is responding to requests."""
|
||||
try:
|
||||
cmd = [
|
||||
"curl",
|
||||
"-s",
|
||||
"-w",
|
||||
"%{http_code}",
|
||||
"-m",
|
||||
"10",
|
||||
f"http://{self.host}:{self.port}/lume/vms",
|
||||
]
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
raise RuntimeError(f"Curl command failed: {stderr.decode()}")
|
||||
|
||||
response = stdout.decode()
|
||||
status_code = int(response[-3:])
|
||||
|
||||
if status_code != 200:
|
||||
raise RuntimeError(f"Server returned status code {status_code}")
|
||||
|
||||
self.logger.debug("PyLume server started successfully")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Server not responding: {str(e)}")
|
||||
|
||||
async def _get_error_output(self) -> str:
|
||||
"""Get error output from the server process."""
|
||||
if not self.output_file:
|
||||
return "No output available"
|
||||
self.output_file.seek(0)
|
||||
output = self.output_file.read()
|
||||
return (
|
||||
f"Server process terminated unexpectedly.\n"
|
||||
f"Exit code: {self.server_process.returncode}\n" # type: ignore[attr-defined]
|
||||
f"Output: {output}"
|
||||
)
|
||||
|
||||
async def _cleanup(self) -> None:
|
||||
"""Clean up all server resources."""
|
||||
if self.server_process:
|
||||
try:
|
||||
self.server_process.terminate()
|
||||
try:
|
||||
self.server_process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.server_process.kill()
|
||||
except:
|
||||
pass
|
||||
self.server_process = None
|
||||
|
||||
# Clean up output file
|
||||
if self.output_file:
|
||||
try:
|
||||
self.output_file.close()
|
||||
os.unlink(self.output_file.name)
|
||||
except Exception as e:
|
||||
self.logger.debug(f"Error cleaning up output file: {e}")
|
||||
self.output_file = None
|
||||
|
||||
async def ensure_running(self) -> None:
|
||||
"""Ensure the server is running.
|
||||
|
||||
If use_existing_server is True, will only try to connect to an existing server.
|
||||
Otherwise will:
|
||||
1. Try to connect to an existing server on the specified port
|
||||
2. If that fails and not in Docker, start a new server
|
||||
3. If in Docker and no existing server is found, raise an error
|
||||
"""
|
||||
# First check if we're in Docker
|
||||
in_docker = os.path.exists("/.dockerenv") or (
|
||||
os.path.exists("/proc/1/cgroup") and "docker" in open("/proc/1/cgroup", "r").read()
|
||||
)
|
||||
|
||||
# If using a non-localhost host like host.docker.internal, set up the connection details
|
||||
if self.host not in ["localhost", "127.0.0.1"]:
|
||||
if self.requested_port is None:
|
||||
raise RuntimeError("Port must be specified when using a remote host")
|
||||
|
||||
self.port = self.requested_port
|
||||
self.base_url = f"http://{self.host}:{self.port}/lume"
|
||||
self.logger.debug(f"Using remote host server at {self.base_url}")
|
||||
|
||||
# Try to verify the server is accessible
|
||||
try:
|
||||
await self._verify_server()
|
||||
self.logger.debug("Successfully connected to remote server")
|
||||
return
|
||||
except Exception as e:
|
||||
if self.use_existing_server or in_docker:
|
||||
# If explicitly requesting an existing server or in Docker, we can't start a new one
|
||||
raise RuntimeError(
|
||||
f"Failed to connect to remote server at {self.base_url}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
self.logger.debug(f"Remote server not available at {self.base_url}: {str(e)}")
|
||||
# Fall back to localhost for starting a new server
|
||||
self.host = "localhost"
|
||||
|
||||
# If explicitly using an existing server, verify it's running
|
||||
if self.use_existing_server:
|
||||
if self.requested_port is None:
|
||||
raise RuntimeError("Port must be specified when using an existing server")
|
||||
|
||||
self.port = self.requested_port
|
||||
self.base_url = f"http://{self.host}:{self.port}/lume"
|
||||
|
||||
try:
|
||||
await self._verify_server()
|
||||
self.logger.debug("Successfully connected to existing server")
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to connect to existing server at {self.base_url}: {str(e)}"
|
||||
)
|
||||
else:
|
||||
# Try to connect to an existing server first
|
||||
if self.requested_port is not None:
|
||||
self.port = self.requested_port
|
||||
self.base_url = f"http://{self.host}:{self.port}/lume"
|
||||
|
||||
try:
|
||||
await self._verify_server()
|
||||
self.logger.debug("Successfully connected to existing server")
|
||||
return
|
||||
except Exception:
|
||||
self.logger.debug(f"No existing server found at {self.base_url}")
|
||||
|
||||
# If in Docker and can't connect to existing server, raise an error
|
||||
if in_docker:
|
||||
raise RuntimeError(
|
||||
f"Failed to connect to server at {self.base_url} and cannot start a new server in Docker"
|
||||
)
|
||||
|
||||
# Start a new server
|
||||
self.logger.debug("Starting a new server instance")
|
||||
await self._start_server()
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the server if we're managing it."""
|
||||
if not self.use_existing_server:
|
||||
self.logger.debug("Stopping lume server...")
|
||||
await self._cleanup()
|
||||
@@ -1,51 +0,0 @@
|
||||
[build-system]
|
||||
build-backend = "pdm.backend"
|
||||
requires = ["pdm-backend"]
|
||||
|
||||
[project]
|
||||
authors = [{ name = "TryCua", email = "gh@trycua.com" }]
|
||||
classifiers = [
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: MacOS :: MacOS X",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
]
|
||||
dependencies = ["pydantic>=2.11.1"]
|
||||
description = "Python SDK for lume - run macOS and Linux VMs on Apple Silicon"
|
||||
dynamic = ["version"]
|
||||
keywords = ["apple-silicon", "macos", "virtualization", "vm"]
|
||||
license = { text = "MIT" }
|
||||
name = "pylume"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
|
||||
[tool.pdm.version]
|
||||
path = "pylume/__init__.py"
|
||||
source = "file"
|
||||
|
||||
[project.urls]
|
||||
homepage = "https://github.com/trycua/pylume"
|
||||
repository = "https://github.com/trycua/pylume"
|
||||
|
||||
[tool.pdm]
|
||||
distribution = true
|
||||
|
||||
[tool.pdm.dev-dependencies]
|
||||
dev = [
|
||||
"black>=23.0.0",
|
||||
"isort>=5.12.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"pytest>=7.0.0",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
python_files = "test_*.py"
|
||||
testpaths = ["tests"]
|
||||
|
||||
[tool.pdm.build]
|
||||
includes = ["pylume/"]
|
||||
source-includes = ["LICENSE", "README.md", "tests/"]
|
||||
23
libs/python/pylume/tests/conftest.py
Normal file
23
libs/python/pylume/tests/conftest.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Pytest configuration for pylume tests.
|
||||
|
||||
This module provides test fixtures for the pylume package.
|
||||
Note: This package has macOS-specific dependencies and will skip tests
|
||||
if the required modules are not available.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_subprocess():
|
||||
with patch("subprocess.run") as mock_run:
|
||||
mock_run.return_value = Mock(returncode=0, stdout="", stderr="")
|
||||
yield mock_run
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_requests():
|
||||
with patch("requests.get") as mock_get, patch("requests.post") as mock_post:
|
||||
yield {"get": mock_get, "post": mock_post}
|
||||
38
libs/python/pylume/tests/test_pylume.py
Normal file
38
libs/python/pylume/tests/test_pylume.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""Unit tests for pylume package.
|
||||
|
||||
This file tests ONLY basic pylume functionality.
|
||||
Following SRP: This file tests pylume module imports and basic operations.
|
||||
All external dependencies are mocked.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestPylumeImports:
|
||||
"""Test pylume module imports (SRP: Only tests imports)."""
|
||||
|
||||
def test_pylume_module_exists(self):
|
||||
"""Test that pylume module can be imported."""
|
||||
try:
|
||||
import pylume
|
||||
|
||||
assert pylume is not None
|
||||
except ImportError:
|
||||
pytest.skip("pylume module not installed")
|
||||
|
||||
|
||||
class TestPylumeInitialization:
|
||||
"""Test pylume initialization (SRP: Only tests initialization)."""
|
||||
|
||||
def test_pylume_can_be_imported(self):
|
||||
"""Basic smoke test: verify pylume components can be imported."""
|
||||
try:
|
||||
import pylume
|
||||
|
||||
# Check for basic attributes
|
||||
assert pylume is not None
|
||||
except ImportError:
|
||||
pytest.skip("pylume module not available")
|
||||
except Exception as e:
|
||||
# Some initialization errors are acceptable in unit tests
|
||||
pytest.skip(f"pylume initialization requires specific setup: {e}")
|
||||
24
libs/python/som/tests/conftest.py
Normal file
24
libs/python/som/tests/conftest.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""Pytest configuration for som tests.
|
||||
|
||||
This module provides test fixtures for the som (Set-of-Mark) package.
|
||||
The som package depends on heavy ML models and will skip tests if not available.
|
||||
"""
|
||||
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_torch():
|
||||
with patch("torch.load") as mock_load:
|
||||
mock_load.return_value = Mock()
|
||||
yield mock_load
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_icon_detector():
|
||||
with patch("omniparser.IconDetector") as mock_detector:
|
||||
instance = Mock()
|
||||
mock_detector.return_value = instance
|
||||
yield instance
|
||||
@@ -1,13 +1,73 @@
|
||||
# """Basic tests for the omniparser package."""
|
||||
"""Unit tests for som package (Set-of-Mark).
|
||||
|
||||
# import pytest
|
||||
# from omniparser import IconDetector
|
||||
This file tests ONLY basic som functionality.
|
||||
Following SRP: This file tests som module imports and basic operations.
|
||||
All external dependencies (ML models, OCR) are mocked.
|
||||
"""
|
||||
|
||||
# def test_icon_detector_import():
|
||||
# """Test that we can import the IconDetector class."""
|
||||
# assert IconDetector is not None
|
||||
import pytest
|
||||
|
||||
# def test_icon_detector_init():
|
||||
# """Test that we can create an IconDetector instance."""
|
||||
# detector = IconDetector(force_cpu=True)
|
||||
# assert detector is not None
|
||||
|
||||
class TestSomImports:
|
||||
"""Test som module imports (SRP: Only tests imports)."""
|
||||
|
||||
def test_som_module_exists(self):
|
||||
"""Test that som module can be imported."""
|
||||
try:
|
||||
import som
|
||||
|
||||
assert som is not None
|
||||
except ImportError:
|
||||
pytest.skip("som module not installed")
|
||||
|
||||
def test_omniparser_import(self):
|
||||
"""Test that OmniParser can be imported."""
|
||||
try:
|
||||
from som import OmniParser
|
||||
|
||||
assert OmniParser is not None
|
||||
except ImportError:
|
||||
pytest.skip("som module not available")
|
||||
except Exception as e:
|
||||
pytest.skip(f"som initialization requires ML models: {e}")
|
||||
|
||||
def test_models_import(self):
|
||||
"""Test that model classes can be imported."""
|
||||
try:
|
||||
from som import BoundingBox, ParseResult, UIElement
|
||||
|
||||
assert BoundingBox is not None
|
||||
assert UIElement is not None
|
||||
assert ParseResult is not None
|
||||
except ImportError:
|
||||
pytest.skip("som models not available")
|
||||
except Exception as e:
|
||||
pytest.skip(f"som models require dependencies: {e}")
|
||||
|
||||
|
||||
class TestSomModels:
|
||||
"""Test som data models (SRP: Only tests model structure)."""
|
||||
|
||||
def test_bounding_box_structure(self):
|
||||
"""Test BoundingBox class structure."""
|
||||
try:
|
||||
from som import BoundingBox
|
||||
|
||||
# Check the class exists and has expected structure
|
||||
assert hasattr(BoundingBox, "__init__")
|
||||
except ImportError:
|
||||
pytest.skip("som models not available")
|
||||
except Exception as e:
|
||||
pytest.skip(f"som models require dependencies: {e}")
|
||||
|
||||
def test_ui_element_structure(self):
|
||||
"""Test UIElement class structure."""
|
||||
try:
|
||||
from som import UIElement
|
||||
|
||||
# Check the class exists and has expected structure
|
||||
assert hasattr(UIElement, "__init__")
|
||||
except ImportError:
|
||||
pytest.skip("som models not available")
|
||||
except Exception as e:
|
||||
pytest.skip(f"som models require dependencies: {e}")
|
||||
|
||||
Reference in New Issue
Block a user