Merge upstream/main to resolve conflicts with trycua/cua

This commit is contained in:
Sarina Li
2025-10-27 16:06:06 -07:00
344 changed files with 30505 additions and 16871 deletions

View File

@@ -0,0 +1,10 @@
[bumpversion]
current_version = 0.4.35
commit = True
tag = True
tag_name = agent-v{new_version}
message = Bump cua-agent to v{new_version}
[bumpversion:file:pyproject.toml]
search = version = "{current_version}"
replace = version = "{new_version}"

View File

@@ -8,10 +8,11 @@
</picture>
</div>
[![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#)
[![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85)
[![PyPI](https://img.shields.io/pypi/v/cua-computer?color=333333)](https://pypi.org/project/cua-computer/)
[![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#)
[![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85)
[![PyPI](https://img.shields.io/pypi/v/cua-computer?color=333333)](https://pypi.org/project/cua-computer/)
</h1>
</div>
@@ -47,7 +48,7 @@ async def main():
name=os.getenv("CUA_CONTAINER_NAME"),
api_key=os.getenv("CUA_API_KEY")
) as computer:
# Create agent
agent = ComputerAgent(
model="anthropic/claude-3-5-sonnet-20241022",
@@ -56,10 +57,10 @@ async def main():
trajectory_dir="trajectories",
max_trajectory_budget=5.0 # $5 budget limit
)
# Run agent
messages = [{"role": "user", "content": "Take a screenshot and tell me what you see"}]
async for result in agent.run(messages):
for item in result["output"]:
if item["type"] == "message":
@@ -84,4 +85,4 @@ if __name__ == "__main__":
## License
MIT License - see LICENSE file for details.
MIT License - see LICENSE file for details.

View File

@@ -5,19 +5,13 @@ agent - Decorator-based Computer Use Agent with liteLLM integration
import logging
import sys
from .decorators import register_agent
from .agent import ComputerAgent
from .types import Messages, AgentResponse
# Import loops to register them
from . import loops
from .agent import ComputerAgent
from .decorators import register_agent
from .types import AgentResponse, Messages
__all__ = [
"register_agent",
"ComputerAgent",
"Messages",
"AgentResponse"
]
__all__ = ["register_agent", "ComputerAgent", "Messages", "AgentResponse"]
__version__ = "0.4.0"

View File

@@ -5,8 +5,9 @@ Usage:
python -m agent.cli <model_string>
"""
import sys
import asyncio
import sys
from .cli import main
if __name__ == "__main__":

View File

@@ -2,27 +2,30 @@ import asyncio
import functools
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional
from litellm.types.utils import GenericStreamingChunk, ModelResponse
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from litellm import acompletion, completion
from litellm.llms.custom_llm import CustomLLM
from litellm import completion, acompletion
from litellm.types.utils import GenericStreamingChunk, ModelResponse
# Try to import HuggingFace dependencies
try:
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
from .models import load_model as load_model_handler
class HuggingFaceLocalAdapter(CustomLLM):
"""HuggingFace Local Adapter for running vision-language models locally."""
def __init__(self, device: str = "auto", trust_remote_code: bool = False, **kwargs):
"""Initialize the adapter.
Args:
device: Device to load model on ("auto", "cuda", "cpu", etc.)
trust_remote_code: Whether to trust remote code
@@ -34,129 +37,120 @@ class HuggingFaceLocalAdapter(CustomLLM):
# Cache for model handlers keyed by model_name
self._handlers: Dict[str, Any] = {}
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
def _get_handler(self, model_name: str):
"""Get or create a model handler for the given model name."""
if model_name not in self._handlers:
self._handlers[model_name] = load_model_handler(model_name=model_name, device=self.device, trust_remote_code=self.trust_remote_code)
self._handlers[model_name] = load_model_handler(
model_name=model_name, device=self.device, trust_remote_code=self.trust_remote_code
)
return self._handlers[model_name]
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert OpenAI format messages to HuggingFace format.
Args:
messages: Messages in OpenAI format
Returns:
Messages in HuggingFace format
"""
converted_messages = []
for message in messages:
converted_message = {
"role": message["role"],
"content": []
}
converted_message = {"role": message["role"], "content": []}
content = message.get("content", [])
if isinstance(content, str):
# Simple text content
converted_message["content"].append({
"type": "text",
"text": content
})
converted_message["content"].append({"type": "text", "text": content})
elif isinstance(content, list):
# Multi-modal content
for item in content:
if item.get("type") == "text":
converted_message["content"].append({
"type": "text",
"text": item.get("text", "")
})
converted_message["content"].append(
{"type": "text", "text": item.get("text", "")}
)
elif item.get("type") == "image_url":
# Convert image_url format to image format
image_url = item.get("image_url", {}).get("url", "")
converted_message["content"].append({
"type": "image",
"image": image_url
})
converted_message["content"].append({"type": "image", "image": image_url})
converted_messages.append(converted_message)
return converted_messages
def _generate(self, **kwargs) -> str:
"""Generate response using the local HuggingFace model.
Args:
**kwargs: Keyword arguments containing messages and model info
Returns:
Generated text response
"""
if not HF_AVAILABLE:
raise ImportError(
"HuggingFace transformers dependencies not found. "
"Please install with: pip install \"cua-agent[uitars-hf]\""
'Please install with: pip install "cua-agent[uitars-hf]"'
)
# Extract messages and model from kwargs
messages = kwargs.get('messages', [])
model_name = kwargs.get('model', 'ByteDance-Seed/UI-TARS-1.5-7B')
max_new_tokens = kwargs.get('max_tokens', 128)
messages = kwargs.get("messages", [])
model_name = kwargs.get("model", "ByteDance-Seed/UI-TARS-1.5-7B")
max_new_tokens = kwargs.get("max_tokens", 128)
# Warn about ignored kwargs
ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
ignored_kwargs = set(kwargs.keys()) - {"messages", "model", "max_tokens"}
if ignored_kwargs:
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
# Convert messages to HuggingFace format
hf_messages = self._convert_messages(messages)
# Delegate to model handler
handler = self._get_handler(model_name)
generated_text = handler.generate(hf_messages, max_new_tokens=max_new_tokens)
return generated_text
def completion(self, *args, **kwargs) -> ModelResponse:
"""Synchronous completion method.
Returns:
ModelResponse with generated text
"""
generated_text = self._generate(**kwargs)
return completion(
model=f"huggingface-local/{kwargs['model']}",
mock_response=generated_text,
)
async def acompletion(self, *args, **kwargs) -> ModelResponse:
"""Asynchronous completion method.
Returns:
ModelResponse with generated text
"""
# Run _generate in thread pool to avoid blocking
loop = asyncio.get_event_loop()
generated_text = await loop.run_in_executor(
self._executor,
functools.partial(self._generate, **kwargs)
self._executor, functools.partial(self._generate, **kwargs)
)
return await acompletion(
model=f"huggingface-local/{kwargs['model']}",
mock_response=generated_text,
)
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
"""Synchronous streaming method.
Returns:
Iterator of GenericStreamingChunk
"""
generated_text = self._generate(**kwargs)
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
@@ -165,22 +159,21 @@ class HuggingFaceLocalAdapter(CustomLLM):
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
yield generic_streaming_chunk
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
"""Asynchronous streaming method.
Returns:
AsyncIterator of GenericStreamingChunk
"""
# Run _generate in thread pool to avoid blocking
loop = asyncio.get_event_loop()
generated_text = await loop.run_in_executor(
self._executor,
functools.partial(self._generate, **kwargs)
self._executor, functools.partial(self._generate, **kwargs)
)
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
@@ -189,5 +182,5 @@ class HuggingFaceLocalAdapter(CustomLLM):
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
yield generic_streaming_chunk
yield generic_streaming_chunk

View File

@@ -1,22 +1,23 @@
import os
import asyncio
import os
from typing import Any, AsyncIterator, Dict, Iterator, List
import requests
from typing import List, Dict, Any, Iterator, AsyncIterator
from litellm.types.utils import GenericStreamingChunk, ModelResponse
from litellm import acompletion, completion
from litellm.llms.custom_llm import CustomLLM
from litellm import completion, acompletion
from litellm.types.utils import GenericStreamingChunk, ModelResponse
class HumanAdapter(CustomLLM):
"""Human Adapter for human-in-the-loop completions.
This adapter sends completion requests to a human completion server
where humans can review and respond to AI requests.
"""
def __init__(self, base_url: str | None = None, timeout: float = 300.0, **kwargs):
"""Initialize the human adapter.
Args:
base_url: Base URL for the human completion server.
Defaults to HUMAN_BASE_URL environment variable or http://localhost:8002
@@ -24,60 +25,58 @@ class HumanAdapter(CustomLLM):
**kwargs: Additional arguments
"""
super().__init__()
self.base_url = base_url or os.getenv('HUMAN_BASE_URL', 'http://localhost:8002')
self.base_url = base_url or os.getenv("HUMAN_BASE_URL", "http://localhost:8002")
self.timeout = timeout
# Ensure base_url doesn't end with slash
self.base_url = self.base_url.rstrip('/')
self.base_url = self.base_url.rstrip("/")
def _queue_completion(self, messages: List[Dict[str, Any]], model: str) -> str:
"""Queue a completion request and return the call ID.
Args:
messages: Messages in OpenAI format
model: Model name
Returns:
Call ID for tracking the request
Raises:
Exception: If queueing fails
"""
try:
response = requests.post(
f"{self.base_url}/queue",
json={"messages": messages, "model": model},
timeout=10
f"{self.base_url}/queue", json={"messages": messages, "model": model}, timeout=10
)
response.raise_for_status()
return response.json()["id"]
except requests.RequestException as e:
raise Exception(f"Failed to queue completion request: {e}")
def _wait_for_completion(self, call_id: str) -> Dict[str, Any]:
"""Wait for human to complete the call.
Args:
call_id: ID of the queued completion call
Returns:
Dict containing response and/or tool_calls
Raises:
TimeoutError: If timeout is exceeded
Exception: If completion fails
"""
import time
start_time = time.time()
while True:
try:
# Check status
status_response = requests.get(f"{self.base_url}/status/{call_id}")
status_response.raise_for_status()
status_data = status_response.json()
if status_data["status"] == "completed":
result = {}
if "response" in status_data and status_data["response"]:
@@ -88,38 +87,41 @@ class HumanAdapter(CustomLLM):
elif status_data["status"] == "failed":
error_msg = status_data.get("error", "Unknown error")
raise Exception(f"Completion failed: {error_msg}")
# Check timeout
if time.time() - start_time > self.timeout:
raise TimeoutError(f"Timeout waiting for human response after {self.timeout} seconds")
raise TimeoutError(
f"Timeout waiting for human response after {self.timeout} seconds"
)
# Wait before checking again
time.sleep(1.0)
except requests.RequestException as e:
if time.time() - start_time > self.timeout:
raise TimeoutError(f"Timeout waiting for human response: {e}")
# Continue trying if we haven't timed out
time.sleep(1.0)
async def _async_wait_for_completion(self, call_id: str) -> Dict[str, Any]:
"""Async version of wait_for_completion.
Args:
call_id: ID of the queued completion call
Returns:
Dict containing response and/or tool_calls
Raises:
TimeoutError: If timeout is exceeded
Exception: If completion fails
"""
import aiohttp
import time
import aiohttp
start_time = time.time()
async with aiohttp.ClientSession() as session:
while True:
try:
@@ -127,7 +129,7 @@ class HumanAdapter(CustomLLM):
async with session.get(f"{self.base_url}/status/{call_id}") as response:
response.raise_for_status()
status_data = await response.json()
if status_data["status"] == "completed":
result = {}
if "response" in status_data and status_data["response"]:
@@ -138,166 +140,158 @@ class HumanAdapter(CustomLLM):
elif status_data["status"] == "failed":
error_msg = status_data.get("error", "Unknown error")
raise Exception(f"Completion failed: {error_msg}")
# Check timeout
if time.time() - start_time > self.timeout:
raise TimeoutError(f"Timeout waiting for human response after {self.timeout} seconds")
raise TimeoutError(
f"Timeout waiting for human response after {self.timeout} seconds"
)
# Wait before checking again
await asyncio.sleep(1.0)
except Exception as e:
if time.time() - start_time > self.timeout:
raise TimeoutError(f"Timeout waiting for human response: {e}")
# Continue trying if we haven't timed out
await asyncio.sleep(1.0)
def _generate_response(self, messages: List[Dict[str, Any]], model: str) -> Dict[str, Any]:
"""Generate a human response for the given messages.
Args:
messages: Messages in OpenAI format
model: Model name
Returns:
Dict containing response and/or tool_calls
"""
# Queue the completion request
call_id = self._queue_completion(messages, model)
# Wait for human response
response = self._wait_for_completion(call_id)
return response
async def _async_generate_response(self, messages: List[Dict[str, Any]], model: str) -> Dict[str, Any]:
async def _async_generate_response(
self, messages: List[Dict[str, Any]], model: str
) -> Dict[str, Any]:
"""Async version of _generate_response.
Args:
messages: Messages in OpenAI format
model: Model name
Returns:
Dict containing response and/or tool_calls
"""
# Queue the completion request (sync operation)
call_id = self._queue_completion(messages, model)
# Wait for human response (async)
response = await self._async_wait_for_completion(call_id)
return response
def completion(self, *args, **kwargs) -> ModelResponse:
"""Synchronous completion method.
Returns:
ModelResponse with human-generated text or tool calls
"""
messages = kwargs.get('messages', [])
model = kwargs.get('model', 'human')
messages = kwargs.get("messages", [])
model = kwargs.get("model", "human")
# Generate human response
human_response_data = self._generate_response(messages, model)
# Create ModelResponse with proper structure
from litellm.types.utils import ModelResponse, Choices, Message
import uuid
import time
import uuid
from litellm.types.utils import Choices, Message, ModelResponse
# Create message content based on response type
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
# Tool calls response
message = Message(
role="assistant",
content=human_response_data.get("response", ""),
tool_calls=human_response_data["tool_calls"]
tool_calls=human_response_data["tool_calls"],
)
else:
# Text response
message = Message(
role="assistant",
content=human_response_data.get("response", "")
)
choice = Choices(
finish_reason="stop",
index=0,
message=message
)
message = Message(role="assistant", content=human_response_data.get("response", ""))
choice = Choices(finish_reason="stop", index=0, message=message)
result = ModelResponse(
id=f"human-{uuid.uuid4()}",
choices=[choice],
created=int(time.time()),
model=f"human/{model}",
object="chat.completion"
object="chat.completion",
)
return result
async def acompletion(self, *args, **kwargs) -> ModelResponse:
"""Asynchronous completion method.
Returns:
ModelResponse with human-generated text or tool calls
"""
messages = kwargs.get('messages', [])
model = kwargs.get('model', 'human')
messages = kwargs.get("messages", [])
model = kwargs.get("model", "human")
# Generate human response
human_response_data = await self._async_generate_response(messages, model)
# Create ModelResponse with proper structure
from litellm.types.utils import ModelResponse, Choices, Message
import uuid
import time
import uuid
from litellm.types.utils import Choices, Message, ModelResponse
# Create message content based on response type
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
# Tool calls response
message = Message(
role="assistant",
content=human_response_data.get("response", ""),
tool_calls=human_response_data["tool_calls"]
tool_calls=human_response_data["tool_calls"],
)
else:
# Text response
message = Message(
role="assistant",
content=human_response_data.get("response", "")
)
choice = Choices(
finish_reason="stop",
index=0,
message=message
)
message = Message(role="assistant", content=human_response_data.get("response", ""))
choice = Choices(finish_reason="stop", index=0, message=message)
result = ModelResponse(
id=f"human-{uuid.uuid4()}",
choices=[choice],
created=int(time.time()),
model=f"human/{model}",
object="chat.completion"
object="chat.completion",
)
return result
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
"""Synchronous streaming method.
Yields:
Streaming chunks with human-generated text or tool calls
"""
messages = kwargs.get('messages', [])
model = kwargs.get('model', 'human')
messages = kwargs.get("messages", [])
model = kwargs.get("model", "human")
# Generate human response
human_response_data = self._generate_response(messages, model)
import time
# Handle tool calls vs text response
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
# Stream tool calls as a single chunk
@@ -319,22 +313,26 @@ class HumanAdapter(CustomLLM):
"is_finished": True,
"text": response_text,
"tool_use": None,
"usage": {"completion_tokens": len(response_text.split()), "prompt_tokens": 0, "total_tokens": len(response_text.split())},
"usage": {
"completion_tokens": len(response_text.split()),
"prompt_tokens": 0,
"total_tokens": len(response_text.split()),
},
}
yield generic_chunk
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
"""Asynchronous streaming method.
Yields:
Streaming chunks with human-generated text or tool calls
"""
messages = kwargs.get('messages', [])
model = kwargs.get('model', 'human')
messages = kwargs.get("messages", [])
model = kwargs.get("model", "human")
# Generate human response
human_response = await self._async_generate_response(messages, model)
# Return as single streaming chunk
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
@@ -342,7 +340,11 @@ class HumanAdapter(CustomLLM):
"is_finished": True,
"text": human_response,
"tool_use": None,
"usage": {"completion_tokens": len(human_response.split()), "prompt_tokens": 0, "total_tokens": len(human_response.split())},
"usage": {
"completion_tokens": len(human_response.split()),
"prompt_tokens": 0,
"total_tokens": len(human_response.split()),
},
}
yield generic_streaming_chunk
yield generic_streaming_chunk

View File

@@ -1,24 +1,26 @@
import asyncio
import functools
import warnings
import io
import base64
import functools
import io
import math
import re
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional, Tuple, cast
from PIL import Image
from litellm.types.utils import GenericStreamingChunk, ModelResponse
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, cast
from litellm import acompletion, completion
from litellm.llms.custom_llm import CustomLLM
from litellm import completion, acompletion
from litellm.types.utils import GenericStreamingChunk, ModelResponse
from PIL import Image
# Try to import MLX dependencies
try:
import mlx.core as mx
from mlx_vlm import load, generate
from mlx_vlm import generate, load
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_config
from transformers.tokenization_utils import PreTrainedTokenizer
MLX_AVAILABLE = True
except ImportError:
MLX_AVAILABLE = False
@@ -29,20 +31,28 @@ MIN_PIXELS = 100 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
def round_by_factor(number: float, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: float, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: float, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
@@ -70,61 +80,62 @@ def smart_resize(
class MLXVLMAdapter(CustomLLM):
"""MLX VLM Adapter for running vision-language models locally using MLX."""
def __init__(self, **kwargs):
"""Initialize the adapter.
Args:
**kwargs: Additional arguments
"""
super().__init__()
self.models = {} # Cache for loaded models
self.processors = {} # Cache for loaded processors
self.configs = {} # Cache for loaded configs
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
def _load_model_and_processor(self, model_name: str):
"""Load model and processor if not already cached.
Args:
model_name: Name of the model to load
Returns:
Tuple of (model, processor, config)
"""
if not MLX_AVAILABLE:
raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.")
if model_name not in self.models:
# Load model and processor
model_obj, processor = load(
model_name,
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
model_name, processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
)
config = load_config(model_name)
# Cache them
self.models[model_name] = model_obj
self.processors[model_name] = processor
self.configs[model_name] = config
return self.models[model_name], self.processors[model_name], self.configs[model_name]
def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
def _process_coordinates(
self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]
) -> str:
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
Args:
text: Text containing box tokens
original_size: Original image size (width, height)
model_size: Model processed image size (width, height)
Returns:
Text with processed coordinates
"""
# Find all box tokens
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
def process_coords(match):
model_x, model_y = int(match.group(1)), int(match.group(2))
# Scale coordinates from model space to original image space
@@ -132,15 +143,20 @@ class MLXVLMAdapter(CustomLLM):
new_x = int(model_x * original_size[0] / model_size[0]) # Width
new_y = int(model_y * original_size[1] / model_size[1]) # Height
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
return re.sub(box_pattern, process_coords, text)
def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Image.Image], Dict[int, Tuple[int, int]], Dict[int, Tuple[int, int]]]:
def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[
List[Dict[str, Any]],
List[Image.Image],
Dict[int, Tuple[int, int]],
Dict[int, Tuple[int, int]],
]:
"""Convert OpenAI format messages to MLX VLM format and extract images.
Args:
messages: Messages in OpenAI format
Returns:
Tuple of (processed_messages, images, original_sizes, model_sizes)
"""
@@ -149,13 +165,10 @@ class MLXVLMAdapter(CustomLLM):
original_sizes = {} # Track original sizes of images for coordinate mapping
model_sizes = {} # Track model processed sizes
image_index = 0
for message in messages:
processed_message = {
"role": message["role"],
"content": []
}
processed_message = {"role": message["role"], "content": []}
content = message.get("content", [])
if isinstance(content, str):
# Simple text content
@@ -165,164 +178,163 @@ class MLXVLMAdapter(CustomLLM):
processed_content = []
for item in content:
if item.get("type") == "text":
processed_content.append({
"type": "text",
"text": item.get("text", "")
})
processed_content.append({"type": "text", "text": item.get("text", "")})
elif item.get("type") == "image_url":
image_url = item.get("image_url", {}).get("url", "")
pil_image = None
if image_url.startswith("data:image/"):
# Extract base64 data
base64_data = image_url.split(',')[1]
base64_data = image_url.split(",")[1]
# Convert base64 to PIL Image
image_data = base64.b64decode(base64_data)
pil_image = Image.open(io.BytesIO(image_data))
else:
# Handle file path or URL
pil_image = Image.open(image_url)
# Store original image size for coordinate mapping
original_size = pil_image.size
original_sizes[image_index] = original_size
# Use smart_resize to determine model size
# Note: smart_resize expects (height, width) but PIL gives (width, height)
height, width = original_size[1], original_size[0]
new_height, new_width = smart_resize(height, width)
# Store model size in (width, height) format for consistent coordinate processing
model_sizes[image_index] = (new_width, new_height)
# Resize the image using the calculated dimensions from smart_resize
resized_image = pil_image.resize((new_width, new_height))
images.append(resized_image)
# Add image placeholder to content
processed_content.append({
"type": "image"
})
processed_content.append({"type": "image"})
image_index += 1
processed_message["content"] = processed_content
processed_messages.append(processed_message)
return processed_messages, images, original_sizes, model_sizes
def _generate(self, **kwargs) -> str:
"""Generate response using the local MLX VLM model.
Args:
**kwargs: Keyword arguments containing messages and model info
Returns:
Generated text response
"""
messages = kwargs.get('messages', [])
model_name = kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')
max_tokens = kwargs.get('max_tokens', 128)
messages = kwargs.get("messages", [])
model_name = kwargs.get("model", "mlx-community/UI-TARS-1.5-7B-4bit")
max_tokens = kwargs.get("max_tokens", 128)
# Warn about ignored kwargs
ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
ignored_kwargs = set(kwargs.keys()) - {"messages", "model", "max_tokens"}
if ignored_kwargs:
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
# Load model and processor
model, processor, config = self._load_model_and_processor(model_name)
# Convert messages and extract images
processed_messages, images, original_sizes, model_sizes = self._convert_messages(messages)
# Process user text input with box coordinates after image processing
# Swap original_size and model_size arguments for inverse transformation
for msg_idx, msg in enumerate(processed_messages):
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
content = msg.get("content", "")
if "<|box_start|>" in content and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
if (
"<|box_start|>" in content
and original_sizes
and model_sizes
and 0 in original_sizes
and 0 in model_sizes
):
orig_size = original_sizes[0]
model_size = model_sizes[0]
# Swap arguments to perform inverse transformation for user input
processed_messages[msg_idx]["content"] = self._process_coordinates(content, model_size, orig_size)
processed_messages[msg_idx]["content"] = self._process_coordinates(
content, model_size, orig_size
)
try:
# Format prompt according to model requirements using the processor directly
prompt = processor.apply_chat_template(
processed_messages,
tokenize=False,
add_generation_prompt=True,
return_tensors='pt'
processed_messages, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
tokenizer = cast(PreTrainedTokenizer, processor)
# Generate response
text_content, usage = generate(
model,
tokenizer,
str(prompt),
images, # type: ignore
model,
tokenizer,
str(prompt),
images, # type: ignore
verbose=False,
max_tokens=max_tokens
max_tokens=max_tokens,
)
except Exception as e:
raise RuntimeError(f"Error generating response: {str(e)}") from e
# Process coordinates in the response back to original image space
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
# Get original image size and model size (using the first image)
orig_size = original_sizes[0]
model_size = model_sizes[0]
# Check if output contains box tokens that need processing
if "<|box_start|>" in text_content:
# Process coordinates from model space back to original image space
text_content = self._process_coordinates(text_content, orig_size, model_size)
return text_content
def completion(self, *args, **kwargs) -> ModelResponse:
"""Synchronous completion method.
Returns:
ModelResponse with generated text
"""
generated_text = self._generate(**kwargs)
result = completion(
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
mock_response=generated_text,
)
return cast(ModelResponse, result)
async def acompletion(self, *args, **kwargs) -> ModelResponse:
"""Asynchronous completion method.
Returns:
ModelResponse with generated text
"""
# Run _generate in thread pool to avoid blocking
loop = asyncio.get_event_loop()
generated_text = await loop.run_in_executor(
self._executor,
functools.partial(self._generate, **kwargs)
self._executor, functools.partial(self._generate, **kwargs)
)
result = await acompletion(
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
mock_response=generated_text,
)
return cast(ModelResponse, result)
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
"""Synchronous streaming method.
Returns:
Iterator of GenericStreamingChunk
"""
generated_text = self._generate(**kwargs)
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
@@ -331,22 +343,21 @@ class MLXVLMAdapter(CustomLLM):
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
yield generic_streaming_chunk
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
"""Asynchronous streaming method.
Returns:
AsyncIterator of GenericStreamingChunk
"""
# Run _generate in thread pool to avoid blocking
loop = asyncio.get_event_loop()
generated_text = await loop.run_in_executor(
self._executor,
functools.partial(self._generate, **kwargs)
self._executor, functools.partial(self._generate, **kwargs)
)
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
@@ -355,5 +366,5 @@ class MLXVLMAdapter(CustomLLM):
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
yield generic_streaming_chunk
yield generic_streaming_chunk

View File

@@ -2,32 +2,40 @@ from typing import Optional
try:
from transformers import AutoConfig
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
from .generic import GenericHFModel
from .internvl import InternVLModel
from .opencua import OpenCUAModel
from .qwen2_5_vl import Qwen2_5_VLModel
from .internvl import InternVLModel
def load_model(model_name: str, device: str = "auto", trust_remote_code: bool = False):
"""Factory function to load and return the right model handler instance.
- If the underlying transformers config class matches OpenCUA, return OpenCUAModel
- Otherwise, return GenericHFModel
"""
if not HF_AVAILABLE:
raise ImportError(
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
)
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
cls = cfg.__class__.__name__
print(f"cls: {cls}")
if "OpenCUA" in cls:
return OpenCUAModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
return OpenCUAModel(
model_name=model_name, device=device, trust_remote_code=trust_remote_code
)
elif "Qwen2_5_VL" in cls:
return Qwen2_5_VLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
return Qwen2_5_VLModel(
model_name=model_name, device=device, trust_remote_code=trust_remote_code
)
elif "InternVL" in cls:
return InternVLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
return InternVLModel(
model_name=model_name, device=device, trust_remote_code=trust_remote_code
)
return GenericHFModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)

View File

@@ -1,9 +1,10 @@
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
# Hugging Face imports are local to avoid hard dependency at module import
try:
import torch # type: ignore
from transformers import AutoModel, AutoProcessor # type: ignore
HF_AVAILABLE = True
except Exception:
HF_AVAILABLE = False
@@ -14,10 +15,12 @@ class GenericHFModel:
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
"""
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
def __init__(
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
) -> None:
if not HF_AVAILABLE:
raise ImportError(
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
)
self.model_name = model_name
self.device = device
@@ -64,7 +67,7 @@ class GenericHFModel:
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
# Trim prompt tokens from output
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
# Decode
output_text = self.processor.batch_decode(

View File

@@ -1,19 +1,22 @@
from __future__ import annotations
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
# Hugging Face imports are local to avoid hard dependency at module import
try:
import torch # type: ignore
from transformers import AutoModel, AutoTokenizer # type: ignore
# Attempt to import InternVL's model dependencies
import einops as _ # type: ignore
import timm as _ # type: ignore
from PIL import Image # type: ignore
import torchvision.transforms as T # type: ignore
from torchvision.transforms.functional import InterpolationMode # type: ignore
import base64 # type: ignore
from io import BytesIO # type: ignore
# Attempt to import InternVL's model dependencies
import einops as _ # type: ignore
import requests # type: ignore
import timm as _ # type: ignore
import torch # type: ignore
import torchvision.transforms as T # type: ignore
from PIL import Image # type: ignore
from torchvision.transforms.functional import InterpolationMode # type: ignore
from transformers import AutoModel, AutoTokenizer # type: ignore
HF_AVAILABLE = True
except Exception:
HF_AVAILABLE = False
@@ -25,10 +28,12 @@ class InternVLModel:
Provides preprocessing to support multi-turn conversations with multiple images.
"""
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
def __init__(
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
) -> None:
if not HF_AVAILABLE:
raise ImportError(
"InternVL dependencies not found. Install with: pip install \"cua-agent[internvl-hf]\""
'InternVL dependencies not found. Install with: pip install "cua-agent[internvl-hf]"'
)
self.model_name = model_name
self.device = device
@@ -60,16 +65,25 @@ class InternVLModel:
def _build_transform(self, input_size: int) -> T.Compose:
MEAN, STD = self.IMAGENET_MEAN, self.IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD),
]
)
return transform
def _find_closest_aspect_ratio(self, aspect_ratio: float, target_ratios: List[tuple], width: int, height: int, image_size: int):
best_ratio_diff = float('inf')
def _find_closest_aspect_ratio(
self,
aspect_ratio: float,
target_ratios: List[tuple],
width: int,
height: int,
image_size: int,
):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
@@ -83,17 +97,29 @@ class InternVLModel:
best_ratio = ratio
return best_ratio
def _dynamic_preprocess(self, image: Image.Image, min_num: int = 1, max_num: int = 12, image_size: int = 448, use_thumbnail: bool = True) -> List[Image.Image]:
def _dynamic_preprocess(
self,
image: Image.Image,
min_num: int = 1,
max_num: int = 12,
image_size: int = 448,
use_thumbnail: bool = True,
) -> List[Image.Image]:
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
target_aspect_ratio = self._find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
@@ -106,7 +132,7 @@ class InternVLModel:
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
((i // (target_width // image_size)) + 1) * image_size,
)
split_img = resized_img.crop(box)
processed_images.append(split_img)
@@ -122,20 +148,24 @@ class InternVLModel:
# data URL base64
header, b64data = src.split(",", 1)
img_bytes = base64.b64decode(b64data)
return Image.open(BytesIO(img_bytes)).convert('RGB')
return Image.open(BytesIO(img_bytes)).convert("RGB")
if src.startswith("http://") or src.startswith("https://"):
resp = requests.get(src, timeout=10)
resp.raise_for_status()
return Image.open(BytesIO(resp.content)).convert('RGB')
return Image.open(BytesIO(resp.content)).convert("RGB")
# Assume local file path
return Image.open(src).convert('RGB')
return Image.open(src).convert("RGB")
def _images_to_pixel_values(self, images: List[Image.Image], input_size: int = 448, max_num: int = 12):
def _images_to_pixel_values(
self, images: List[Image.Image], input_size: int = 448, max_num: int = 12
):
transform = self._build_transform(input_size=input_size)
pixel_values_list = []
num_patches_list: List[int] = []
for img in images:
tiles = self._dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
tiles = self._dynamic_preprocess(
img, image_size=input_size, use_thumbnail=True, max_num=max_num
)
pv = [transform(tile) for tile in tiles]
pv = torch.stack(pv)
num_patches_list.append(pv.shape[0])
@@ -191,7 +221,9 @@ class InternVLModel:
last_user_text_parts = parts_text or last_user_text_parts
elif role == "assistant":
# Only keep text content for history
parts_text = [item.get("text", "") for item in content_items if item.get("type") == "text"]
parts_text = [
item.get("text", "") for item in content_items if item.get("type") == "text"
]
text = "\n".join(parts_text).strip()
if text:
context_lines.append(f"Assistant: {text}")
@@ -200,7 +232,9 @@ class InternVLModel:
pixel_values = None
num_patches_list: List[int] = []
if all_images:
pixel_values, num_patches_list = self._images_to_pixel_values(all_images, input_size=448, max_num=12)
pixel_values, num_patches_list = self._images_to_pixel_values(
all_images, input_size=448, max_num=12
)
if pixel_values is not None:
# Convert dtype/device as in docs
pixel_values = pixel_values.to(torch.bfloat16)
@@ -246,7 +280,9 @@ class InternVLModel:
num_patches_list=num_patches_list,
)
else:
response = self.model.chat(self.tokenizer, pixel_values, question, generation_config)
response = self.model.chat(
self.tokenizer, pixel_values, question, generation_config
)
except Exception as e:
# Fallback: return empty string to avoid crashing the adapter
return ""

View File

@@ -1,13 +1,18 @@
from typing import List, Dict, Any
import re
import base64
import re
from io import BytesIO
from typing import Any, Dict, List
try:
import blobfile as _ # assert blobfile is installed
import torch # type: ignore
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor # type: ignore
from PIL import Image # type: ignore
import blobfile as _ # assert blobfile is installed
from transformers import ( # type: ignore
AutoImageProcessor,
AutoModel,
AutoTokenizer,
)
OPENCUA_AVAILABLE = True
except Exception:
OPENCUA_AVAILABLE = False
@@ -16,10 +21,12 @@ except Exception:
class OpenCUAModel:
"""OpenCUA model handler using AutoTokenizer, AutoModel and AutoImageProcessor."""
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
def __init__(
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
) -> None:
if not OPENCUA_AVAILABLE:
raise ImportError(
"OpenCUA requirements not found. Install with: pip install \"cua-agent[opencua-hf]\""
'OpenCUA requirements not found. Install with: pip install "cua-agent[opencua-hf]"'
)
self.model_name = model_name
self.device = device
@@ -56,7 +63,11 @@ class OpenCUAModel:
return ""
def generate(self, messages: List[Dict[str, Any]], max_new_tokens: int = 512) -> str:
assert self.model is not None and self.tokenizer is not None and self.image_processor is not None
assert (
self.model is not None
and self.tokenizer is not None
and self.image_processor is not None
)
# Tokenize text side using chat template
input_ids = self.tokenizer.apply_chat_template(
@@ -74,7 +85,11 @@ class OpenCUAModel:
pixel_values = torch.tensor(image_info["pixel_values"]).to(
dtype=torch.bfloat16, device=self.model.device
)
grid_thws = torch.tensor(image_info["image_grid_thw"]) if "image_grid_thw" in image_info else None
grid_thws = (
torch.tensor(image_info["image_grid_thw"])
if "image_grid_thw" in image_info
else None
)
gen_kwargs: Dict[str, Any] = {
"max_new_tokens": max_new_tokens,

View File

@@ -1,9 +1,10 @@
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
# Hugging Face imports are local to avoid hard dependency at module import
try:
import torch # type: ignore
from transformers import AutoModelForImageTextToText, AutoProcessor # type: ignore
HF_AVAILABLE = True
except Exception:
HF_AVAILABLE = False
@@ -14,10 +15,12 @@ class Qwen2_5_VLModel:
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
"""
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
def __init__(
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
) -> None:
if not HF_AVAILABLE:
raise ImportError(
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
)
self.model_name = model_name
self.device = device
@@ -64,7 +67,7 @@ class Qwen2_5_VLModel:
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
# Trim prompt tokens from output
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
# Decode
output_text = self.processor.batch_decode(

View File

@@ -3,76 +3,83 @@ ComputerAgent - Main agent class that selects and runs agent loops
"""
import asyncio
from pathlib import Path
from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set, Tuple
from litellm.responses.utils import Usage
from .types import (
Messages,
AgentCapability,
ToolError,
IllegalArgumentError
)
from .responses import make_tool_error_item, replace_failed_computer_calls_with_function_calls
from .decorators import find_agent_config
import inspect
import json
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
import litellm
import litellm.utils
import inspect
from litellm.responses.utils import Usage
from .adapters import (
HuggingFaceLocalAdapter,
HumanAdapter,
MLXVLMAdapter,
)
from .callbacks import (
ImageRetentionCallback,
LoggingCallback,
TrajectorySaverCallback,
BudgetManagerCallback,
TelemetryCallback,
ImageRetentionCallback,
LoggingCallback,
OperatorNormalizerCallback,
PromptInstructionsCallback,
TelemetryCallback,
TrajectorySaverCallback,
)
from .computers import (
AsyncComputerHandler,
is_agent_computer,
make_computer_handler
from .computers import AsyncComputerHandler, is_agent_computer, make_computer_handler
from .decorators import find_agent_config
from .responses import (
make_tool_error_item,
replace_failed_computer_calls_with_function_calls,
)
from .types import AgentCapability, IllegalArgumentError, Messages, ToolError
def assert_callable_with(f, *args, **kwargs):
"""Check if function can be called with given arguments."""
try:
inspect.signature(f).bind(*args, **kwargs)
return True
except TypeError as e:
sig = inspect.signature(f)
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
"""Check if function can be called with given arguments."""
try:
inspect.signature(f).bind(*args, **kwargs)
return True
except TypeError as e:
sig = inspect.signature(f)
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
def get_json(obj: Any, max_depth: int = 10) -> Any:
def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
if seen is None:
seen = set()
# Use model_dump() if available
if hasattr(o, 'model_dump'):
if hasattr(o, "model_dump"):
return o.model_dump()
# Check depth limit
if depth > max_depth:
return f"<max_depth_exceeded:{max_depth}>"
# Check for circular references using object id
obj_id = id(o)
if obj_id in seen:
return f"<circular_reference:{type(o).__name__}>"
# Handle Computer objects
if hasattr(o, '__class__') and 'computer' in getattr(o, '__class__').__name__.lower():
if hasattr(o, "__class__") and "computer" in o.__class__.__name__.lower():
return f"<computer:{o.__class__.__name__}>"
# Handle objects with __dict__
if hasattr(o, '__dict__'):
if hasattr(o, "__dict__"):
seen.add(obj_id)
try:
result = {}
@@ -84,7 +91,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
return result
finally:
seen.discard(obj_id)
# Handle common types that might contain nested objects
elif isinstance(o, dict):
seen.add(obj_id)
@@ -96,7 +103,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
}
finally:
seen.discard(obj_id)
elif isinstance(o, (list, tuple, set)):
seen.add(obj_id)
try:
@@ -107,32 +114,33 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
]
finally:
seen.discard(obj_id)
# For basic types that json.dumps can handle
elif isinstance(o, (str, int, float, bool)) or o is None:
return o
# Fallback to string representation
else:
return str(o)
def remove_nones(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: remove_nones(v) for k, v in obj.items() if v is not None}
elif isinstance(obj, list):
return [remove_nones(item) for item in obj if item is not None]
return obj
# Serialize with circular reference and depth protection
serialized = custom_serializer(obj)
# Convert to JSON string and back to ensure JSON compatibility
json_str = json.dumps(serialized)
parsed = json.loads(json_str)
# Final cleanup of any remaining None values
return remove_nones(parsed)
def sanitize_message(msg: Any) -> Any:
"""Return a copy of the message with image_url omitted for computer_call_output messages."""
if msg.get("type") == "computer_call_output":
@@ -143,19 +151,24 @@ def sanitize_message(msg: Any) -> Any:
return sanitized
return msg
def get_output_call_ids(messages: List[Dict[str, Any]]) -> List[str]:
call_ids = []
for message in messages:
if message.get("type") == "computer_call_output" or message.get("type") == "function_call_output":
if (
message.get("type") == "computer_call_output"
or message.get("type") == "function_call_output"
):
call_ids.append(message.get("call_id"))
return call_ids
class ComputerAgent:
"""
Main agent class that automatically selects the appropriate agent loop
based on the model and executes tool calls.
"""
def __init__(
self,
model: str,
@@ -172,11 +185,11 @@ class ComputerAgent:
max_trajectory_budget: Optional[float | dict] = None,
telemetry_enabled: Optional[bool] = True,
trust_remote_code: Optional[bool] = False,
**kwargs
**kwargs,
):
"""
Initialize ComputerAgent.
Args:
model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
tools: List of tools (computer objects, decorated functions, etc.)
@@ -193,11 +206,11 @@ class ComputerAgent:
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
**kwargs: Additional arguments passed to the agent loop
"""
"""
# If the loop is "human/human", we need to prefix a grounding model fallback
if model in ["human/human", "human"]:
model = "openai/computer-use-preview+human/human"
self.model = model
self.tools = tools or []
self.custom_loop = custom_loop
@@ -236,34 +249,33 @@ class ComputerAgent:
# Add image retention callback if only_n_most_recent_images is set
if self.only_n_most_recent_images:
self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
# Add trajectory saver callback if trajectory_dir is set
if self.trajectory_dir:
if isinstance(self.trajectory_dir, dict):
self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
elif isinstance(self.trajectory_dir, (str, Path)):
self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
# Add budget manager if max_trajectory_budget is set
if max_trajectory_budget:
if isinstance(max_trajectory_budget, dict):
self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
else:
self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
# == Enable local model providers w/ LiteLLM ==
# Register local model providers
hf_adapter = HuggingFaceLocalAdapter(
device="auto",
trust_remote_code=self.trust_remote_code or False
device="auto", trust_remote_code=self.trust_remote_code or False
)
human_adapter = HumanAdapter()
mlx_adapter = MLXVLMAdapter()
litellm.custom_provider_map = [
{"provider": "huggingface-local", "custom_handler": hf_adapter},
{"provider": "human", "custom_handler": human_adapter},
{"provider": "mlx", "custom_handler": mlx_adapter}
{"provider": "mlx", "custom_handler": mlx_adapter},
]
litellm.suppress_debug_info = True
@@ -280,16 +292,16 @@ class ComputerAgent:
# Instantiate the agent config class
self.agent_loop = config_info.agent_class()
self.agent_config_info = config_info
self.tool_schemas = []
self.computer_handler = None
async def _initialize_computers(self):
"""Initialize computer objects"""
if not self.tool_schemas:
# Process tools and create tool schemas
self.tool_schemas = self._process_tools()
# Find computer tool and create interface adapter
computer_handler = None
for schema in self.tool_schemas:
@@ -297,7 +309,7 @@ class ComputerAgent:
computer_handler = await make_computer_handler(schema["computer"])
break
self.computer_handler = computer_handler
def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
"""Process input messages and create schemas for the agent loop"""
if isinstance(input, str):
@@ -307,69 +319,73 @@ class ComputerAgent:
def _process_tools(self) -> List[Dict[str, Any]]:
"""Process tools and create schemas for the agent loop"""
schemas = []
for tool in self.tools:
# Check if it's a computer object (has interface attribute)
if is_agent_computer(tool):
# This is a computer tool - will be handled by agent loop
schemas.append({
"type": "computer",
"computer": tool
})
schemas.append({"type": "computer", "computer": tool})
elif callable(tool):
# Use litellm.utils.function_to_dict to extract schema from docstring
try:
function_schema = litellm.utils.function_to_dict(tool)
schemas.append({
"type": "function",
"function": function_schema
})
schemas.append({"type": "function", "function": function_schema})
except Exception as e:
print(f"Warning: Could not process tool {tool}: {e}")
else:
print(f"Warning: Unknown tool type: {tool}")
return schemas
def _get_tool(self, name: str) -> Optional[Callable]:
"""Get a tool by name"""
for tool in self.tools:
if hasattr(tool, '__name__') and tool.__name__ == name:
if hasattr(tool, "__name__") and tool.__name__ == name:
return tool
elif hasattr(tool, 'func') and tool.func.__name__ == name:
elif hasattr(tool, "func") and tool.func.__name__ == name:
return tool
return None
# ============================================================================
# AGENT RUN LOOP LIFECYCLE HOOKS
# ============================================================================
async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Initialize run tracking by calling callbacks."""
for callback in self.callbacks:
if hasattr(callback, 'on_run_start'):
if hasattr(callback, "on_run_start"):
await callback.on_run_start(kwargs, old_items)
async def _on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
async def _on_run_end(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> None:
"""Finalize run tracking by calling callbacks."""
for callback in self.callbacks:
if hasattr(callback, 'on_run_end'):
if hasattr(callback, "on_run_end"):
await callback.on_run_end(kwargs, old_items, new_items)
async def _on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
async def _on_run_continue(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> bool:
"""Check if run should continue by calling callbacks."""
for callback in self.callbacks:
if hasattr(callback, 'on_run_continue'):
if hasattr(callback, "on_run_continue"):
should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
if not should_continue:
return False
return True
async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Prepare messages for the LLM call by applying callbacks."""
result = messages
for callback in self.callbacks:
if hasattr(callback, 'on_llm_start'):
if hasattr(callback, "on_llm_start"):
result = await callback.on_llm_start(result)
return result
@@ -377,82 +393,91 @@ class ComputerAgent:
"""Postprocess messages after the LLM call by applying callbacks."""
result = messages
for callback in self.callbacks:
if hasattr(callback, 'on_llm_end'):
if hasattr(callback, "on_llm_end"):
result = await callback.on_llm_end(result)
return result
async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
"""Called when responses are received."""
for callback in self.callbacks:
if hasattr(callback, 'on_responses'):
if hasattr(callback, "on_responses"):
await callback.on_responses(get_json(kwargs), get_json(responses))
async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a computer call is about to start."""
for callback in self.callbacks:
if hasattr(callback, 'on_computer_call_start'):
if hasattr(callback, "on_computer_call_start"):
await callback.on_computer_call_start(get_json(item))
async def _on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
async def _on_computer_call_end(
self, item: Dict[str, Any], result: List[Dict[str, Any]]
) -> None:
"""Called when a computer call has completed."""
for callback in self.callbacks:
if hasattr(callback, 'on_computer_call_end'):
if hasattr(callback, "on_computer_call_end"):
await callback.on_computer_call_end(get_json(item), get_json(result))
async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a function call is about to start."""
for callback in self.callbacks:
if hasattr(callback, 'on_function_call_start'):
if hasattr(callback, "on_function_call_start"):
await callback.on_function_call_start(get_json(item))
async def _on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
async def _on_function_call_end(
self, item: Dict[str, Any], result: List[Dict[str, Any]]
) -> None:
"""Called when a function call has completed."""
for callback in self.callbacks:
if hasattr(callback, 'on_function_call_end'):
if hasattr(callback, "on_function_call_end"):
await callback.on_function_call_end(get_json(item), get_json(result))
async def _on_text(self, item: Dict[str, Any]) -> None:
"""Called when a text message is encountered."""
for callback in self.callbacks:
if hasattr(callback, 'on_text'):
if hasattr(callback, "on_text"):
await callback.on_text(get_json(item))
async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
"""Called when an LLM API call is about to start."""
for callback in self.callbacks:
if hasattr(callback, 'on_api_start'):
if hasattr(callback, "on_api_start"):
await callback.on_api_start(get_json(kwargs))
async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
"""Called when an LLM API call has completed."""
for callback in self.callbacks:
if hasattr(callback, 'on_api_end'):
if hasattr(callback, "on_api_end"):
await callback.on_api_end(get_json(kwargs), get_json(result))
async def _on_usage(self, usage: Dict[str, Any]) -> None:
"""Called when usage information is received."""
for callback in self.callbacks:
if hasattr(callback, 'on_usage'):
if hasattr(callback, "on_usage"):
await callback.on_usage(get_json(usage))
async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
"""Called when a screenshot is taken."""
for callback in self.callbacks:
if hasattr(callback, 'on_screenshot'):
if hasattr(callback, "on_screenshot"):
await callback.on_screenshot(screenshot, name)
# ============================================================================
# AGENT OUTPUT PROCESSING
# ============================================================================
async def _handle_item(self, item: Any, computer: Optional[AsyncComputerHandler] = None, ignore_call_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]:
async def _handle_item(
self,
item: Any,
computer: Optional[AsyncComputerHandler] = None,
ignore_call_ids: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Handle each item; may cause a computer action + screenshot."""
call_id = item.get("call_id")
if ignore_call_ids and call_id and call_id in ignore_call_ids:
return []
item_type = item.get("type", None)
if item_type == "message":
await self._on_text(item)
# # Print messages
@@ -461,7 +486,7 @@ class ComputerAgent:
# if content_item.get("text"):
# print(content_item.get("text"))
return []
try:
if item_type == "computer_call":
await self._on_computer_call_start(item)
@@ -472,14 +497,16 @@ class ComputerAgent:
action = item.get("action")
action_type = action.get("type")
if action_type is None:
print(f"Action type cannot be `None`: action={action}, action_type={action_type}")
print(
f"Action type cannot be `None`: action={action}, action_type={action_type}"
)
return []
# Extract action arguments (all fields except 'type')
action_args = {k: v for k, v in action.items() if k != "type"}
# print(f"{action_type}({action_args})")
# Execute the computer action
computer_method = getattr(computer, action_type, None)
if computer_method:
@@ -487,13 +514,13 @@ class ComputerAgent:
await computer_method(**action_args)
else:
raise ToolError(f"Unknown computer action: {action_type}")
# Take screenshot after action
if self.screenshot_delay and self.screenshot_delay > 0:
await asyncio.sleep(self.screenshot_delay)
screenshot_base64 = await computer.screenshot()
await self._on_screenshot(screenshot_base64, "screenshot_after")
# Handle safety checks
pending_checks = item.get("pending_safety_checks", [])
acknowledged_checks = []
@@ -505,7 +532,7 @@ class ComputerAgent:
# acknowledged_checks.append(check)
# else:
# raise ValueError(f"Safety check failed: {check_message}")
# Create call output
call_output = {
"type": "computer_call_output",
@@ -516,25 +543,25 @@ class ComputerAgent:
"image_url": f"data:image/png;base64,{screenshot_base64}",
},
}
# # Additional URL safety checks for browser environments
# if await computer.get_environment() == "browser":
# current_url = await computer.get_current_url()
# call_output["output"]["current_url"] = current_url
# # TODO: implement a callback for URL safety checks
# # check_blocklisted_url(current_url)
result = [call_output]
await self._on_computer_call_end(item, result)
return result
if item_type == "function_call":
await self._on_function_call_start(item)
# Perform function call
function = self._get_tool(item.get("name"))
if not function:
raise ToolError(f"Function {item.get("name")} not found")
raise ToolError(f"Function {item.get('name')} not found")
args = json.loads(item.get("arguments"))
# Validate arguments before execution
@@ -545,14 +572,14 @@ class ComputerAgent:
result = await function(**args)
else:
result = await asyncio.to_thread(function, **args)
# Create function call output
call_output = {
"type": "function_call_output",
"call_id": item.get("call_id"),
"output": str(result),
}
result = [call_output]
await self._on_function_call_end(item, result)
return result
@@ -564,36 +591,35 @@ class ComputerAgent:
# ============================================================================
# MAIN AGENT LOOP
# ============================================================================
async def run(
self,
messages: Messages,
stream: bool = False,
**kwargs
self, messages: Messages, stream: bool = False, **kwargs
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Run the agent with the given messages using Computer protocol handler pattern.
Args:
messages: List of message dictionaries
stream: Whether to stream the response
**kwargs: Additional arguments
Returns:
AsyncGenerator that yields response chunks
"""
if not self.agent_config_info:
raise ValueError("Agent configuration not found")
capabilities = self.get_capabilities()
if "step" not in capabilities:
raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions")
raise ValueError(
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions"
)
await self._initialize_computers()
# Merge kwargs
merged_kwargs = {**self.kwargs, **kwargs}
old_items = self._process_input(messages)
new_items = []
@@ -603,7 +629,7 @@ class ComputerAgent:
"stream": stream,
"model": self.model,
"agent_loop": self.agent_config_info.agent_class.__name__,
**merged_kwargs
**merged_kwargs,
}
await self._on_run_start(run_kwargs, old_items)
@@ -620,7 +646,7 @@ class ComputerAgent:
combined_messages = old_items + new_items
combined_messages = replace_failed_computer_calls_with_function_calls(combined_messages)
preprocessed_messages = await self._on_llm_start(combined_messages)
loop_kwargs = {
"messages": preprocessed_messages,
"model": self.model,
@@ -629,7 +655,7 @@ class ComputerAgent:
"computer_handler": self.computer_handler,
"max_retries": self.max_retries,
"use_prompt_caching": self.use_prompt_caching,
**merged_kwargs
**merged_kwargs,
}
# Run agent loop iteration
@@ -641,13 +667,13 @@ class ComputerAgent:
_on_screenshot=self._on_screenshot,
)
result = get_json(result)
# Lifecycle hook: Postprocess messages after the LLM call
# Use cases:
# - PII deanonymization (if you want tool calls to see PII)
result["output"] = await self._on_llm_end(result.get("output", []))
await self._on_responses(loop_kwargs, result)
# Yield agent response
yield result
@@ -659,7 +685,9 @@ class ComputerAgent:
# Handle computer actions
for item in result.get("output"):
partial_items = await self._handle_item(item, self.computer_handler, ignore_call_ids=output_call_ids)
partial_items = await self._handle_item(
item, self.computer_handler, ignore_call_ids=output_call_ids
)
new_items += partial_items
# Yield partial response
@@ -669,54 +697,52 @@ class ComputerAgent:
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
),
}
await self._on_run_end(loop_kwargs, old_items, new_items)
async def predict_click(
self,
instruction: str,
image_b64: Optional[str] = None
self, instruction: str, image_b64: Optional[str] = None
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates based on image and instruction.
Args:
instruction: Instruction for where to click
image_b64: Base64 encoded image (optional, will take screenshot if not provided)
Returns:
None or tuple with (x, y) coordinates
"""
if not self.agent_config_info:
raise ValueError("Agent configuration not found")
capabilities = self.get_capabilities()
if "click" not in capabilities:
raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions")
if hasattr(self.agent_loop, 'predict_click'):
raise ValueError(
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions"
)
if hasattr(self.agent_loop, "predict_click"):
if not image_b64:
if not self.computer_handler:
raise ValueError("Computer tool or image_b64 is required for predict_click")
image_b64 = await self.computer_handler.screenshot()
return await self.agent_loop.predict_click(
model=self.model,
image_b64=image_b64,
instruction=instruction
model=self.model, image_b64=image_b64, instruction=instruction
)
return None
def get_capabilities(self) -> List[AgentCapability]:
"""
Get list of capabilities supported by the current agent config.
Returns:
List of capability strings (e.g., ["step", "click"])
"""
if not self.agent_config_info:
raise ValueError("Agent configuration not found")
if hasattr(self.agent_loop, 'get_capabilities'):
if hasattr(self.agent_loop, "get_capabilities"):
return self.agent_loop.get_capabilities()
return ["step"] # Default capability
return ["step"] # Default capability

View File

@@ -3,17 +3,17 @@ Callback system for ComputerAgent preprocessing and postprocessing hooks.
"""
from .base import AsyncCallbackHandler
from .budget_manager import BudgetManagerCallback
from .image_retention import ImageRetentionCallback
from .logging import LoggingCallback
from .trajectory_saver import TrajectorySaverCallback
from .budget_manager import BudgetManagerCallback
from .telemetry import TelemetryCallback
from .operator_validator import OperatorNormalizerCallback
from .prompt_instructions import PromptInstructionsCallback
from .telemetry import TelemetryCallback
from .trajectory_saver import TrajectorySaverCallback
__all__ = [
"AsyncCallbackHandler",
"ImageRetentionCallback",
"ImageRetentionCallback",
"LoggingCallback",
"TrajectorySaverCallback",
"BudgetManagerCallback",

View File

@@ -3,7 +3,7 @@ Base callback handler interface for ComputerAgent preprocessing and postprocessi
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Union
from typing import Any, Dict, List, Optional, Union
class AsyncCallbackHandler(ABC):
@@ -16,42 +16,52 @@ class AsyncCallbackHandler(ABC):
"""Called at the start of an agent run loop."""
pass
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
async def on_run_end(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> None:
"""Called at the end of an agent run loop."""
pass
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
async def on_run_continue(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> bool:
"""Called during agent run loop to determine if execution should continue.
Args:
kwargs: Run arguments
old_items: Original messages
new_items: New messages generated during run
Returns:
True to continue execution, False to stop
"""
return True
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Called before messages are sent to the agent loop.
Args:
messages: List of message dictionaries to preprocess
Returns:
List of preprocessed message dictionaries
"""
return messages
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Called after the agent loop returns output.
Args:
output: List of output message dictionaries to postprocess
Returns:
List of postprocessed output dictionaries
"""
@@ -60,63 +70,67 @@ class AsyncCallbackHandler(ABC):
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
"""
Called when a computer call is about to start.
Args:
item: The computer call item dictionary
"""
pass
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
async def on_computer_call_end(
self, item: Dict[str, Any], result: List[Dict[str, Any]]
) -> None:
"""
Called when a computer call has completed.
Args:
item: The computer call item dictionary
result: The result of the computer call
"""
pass
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
"""
Called when a function call is about to start.
Args:
item: The function call item dictionary
"""
pass
async def on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
async def on_function_call_end(
self, item: Dict[str, Any], result: List[Dict[str, Any]]
) -> None:
"""
Called when a function call has completed.
Args:
item: The function call item dictionary
result: The result of the function call
"""
pass
async def on_text(self, item: Dict[str, Any]) -> None:
"""
Called when a text message is encountered.
Args:
item: The message item dictionary
"""
pass
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
"""
Called when an API call is about to start.
Args:
kwargs: The kwargs being passed to the API call
"""
pass
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
"""
Called when an API call has completed.
Args:
kwargs: The kwargs that were passed to the API call
result: The result of the API call
@@ -126,7 +140,7 @@ class AsyncCallbackHandler(ABC):
async def on_usage(self, usage: Dict[str, Any]) -> None:
"""
Called when usage information is received.
Args:
usage: The usage information
"""
@@ -135,7 +149,7 @@ class AsyncCallbackHandler(ABC):
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
"""
Called when a screenshot is taken.
Args:
screenshot: The screenshot image
name: The name of the screenshot
@@ -145,9 +159,9 @@ class AsyncCallbackHandler(ABC):
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
"""
Called when responses are received.
Args:
kwargs: The kwargs being passed to the agent loop
responses: The responses received
"""
pass
pass

View File

@@ -1,17 +1,23 @@
from typing import Dict, List, Any
from typing import Any, Dict, List
from .base import AsyncCallbackHandler
class BudgetExceededError(Exception):
"""Exception raised when budget is exceeded."""
pass
class BudgetManagerCallback(AsyncCallbackHandler):
"""Budget manager callback that tracks usage costs and can stop execution when budget is exceeded."""
def __init__(self, max_budget: float, reset_after_each_run: bool = True, raise_error: bool = False):
def __init__(
self, max_budget: float, reset_after_each_run: bool = True, raise_error: bool = False
):
"""
Initialize BudgetManagerCallback.
Args:
max_budget: Maximum budget allowed
reset_after_each_run: Whether to reset budget after each run
@@ -21,24 +27,30 @@ class BudgetManagerCallback(AsyncCallbackHandler):
self.reset_after_each_run = reset_after_each_run
self.raise_error = raise_error
self.total_cost = 0.0
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Reset budget if configured to do so."""
if self.reset_after_each_run:
self.total_cost = 0.0
async def on_usage(self, usage: Dict[str, Any]) -> None:
"""Track usage costs."""
if "response_cost" in usage:
self.total_cost += usage["response_cost"]
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
async def on_run_continue(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> bool:
"""Check if budget allows continuation."""
if self.total_cost >= self.max_budget:
if self.raise_error:
raise BudgetExceededError(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
raise BudgetExceededError(
f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}"
)
else:
print(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
return False
return True

View File

@@ -2,7 +2,8 @@
Image retention callback handler that limits the number of recent images in message history.
"""
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
from .base import AsyncCallbackHandler
@@ -11,40 +12,40 @@ class ImageRetentionCallback(AsyncCallbackHandler):
Callback handler that applies image retention policy to limit the number
of recent images in message history to prevent context window overflow.
"""
def __init__(self, only_n_most_recent_images: Optional[int] = None):
"""
Initialize the image retention callback.
Args:
only_n_most_recent_images: If set, only keep the N most recent images in message history
"""
self.only_n_most_recent_images = only_n_most_recent_images
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Apply image retention policy to messages before sending to agent loop.
Args:
messages: List of message dictionaries
Returns:
List of messages with image retention policy applied
"""
if self.only_n_most_recent_images is None:
return messages
return self._apply_image_retention(messages)
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Apply image retention policy to keep only the N most recent images.
Removes computer_call_output items with image_url and their corresponding computer_call items,
keeping only the most recent N image pairs based on only_n_most_recent_images setting.
Args:
messages: List of message dictionaries
Returns:
Filtered list of messages with image retention applied
"""
@@ -78,7 +79,11 @@ class ImageRetentionCallback(AsyncCallbackHandler):
# Remove the immediately preceding computer_call with matching call_id (if present)
call_id = messages[idx].get("call_id")
prev_idx = idx - 1
if prev_idx >= 0 and messages[prev_idx].get("type") == "computer_call" and messages[prev_idx].get("call_id") == call_id:
if (
prev_idx >= 0
and messages[prev_idx].get("type") == "computer_call"
and messages[prev_idx].get("call_id") == call_id
):
to_remove.add(prev_idx)
# Check a single reasoning immediately before that computer_call
r_idx = prev_idx - 1
@@ -87,4 +92,4 @@ class ImageRetentionCallback(AsyncCallbackHandler):
# Construct filtered list
filtered = [m for i, m in enumerate(messages) if i not in to_remove]
return filtered
return filtered

View File

@@ -4,17 +4,18 @@ Logging callback for ComputerAgent that provides configurable logging of agent l
import json
import logging
from typing import Dict, List, Any, Optional, Union
from typing import Any, Dict, List, Optional, Union
from .base import AsyncCallbackHandler
def sanitize_image_urls(data: Any) -> Any:
"""
Recursively search for 'image_url' keys and set their values to '[omitted]'.
Args:
data: Any data structure (dict, list, or primitive type)
Returns:
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
"""
@@ -28,11 +29,11 @@ def sanitize_image_urls(data: Any) -> Any:
# Recursively sanitize the value
sanitized[key] = sanitize_image_urls(value)
return sanitized
elif isinstance(data, list):
# Recursively sanitize each item in the list
return [sanitize_image_urls(item) for item in data]
else:
# For primitive types (str, int, bool, None, etc.), return as-is
return data
@@ -41,37 +42,36 @@ def sanitize_image_urls(data: Any) -> Any:
class LoggingCallback(AsyncCallbackHandler):
"""
Callback handler that logs agent lifecycle events with configurable verbosity.
Logging levels:
- DEBUG: All events including API calls, message preprocessing, and detailed outputs
- INFO: Major lifecycle events (start/end, messages, outputs)
- INFO: Major lifecycle events (start/end, messages, outputs)
- WARNING: Only warnings and errors
- ERROR: Only errors
"""
def __init__(self, logger: Optional[logging.Logger] = None, level: int = logging.INFO):
"""
Initialize the logging callback.
Args:
logger: Logger instance to use. If None, creates a logger named 'agent.ComputerAgent'
level: Logging level (logging.DEBUG, logging.INFO, etc.)
"""
self.logger = logger or logging.getLogger('agent.ComputerAgent')
self.logger = logger or logging.getLogger("agent.ComputerAgent")
self.level = level
# Set up logger if it doesn't have handlers
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(level)
def _update_usage(self, usage: Dict[str, Any]) -> None:
"""Update total usage statistics."""
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
for key, value in source.items():
if isinstance(value, dict):
@@ -82,18 +82,25 @@ class LoggingCallback(AsyncCallbackHandler):
if key not in target:
target[key] = 0
target[key] += value
add_dicts(self.total_usage, usage)
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Called before the run starts."""
self.total_usage = {}
async def on_usage(self, usage: Dict[str, Any]) -> None:
"""Called when usage information is received."""
self._update_usage(usage)
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
async def on_run_end(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> None:
"""Called after the run ends."""
def format_dict(d, indent=0):
lines = []
prefix = f" - {' ' * indent}"
@@ -106,10 +113,10 @@ class LoggingCallback(AsyncCallbackHandler):
else:
lines.append(f"{prefix}{key}: {value}")
return lines
formatted_output = "\n".join(format_dict(self.total_usage))
self.logger.info(f"Total usage:\n{formatted_output}")
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Called before LLM processing starts."""
if self.logger.isEnabledFor(logging.INFO):
@@ -118,27 +125,27 @@ class LoggingCallback(AsyncCallbackHandler):
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
self.logger.debug(f"LLM input messages: {json.dumps(sanitized_messages, indent=2)}")
return messages
async def on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Called after LLM processing ends."""
if self.logger.isEnabledFor(logging.DEBUG):
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
self.logger.debug(f"LLM output: {json.dumps(sanitized_messages, indent=2)}")
return messages
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a computer call starts."""
action = item.get("action", {})
action_type = action.get("type", "unknown")
action_args = {k: v for k, v in action.items() if k != "type"}
# INFO level logging for the action
self.logger.info(f"Computer: {action_type}({action_args})")
# DEBUG level logging for full details
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Computer call started: {json.dumps(action, indent=2)}")
async def on_computer_call_end(self, item: Dict[str, Any], result: Any) -> None:
"""Called when a computer call ends."""
if self.logger.isEnabledFor(logging.DEBUG):
@@ -147,48 +154,52 @@ class LoggingCallback(AsyncCallbackHandler):
if result:
sanitized_result = sanitize_image_urls(result)
self.logger.debug(f"Computer call result: {json.dumps(sanitized_result, indent=2)}")
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a function call starts."""
name = item.get("name", "unknown")
arguments = item.get("arguments", "{}")
# INFO level logging for the function call
self.logger.info(f"Function: {name}({arguments})")
# DEBUG level logging for full details
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Function call started: {name}")
async def on_function_call_end(self, item: Dict[str, Any], result: Any) -> None:
"""Called when a function call ends."""
# INFO level logging for function output (similar to function_call_output)
if result:
# Handle both list and direct result formats
if isinstance(result, list) and len(result) > 0:
output = result[0].get("output", str(result)) if isinstance(result[0], dict) else str(result[0])
output = (
result[0].get("output", str(result))
if isinstance(result[0], dict)
else str(result[0])
)
else:
output = str(result)
# Truncate long outputs
if len(output) > 100:
output = output[:100] + "..."
self.logger.info(f"Output: {output}")
# DEBUG level logging for full details
if self.logger.isEnabledFor(logging.DEBUG):
name = item.get("name", "unknown")
self.logger.debug(f"Function call completed: {name}")
if result:
self.logger.debug(f"Function call result: {json.dumps(result, indent=2)}")
async def on_text(self, item: Dict[str, Any]) -> None:
"""Called when a text message is encountered."""
# Get the role to determine if it's Agent or User
role = item.get("role", "unknown")
content_items = item.get("content", [])
# Process content items to build display text
text_parts = []
for content_item in content_items:
@@ -206,10 +217,10 @@ class LoggingCallback(AsyncCallbackHandler):
else:
# Non-text content, show as [type]
text_parts.append(f"[{content_type}]")
# Join all text parts
display_text = ''.join(text_parts) if text_parts else "[empty]"
display_text = "".join(text_parts) if text_parts else "[empty]"
# Log with appropriate level and format
if role == "assistant":
self.logger.info(f"Agent: {display_text}")
@@ -219,7 +230,7 @@ class LoggingCallback(AsyncCallbackHandler):
# Fallback for unknown roles, use debug level
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Text message ({role}): {display_text}")
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
"""Called when an API call is about to start."""
if self.logger.isEnabledFor(logging.DEBUG):
@@ -232,16 +243,18 @@ class LoggingCallback(AsyncCallbackHandler):
elif "input" in kwargs:
sanitized_input = sanitize_image_urls(kwargs["input"])
self.logger.debug(f"API call input: {json.dumps(sanitized_input, indent=2)}")
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
"""Called when an API call has completed."""
if self.logger.isEnabledFor(logging.DEBUG):
model = kwargs.get("model", "unknown")
self.logger.debug(f"API call completed for model: {model}")
self.logger.debug(f"API call result: {json.dumps(sanitize_image_urls(result), indent=2)}")
self.logger.debug(
f"API call result: {json.dumps(sanitize_image_urls(result), indent=2)}"
)
async def on_screenshot(self, item: Union[str, bytes], name: str = "screenshot") -> None:
"""Called when a screenshot is taken."""
if self.logger.isEnabledFor(logging.DEBUG):
image_size = len(item) / 1024
self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB")
self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB")

View File

@@ -9,6 +9,7 @@ Ensures agent output actions conform to expected schemas by fixing common issues
This runs in on_llm_end, which receives the output array (AgentMessage[] as dicts).
The purpose is to avoid spending another LLM call to fix broken computer call syntax when possible.
"""
from __future__ import annotations
from typing import Any, Dict, List
@@ -48,6 +49,7 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
action["type"] = "type"
action_type = action.get("type")
def _keep_keys(action: Dict[str, Any], keys_to_keep: List[str]):
"""Keep only the provided keys on action; delete everything else.
Always ensures required 'type' is present if listed in keys_to_keep.
@@ -55,6 +57,7 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
for key in list(action.keys()):
if key not in keys_to_keep:
del action[key]
# rename "coordinate" to "x", "y"
if "coordinate" in action:
action["x"] = action["coordinate"][0]
@@ -100,7 +103,6 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
keep = required_keys_by_type.get(action_type or "")
if keep:
_keep_keys(action, keep)
# # Second pass: if an assistant message is immediately followed by a computer_call,
# # replace the assistant message itself with a reasoning message with summary text.

View File

@@ -2,38 +2,41 @@
PII anonymization callback handler using Microsoft Presidio for text and image redaction.
"""
from typing import List, Dict, Any, Optional, Tuple
from .base import AsyncCallbackHandler
import base64
import io
import logging
from typing import Any, Dict, List, Optional, Tuple
from .base import AsyncCallbackHandler
try:
# TODO: Add Presidio dependencies
from PIL import Image
PRESIDIO_AVAILABLE = True
except ImportError:
PRESIDIO_AVAILABLE = False
logger = logging.getLogger(__name__)
class PIIAnonymizationCallback(AsyncCallbackHandler):
"""
Callback handler that anonymizes PII in text and images using Microsoft Presidio.
This handler:
1. Anonymizes PII in messages before sending to the agent loop
2. Deanonymizes PII in tool calls and message outputs after the agent loop
3. Redacts PII from images in computer_call_output messages
"""
def __init__(
self,
# TODO: Any extra kwargs if needed
):
"""
Initialize the PII anonymization callback.
Args:
anonymize_text: Whether to anonymize text content
anonymize_images: Whether to redact images
@@ -46,16 +49,16 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
"Presidio is not available. Install with: "
"pip install cua-agent[pii-anonymization]"
)
# TODO: Implement __init__
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Anonymize PII in messages before sending to agent loop.
Args:
messages: List of message dictionaries
Returns:
List of messages with PII anonymized
"""
@@ -63,16 +66,16 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
for msg in messages:
anonymized_msg = await self._anonymize_message(msg)
anonymized_messages.append(anonymized_msg)
return anonymized_messages
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Deanonymize PII in tool calls and message outputs after agent loop.
Args:
output: List of output dictionaries
Returns:
List of output with PII deanonymized for tool calls
"""
@@ -84,13 +87,13 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
deanonymized_output.append(deanonymized_item)
else:
deanonymized_output.append(item)
return deanonymized_output
async def _anonymize_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
# TODO: Implement _anonymize_message
return message
async def _deanonymize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
# TODO: Implement _deanonymize_item
return item

View File

@@ -2,17 +2,17 @@
Telemetry callback handler for Computer-Use Agent (cua-agent)
"""
import platform
import time
import uuid
from typing import List, Dict, Any, Optional, Union
from typing import Any, Dict, List, Optional, Union
from .base import AsyncCallbackHandler
from core.telemetry import (
record_event,
is_telemetry_enabled,
record_event,
)
import platform
from .base import AsyncCallbackHandler
SYSTEM_INFO = {
"os": platform.system().lower(),
@@ -20,32 +20,29 @@ SYSTEM_INFO = {
"python_version": platform.python_version(),
}
class TelemetryCallback(AsyncCallbackHandler):
"""
Telemetry callback handler for Computer-Use Agent (cua-agent)
Tracks agent usage, performance metrics, and optionally trajectory data.
"""
def __init__(
self,
agent,
log_trajectory: bool = False
):
def __init__(self, agent, log_trajectory: bool = False):
"""
Initialize telemetry callback.
Args:
agent: The ComputerAgent instance
log_trajectory: Whether to log full trajectory items (opt-in)
"""
self.agent = agent
self.log_trajectory = log_trajectory
# Generate session/run IDs
self.session_id = str(uuid.uuid4())
self.run_id = None
# Track timing and metrics
self.run_start_time = None
self.step_count = 0
@@ -54,126 +51,133 @@ class TelemetryCallback(AsyncCallbackHandler):
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"response_cost": 0.0
"response_cost": 0.0,
}
# Record agent initialization
if is_telemetry_enabled():
self._record_agent_initialization()
def _record_agent_initialization(self) -> None:
"""Record agent type/model and session initialization."""
agent_info = {
"session_id": self.session_id,
"agent_type": self.agent.agent_loop.__name__ if hasattr(self.agent, 'agent_loop') else 'unknown',
"model": getattr(self.agent, 'model', 'unknown'),
**SYSTEM_INFO
"agent_type": (
self.agent.agent_loop.__name__ if hasattr(self.agent, "agent_loop") else "unknown"
),
"model": getattr(self.agent, "model", "unknown"),
**SYSTEM_INFO,
}
record_event("agent_session_start", agent_info)
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Called at the start of an agent run loop."""
if not is_telemetry_enabled():
return
self.run_id = str(uuid.uuid4())
self.run_start_time = time.time()
self.step_count = 0
# Calculate input context size
input_context_size = self._calculate_context_size(old_items)
run_data = {
"session_id": self.session_id,
"run_id": self.run_id,
"start_time": self.run_start_time,
"input_context_size": input_context_size,
"num_existing_messages": len(old_items)
"num_existing_messages": len(old_items),
}
# Log trajectory if opted in
if self.log_trajectory:
trajectory = self._extract_trajectory(old_items)
if trajectory:
run_data["uploaded_trajectory"] = trajectory
record_event("agent_run_start", run_data)
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
async def on_run_end(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> None:
"""Called at the end of an agent run loop."""
if not is_telemetry_enabled() or not self.run_start_time:
return
run_duration = time.time() - self.run_start_time
run_data = {
"session_id": self.session_id,
"run_id": self.run_id,
"end_time": time.time(),
"duration_seconds": run_duration,
"num_steps": self.step_count,
"total_usage": self.total_usage.copy()
"total_usage": self.total_usage.copy(),
}
# Log trajectory if opted in
if self.log_trajectory:
trajectory = self._extract_trajectory(new_items)
if trajectory:
run_data["uploaded_trajectory"] = trajectory
record_event("agent_run_end", run_data)
async def on_usage(self, usage: Dict[str, Any]) -> None:
"""Called when usage information is received."""
if not is_telemetry_enabled():
return
# Accumulate usage stats
self.total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
self.total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
self.total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
self.total_usage["total_tokens"] += usage.get("total_tokens", 0)
self.total_usage["response_cost"] += usage.get("response_cost", 0.0)
# Record individual usage event
usage_data = {
"session_id": self.session_id,
"run_id": self.run_id,
"step": self.step_count,
**usage
**usage,
}
record_event("agent_usage", usage_data)
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
"""Called when responses are received."""
if not is_telemetry_enabled():
return
self.step_count += 1
step_duration = None
if self.step_start_time:
step_duration = time.time() - self.step_start_time
self.step_start_time = time.time()
step_data = {
"session_id": self.session_id,
"run_id": self.run_id,
"step": self.step_count,
"timestamp": self.step_start_time
"timestamp": self.step_start_time,
}
if step_duration is not None:
step_data["duration_seconds"] = step_duration
record_event("agent_step", step_data)
def _calculate_context_size(self, items: List[Dict[str, Any]]) -> int:
"""Calculate approximate context size in tokens/characters."""
total_size = 0
for item in items:
if item.get("type") == "message" and "content" in item:
content = item["content"]
@@ -185,25 +189,27 @@ class TelemetryCallback(AsyncCallbackHandler):
total_size += len(part["text"])
elif "content" in item and isinstance(item["content"], str):
total_size += len(item["content"])
return total_size
def _extract_trajectory(self, items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Extract trajectory items that should be logged."""
trajectory = []
for item in items:
# Include user messages, assistant messages, reasoning, computer calls, and computer outputs
if (
item.get("role") == "user" or # User inputs
(item.get("type") == "message" and item.get("role") == "assistant") or # Model outputs
item.get("type") == "reasoning" or # Reasoning traces
item.get("type") == "computer_call" or # Computer actions
item.get("type") == "computer_call_output" # Computer outputs
item.get("role") == "user" # User inputs
or (
item.get("type") == "message" and item.get("role") == "assistant"
) # Model outputs
or item.get("type") == "reasoning" # Reasoning traces
or item.get("type") == "computer_call" # Computer actions
or item.get("type") == "computer_call_output" # Computer outputs
):
# Create a copy of the item with timestamp
trajectory_item = item.copy()
trajectory_item["logged_at"] = time.time()
trajectory.append(trajectory_item)
return trajectory
return trajectory

View File

@@ -2,26 +2,28 @@
Trajectory saving callback handler for ComputerAgent.
"""
import os
import json
import uuid
from datetime import datetime
import base64
from pathlib import Path
from typing import List, Dict, Any, Optional, Union, override
from PIL import Image, ImageDraw
import io
import json
import os
import uuid
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, override
from PIL import Image, ImageDraw
from .base import AsyncCallbackHandler
def sanitize_image_urls(data: Any) -> Any:
"""
Recursively search for 'image_url' keys and set their values to '[omitted]'.
Args:
data: Any data structure (dict, list, or primitive type)
Returns:
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
"""
@@ -35,17 +37,19 @@ def sanitize_image_urls(data: Any) -> Any:
# Recursively sanitize the value
sanitized[key] = sanitize_image_urls(value)
return sanitized
elif isinstance(data, list):
# Recursively sanitize each item in the list
return [sanitize_image_urls(item) for item in data]
else:
# For primitive types (str, int, bool, None, etc.), return as-is
return data
def extract_computer_call_outputs(items: List[Dict[str, Any]], screenshot_dir: Optional[Path]) -> List[Dict[str, Any]]:
def extract_computer_call_outputs(
items: List[Dict[str, Any]], screenshot_dir: Optional[Path]
) -> List[Dict[str, Any]]:
"""
Save any base64-encoded screenshots from computer_call_output entries to files and
replace their image_url with the saved file path when a call_id is present.
@@ -103,18 +107,21 @@ def extract_computer_call_outputs(items: List[Dict[str, Any]], screenshot_dir: O
updated.append(msg)
return updated
class TrajectorySaverCallback(AsyncCallbackHandler):
"""
Callback handler that saves agent trajectories to disk.
Saves each run as a separate trajectory with unique ID, and each turn
within the trajectory gets its own folder with screenshots and responses.
"""
def __init__(self, trajectory_dir: str, reset_on_run: bool = True, screenshot_dir: Optional[str] = None):
def __init__(
self, trajectory_dir: str, reset_on_run: bool = True, screenshot_dir: Optional[str] = None
):
"""
Initialize trajectory saver.
Args:
trajectory_dir: Base directory to save trajectories
reset_on_run: If True, reset trajectory_id/turn/artifact on each run.
@@ -129,7 +136,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
self.reset_on_run = reset_on_run
# Optional directory to store extracted screenshots from metadata/new_items
self.screenshot_dir: Optional[Path] = Path(screenshot_dir) if screenshot_dir else None
# Ensure trajectory directory exists
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
@@ -137,7 +144,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
"""Get the directory for the current turn."""
if not self.trajectory_id:
raise ValueError("Trajectory not initialized - call _on_run_start first")
# format: trajectory_id/turn_000
turn_dir = self.trajectory_dir / self.trajectory_id / f"turn_{self.current_turn:03d}"
turn_dir.mkdir(parents=True, exist_ok=True)
@@ -166,6 +173,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
def _update_usage(self, usage: Dict[str, Any]) -> None:
"""Update total usage statistics."""
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
for key, value in source.items():
if isinstance(value, dict):
@@ -176,20 +184,21 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
if key not in target:
target[key] = 0
target[key] += value
add_dicts(self.total_usage, usage)
@override
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Initialize trajectory tracking for a new run."""
model = kwargs.get("model", "unknown")
# Only reset trajectory state if reset_on_run is True or no trajectory exists
if self.reset_on_run or not self.trajectory_id:
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
if "+" in model:
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
# strip non-alphanumeric characters from model_name_short
model_name_short = ''.join(c for c in model_name_short if c.isalnum() or c == '_')
model_name_short = "".join(c for c in model_name_short if c.isalnum() or c == "_")
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
now = datetime.now()
@@ -198,11 +207,11 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
self.current_artifact = 0
self.model = model
self.total_usage = {}
# Create trajectory directory
trajectory_path = self.trajectory_dir / self.trajectory_id
trajectory_path.mkdir(parents=True, exist_ok=True)
# Save trajectory metadata (optionally extract screenshots to screenshot_dir)
kwargs_to_save = kwargs.copy()
try:
@@ -219,7 +228,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
"status": "running",
"kwargs": kwargs_to_save,
}
with open(trajectory_path / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
else:
@@ -227,22 +236,27 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
self.model = model
@override
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
async def on_run_end(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> None:
"""Finalize run tracking by updating metadata with completion status, usage, and new items."""
if not self.trajectory_id:
return
# Update metadata with completion status, total usage, and new items
trajectory_path = self.trajectory_dir / self.trajectory_id
metadata_path = trajectory_path / "metadata.json"
# Read existing metadata
if metadata_path.exists():
with open(metadata_path, "r") as f:
metadata = json.load(f)
else:
metadata = {}
# Update metadata with completion info
# Optionally extract screenshots from new_items before persisting
new_items_to_save = new_items
@@ -251,32 +265,34 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
except Exception:
pass
metadata.update({
"status": "completed",
"completed_at": str(uuid.uuid1().time),
"total_usage": self.total_usage,
"new_items": new_items_to_save,
"total_turns": self.current_turn
})
metadata.update(
{
"status": "completed",
"completed_at": str(uuid.uuid1().time),
"total_usage": self.total_usage,
"new_items": new_items_to_save,
"total_turns": self.current_turn,
}
)
# Save updated metadata
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
@override
@override
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
if not self.trajectory_id:
return
self._save_artifact("api_start", { "kwargs": kwargs })
self._save_artifact("api_start", {"kwargs": kwargs})
@override
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
"""Save API call result."""
if not self.trajectory_id:
return
self._save_artifact("api_result", { "kwargs": kwargs, "result": result })
self._save_artifact("api_result", {"kwargs": kwargs, "result": result})
@override
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
@@ -295,77 +311,83 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
"""Save responses to the current turn directory and update usage statistics."""
if not self.trajectory_id:
return
# Save responses
turn_dir = self._get_turn_dir()
response_data = {
"timestamp": str(uuid.uuid1().time),
"model": self.model,
"kwargs": kwargs,
"response": responses
"response": responses,
}
self._save_artifact("agent_response", response_data)
# Increment turn counter
self.current_turn += 1
def _draw_crosshair_on_image(self, image_bytes: bytes, x: int, y: int) -> bytes:
"""
Draw a red dot and crosshair at the specified coordinates on the image.
Args:
image_bytes: The original image as bytes
x: X coordinate for the crosshair
y: Y coordinate for the crosshair
Returns:
Modified image as bytes with red dot and crosshair
"""
# Open the image
image = Image.open(io.BytesIO(image_bytes))
draw = ImageDraw.Draw(image)
# Draw crosshair lines (red, 2px thick)
crosshair_size = 20
line_width = 2
color = "red"
# Horizontal line
draw.line([(x - crosshair_size, y), (x + crosshair_size, y)], fill=color, width=line_width)
# Vertical line
draw.line([(x, y - crosshair_size), (x, y + crosshair_size)], fill=color, width=line_width)
# Draw center dot (filled circle)
dot_radius = 3
draw.ellipse([(x - dot_radius, y - dot_radius), (x + dot_radius, y + dot_radius)], fill=color)
draw.ellipse(
[(x - dot_radius, y - dot_radius), (x + dot_radius, y + dot_radius)], fill=color
)
# Convert back to bytes
output = io.BytesIO()
image.save(output, format='PNG')
image.save(output, format="PNG")
return output.getvalue()
@override
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
async def on_computer_call_end(
self, item: Dict[str, Any], result: List[Dict[str, Any]]
) -> None:
"""
Called when a computer call has completed.
Saves screenshots and computer call output.
"""
if not self.trajectory_id:
return
self._save_artifact("computer_call_result", { "item": item, "result": result })
self._save_artifact("computer_call_result", {"item": item, "result": result})
# Check if action has x/y coordinates and there's a screenshot in the result
action = item.get("action", {})
if "x" in action and "y" in action:
# Look for screenshot in the result
for result_item in result:
if (result_item.get("type") == "computer_call_output" and
result_item.get("output", {}).get("type") == "input_image"):
if (
result_item.get("type") == "computer_call_output"
and result_item.get("output", {}).get("type") == "input_image"
):
image_url = result_item["output"]["image_url"]
# Extract base64 image data
if image_url.startswith("data:image/"):
# Format: data:image/png;base64,<base64_data>
@@ -373,26 +395,24 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
else:
# Assume it's just base64 data
base64_data = image_url
try:
# Decode the image
image_bytes = base64.b64decode(base64_data)
# Draw crosshair at the action coordinates
annotated_image = self._draw_crosshair_on_image(
image_bytes,
int(action["x"]),
int(action["y"])
image_bytes, int(action["x"]), int(action["y"])
)
# Save as screenshot_action
self._save_artifact("screenshot_action", annotated_image)
except Exception as e:
# If annotation fails, just log and continue
print(f"Failed to annotate screenshot: {e}")
break # Only process the first screenshot found
# Increment turn counter
self.current_turn += 1
self.current_turn += 1

View File

@@ -3,7 +3,7 @@ CLI chat interface for agent - Computer Use Agent
Usage:
python -m agent.cli <model_string>
Examples:
python -m agent.cli openai/computer-use-preview
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
@@ -11,19 +11,22 @@ Examples:
"""
try:
import asyncio
import argparse
import os
import sys
import json
from typing import List, Dict, Any
import dotenv
import asyncio
import base64
import time
import json
import os
import platform
import sys
import time
from pathlib import Path
from typing import Any, Dict, List
import dotenv
try:
from PIL import Image, ImageDraw
PIL_AVAILABLE = True
except Exception:
PIL_AVAILABLE = False
@@ -31,36 +34,44 @@ try:
except ImportError:
if __name__ == "__main__":
raise ImportError(
"CLI dependencies not found. "
"Please install with: pip install \"cua-agent[cli]\""
"CLI dependencies not found. " 'Please install with: pip install "cua-agent[cli]"'
)
# Load environment variables
dotenv.load_dotenv()
# Color codes for terminal output
class Colors:
RESET = '\033[0m'
BOLD = '\033[1m'
DIM = '\033[2m'
# Text colors
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
MAGENTA = '\033[35m'
CYAN = '\033[36m'
WHITE = '\033[37m'
GRAY = '\033[90m'
# Background colors
BG_RED = '\033[41m'
BG_GREEN = '\033[42m'
BG_YELLOW = '\033[43m'
BG_BLUE = '\033[44m'
RESET = "\033[0m"
BOLD = "\033[1m"
DIM = "\033[2m"
def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = False, end: str = "\n", right: str = ""):
# Text colors
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
MAGENTA = "\033[35m"
CYAN = "\033[36m"
WHITE = "\033[37m"
GRAY = "\033[90m"
# Background colors
BG_RED = "\033[41m"
BG_GREEN = "\033[42m"
BG_YELLOW = "\033[43m"
BG_BLUE = "\033[44m"
def print_colored(
text: str,
color: str = "",
bold: bool = False,
dim: bool = False,
end: str = "\n",
right: str = "",
):
"""Print colored text to terminal with optional right-aligned text."""
prefix = ""
if bold:
@@ -69,24 +80,25 @@ def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = Fa
prefix += Colors.DIM
if color:
prefix += color
if right:
# Get terminal width (default to 80 if unable to determine)
try:
import shutil
terminal_width = shutil.get_terminal_size().columns
except:
terminal_width = 80
# Add right margin
terminal_width -= 1
# Calculate padding needed
# Account for ANSI escape codes not taking visual space
visible_left_len = len(text)
visible_right_len = len(right)
padding = terminal_width - visible_left_len - visible_right_len
if padding > 0:
output = f"{prefix}{text}{' ' * padding}{right}{Colors.RESET}"
else:
@@ -94,7 +106,7 @@ def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = Fa
output = f"{prefix}{text} {right}{Colors.RESET}"
else:
output = f"{prefix}{text}{Colors.RESET}"
print(output, end=end)
@@ -113,29 +125,34 @@ def print_action(action_type: str, details: Dict[str, Any], total_cost: float):
args_str = f"('{details['text']}')"
elif action_type == "scroll" and "x" in details and "y" in details:
args_str = f"({details['x']}, {details['y']})"
if total_cost > 0:
print_colored(f"🛠️ {action_type}{args_str}", dim=True, right=f"💸 ${total_cost:.2f}")
else:
print_colored(f"🛠️ {action_type}{args_str}", dim=True)
def print_welcome(model: str, agent_loop: str, container_name: str):
"""Print welcome message."""
print_colored(f"Connected to {container_name} ({model}, {agent_loop})")
print_colored("Type 'exit' to quit.", dim=True)
async def ainput(prompt: str = ""):
return await asyncio.to_thread(input, prompt)
async def chat_loop(agent, model: str, container_name: str, initial_prompt: str = "", show_usage: bool = True):
async def chat_loop(
agent, model: str, container_name: str, initial_prompt: str = "", show_usage: bool = True
):
"""Main chat loop with the agent."""
print_welcome(model, agent.agent_config_info.agent_class.__name__, container_name)
history = []
if initial_prompt:
history.append({"role": "user", "content": initial_prompt})
total_cost = 0
while True:
@@ -143,31 +160,31 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
# Get user input with prompt
print_colored("> ", end="")
user_input = await ainput()
if user_input.lower() in ['exit', 'quit', 'q']:
if user_input.lower() in ["exit", "quit", "q"]:
print_colored("\n👋 Goodbye!")
break
if not user_input:
continue
# Add user message to history
history.append({"role": "user", "content": user_input})
# Stream responses from the agent with spinner
with yaspin(text="Thinking...", spinner="line", attrs=["dark"]) as spinner:
spinner.hide()
async for result in agent.run(history):
# Add agent responses to history
history.extend(result.get("output", []))
if show_usage:
total_cost += result.get("usage", {}).get("response_cost", 0)
# Process and display the output
for item in result.get("output", []):
if item.get("type") == "message":
if item.get("type") == "message" and item.get("role") == "assistant":
# Display agent text response
content = item.get("content", [])
for content_part in content:
@@ -176,7 +193,7 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
if text:
spinner.hide()
print_colored(text)
elif item.get("type") == "computer_call":
# Display computer action
action = item.get("action", {})
@@ -186,7 +203,7 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
print_action(action_type, action, total_cost)
spinner.text = f"Performing {action_type}..."
spinner.show()
elif item.get("type") == "function_call":
# Display function call
function_name = item.get("name", "")
@@ -194,18 +211,18 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
print_colored(f"🔧 Calling function: {function_name}", dim=True)
spinner.text = f"Calling {function_name}..."
spinner.show()
elif item.get("type") == "function_call_output":
# Display function output (dimmed)
output = item.get("output", "")
if output and len(output.strip()) > 0:
spinner.hide()
print_colored(f"📤 {output}", dim=True)
spinner.hide()
if show_usage and total_cost > 0:
print_colored(f"Total cost: ${total_cost:.2f}", dim=True)
async def main():
"""Main CLI function."""
@@ -218,104 +235,103 @@ Examples:
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
python -m agent.cli huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B
"""
""",
)
parser.add_argument(
"model",
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')"
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')",
)
parser.add_argument(
"--provider",
choices=["cloud", "lume", "winsandbox", "docker"],
default="cloud",
help="Computer provider to use: cloud (default), lume, winsandbox, or docker",
)
parser.add_argument(
"--images",
type=int,
default=3,
help="Number of recent images to keep in context (default: 3)"
help="Number of recent images to keep in context (default: 3)",
)
parser.add_argument("--trajectory", action="store_true", help="Save trajectory for debugging")
parser.add_argument("--budget", type=float, help="Maximum budget for the session (in dollars)")
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
parser.add_argument(
"--trajectory",
action="store_true",
help="Save trajectory for debugging"
)
parser.add_argument(
"--budget",
type=float,
help="Maximum budget for the session (in dollars)"
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose logging"
"-p",
"--prompt",
type=str,
help="Initial prompt to send to the agent. Leave blank for interactive mode.",
)
parser.add_argument(
"-p", "--prompt",
type=str,
help="Initial prompt to send to the agent. Leave blank for interactive mode."
"--prompt-file",
type=Path,
help="Path to a UTF-8 text file whose contents will be used as the initial prompt. If provided, overrides --prompt.",
)
parser.add_argument(
"--predict-click",
dest="predict_click",
type=str,
help="Instruction for click prediction. If set, runs predict_click, draws crosshair on a fresh screenshot, saves and opens it."
help="Instruction for click prediction. If set, runs predict_click, draws crosshair on a fresh screenshot, saves and opens it.",
)
parser.add_argument("-c", "--cache", action="store_true", help="Tell the API to enable caching")
parser.add_argument(
"-u", "--usage", action="store_true", help="Show total cost of the agent runs"
)
parser.add_argument(
"-c", "--cache",
action="store_true",
help="Tell the API to enable caching"
)
parser.add_argument(
"-u", "--usage",
action="store_true",
help="Show total cost of the agent runs"
)
parser.add_argument(
"-r", "--max-retries",
"-r",
"--max-retries",
type=int,
default=3,
help="Maximum number of retries for the LLM API calls"
help="Maximum number of retries for the LLM API calls",
)
args = parser.parse_args()
# Check for required environment variables
container_name = os.getenv("CUA_CONTAINER_NAME")
cua_api_key = os.getenv("CUA_API_KEY")
# Prompt for missing environment variables
# Prompt for missing environment variables (container name always required)
if not container_name:
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
container_name = input("Enter your CUA container name: ").strip()
if not container_name:
print_colored("❌ Container name is required.")
sys.exit(1)
if not cua_api_key:
if args.provider == "cloud":
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
container_name = input("Enter your CUA container name: ").strip()
if not container_name:
print_colored("❌ Container name is required.")
sys.exit(1)
else:
container_name = "cli-sandbox"
# Only require API key for cloud provider
if args.provider == "cloud" and not cua_api_key:
print_colored("CUA_API_KEY not set.", dim=True)
cua_api_key = input("Enter your CUA API key: ").strip()
if not cua_api_key:
print_colored("❌ API key is required.")
print_colored("❌ API key is required for cloud provider.")
sys.exit(1)
# Check for provider-specific API keys based on model
provider_api_keys = {
"openai/": "OPENAI_API_KEY",
"anthropic/": "ANTHROPIC_API_KEY",
"omniparser+": "OPENAI_API_KEY",
"omniparser+": "ANTHROPIC_API_KEY",
}
# Find matching provider and check for API key
for prefix, env_var in provider_api_keys.items():
if args.model.startswith(prefix):
if prefix in args.model:
if not os.getenv(env_var):
print_colored(f"{env_var} not set.", dim=True)
api_key = input(f"Enter your {env_var.replace('_', ' ').title()}: ").strip()
@@ -325,7 +341,7 @@ Examples:
# Set the environment variable for the session
os.environ[env_var] = api_key
break
# Import here to avoid import errors if dependencies are missing
try:
from agent import ComputerAgent
@@ -334,46 +350,62 @@ Examples:
print_colored(f"❌ Import error: {e}", Colors.RED, bold=True)
print_colored("Make sure agent and computer libraries are installed.", Colors.YELLOW)
sys.exit(1)
# Resolve provider -> os_type, provider_type, api key requirement
provider_map = {
"cloud": ("linux", "cloud", True),
"lume": ("macos", "lume", False),
"winsandbox": ("windows", "winsandbox", False),
"docker": ("linux", "docker", False),
}
os_type, provider_type, needs_api_key = provider_map[args.provider]
computer_kwargs = {
"os_type": os_type,
"provider_type": provider_type,
"name": container_name,
}
if needs_api_key:
computer_kwargs["api_key"] = cua_api_key # type: ignore
# Create computer instance
async with Computer(
os_type="linux",
provider_type="cloud",
name=container_name,
api_key=cua_api_key
) as computer:
async with Computer(**computer_kwargs) as computer: # type: ignore
# Create agent
agent_kwargs = {
"model": args.model,
"tools": [computer],
"trust_remote_code": True, # needed for some local models (e.g., InternVL, OpenCUA)
"trust_remote_code": True, # needed for some local models (e.g., InternVL, OpenCUA)
"verbosity": 20 if args.verbose else 30, # DEBUG vs WARNING
"max_retries": args.max_retries
"max_retries": args.max_retries,
}
if args.images > 0:
agent_kwargs["only_n_most_recent_images"] = args.images
if args.trajectory:
agent_kwargs["trajectory_dir"] = "trajectories"
if args.budget:
agent_kwargs["max_trajectory_budget"] = {
"max_budget": args.budget,
"raise_error": True,
"reset_after_each_run": False
"reset_after_each_run": False,
}
if args.cache:
agent_kwargs["use_prompt_caching"] = True
agent = ComputerAgent(**agent_kwargs)
# If predict-click mode is requested, run once and exit
if args.predict_click:
if not PIL_AVAILABLE:
print_colored("❌ Pillow (PIL) is required for --predict-click visualization. Install with: pip install pillow", Colors.RED, bold=True)
print_colored(
"❌ Pillow (PIL) is required for --predict-click visualization. Install with: pip install pillow",
Colors.RED,
bold=True,
)
sys.exit(1)
instruction = args.predict_click
@@ -408,6 +440,7 @@ Examples:
try:
from io import BytesIO
with Image.open(BytesIO(img_bytes)) as img:
img = img.convert("RGB")
draw = ImageDraw.Draw(img)
@@ -430,9 +463,9 @@ Examples:
if system == "windows":
os.startfile(str(out_path)) # type: ignore[attr-defined]
elif system == "darwin":
os.system(f"open \"{out_path}\"")
os.system(f'open "{out_path}"')
else:
os.system(f"xdg-open \"{out_path}\"")
os.system(f'xdg-open "{out_path}"')
except Exception:
pass
except Exception as e:
@@ -442,13 +475,21 @@ Examples:
# Done
sys.exit(0)
# Start chat loop (default interactive mode)
await chat_loop(agent, args.model, container_name, args.prompt, args.usage)
# Resolve initial prompt from --prompt-file or --prompt
initial_prompt = args.prompt or ""
if args.prompt_file:
try:
initial_prompt = args.prompt_file.read_text(encoding="utf-8")
except Exception as e:
print_colored(f"❌ Failed to read --prompt-file: {e}", Colors.RED, bold=True)
sys.exit(1)
# Start chat loop (default interactive mode)
await chat_loop(agent, args.model, container_name, initial_prompt, args.usage)
if __name__ == "__main__":
try:
asyncio.run(main())
except (KeyboardInterrupt, EOFError) as _:
print_colored("\n\n👋 Goodbye!")
print_colored("\n\n👋 Goodbye!")

View File

@@ -6,27 +6,32 @@ computer interface types, supporting both the ComputerHandler protocol and the
Computer library interface.
"""
from computer import Computer as cuaComputer
from .base import AsyncComputerHandler
from .cua import cuaComputerHandler
from .custom import CustomComputerHandler
from computer import Computer as cuaComputer
def is_agent_computer(computer):
"""Check if the given computer is a ComputerHandler or CUA Computer."""
return isinstance(computer, AsyncComputerHandler) or \
isinstance(computer, cuaComputer) or \
(isinstance(computer, dict)) #and "screenshot" in computer)
return (
isinstance(computer, AsyncComputerHandler)
or isinstance(computer, cuaComputer)
or (isinstance(computer, dict))
) # and "screenshot" in computer)
async def make_computer_handler(computer):
"""
Create a computer handler from a computer interface.
Args:
computer: Either a ComputerHandler instance, Computer instance, or dict of functions
Returns:
ComputerHandler: A computer handler instance
Raises:
ValueError: If the computer type is not supported
"""
@@ -38,4 +43,4 @@ async def make_computer_handler(computer):
return computer_handler
if isinstance(computer, dict):
return CustomComputerHandler(computer)
raise ValueError(f"Unsupported computer type: {type(computer)}")
raise ValueError(f"Unsupported computer type: {type(computer)}")

View File

@@ -2,23 +2,32 @@
Base computer interface protocol for agent interactions.
"""
from typing import Protocol, Literal, List, Dict, Any, Union, Optional, runtime_checkable
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
)
@runtime_checkable
class AsyncComputerHandler(Protocol):
"""Protocol defining the interface for computer interactions."""
# ==== Computer-Use-Preview Action Space ====
# ==== Computer-Use-Preview Action Space ====
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
"""Get the current environment type."""
...
async def get_dimensions(self) -> tuple[int, int]:
"""Get screen dimensions as (width, height)."""
...
async def screenshot(self, text: Optional[str] = None) -> str:
"""Take a screenshot and return as base64 string.
@@ -26,49 +35,49 @@ class AsyncComputerHandler(Protocol):
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
"""
...
async def click(self, x: int, y: int, button: str = "left") -> None:
"""Click at coordinates with specified button."""
...
async def double_click(self, x: int, y: int) -> None:
"""Double click at coordinates."""
...
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
"""Scroll at coordinates with specified scroll amounts."""
...
async def type(self, text: str) -> None:
"""Type text."""
...
async def wait(self, ms: int = 1000) -> None:
"""Wait for specified milliseconds."""
...
async def move(self, x: int, y: int) -> None:
"""Move cursor to coordinates."""
...
async def keypress(self, keys: Union[List[str], str]) -> None:
"""Press key combination."""
...
async def drag(self, path: List[Dict[str, int]]) -> None:
"""Drag along specified path."""
...
async def get_current_url(self) -> str:
"""Get current URL (for browser environments)."""
...
# ==== Anthropic Action Space ====
# ==== Anthropic Action Space ====
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Left mouse down at coordinates."""
...
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Left mouse up at coordinates."""
...

View File

@@ -3,24 +3,27 @@ Computer handler implementation for OpenAI computer-use-preview protocol.
"""
import base64
from typing import Dict, List, Any, Literal, Union, Optional
from .base import AsyncComputerHandler
from typing import Any, Dict, List, Literal, Optional, Union
from computer import Computer
from .base import AsyncComputerHandler
class cuaComputerHandler(AsyncComputerHandler):
"""Computer handler that implements the Computer protocol using the computer interface."""
def __init__(self, cua_computer: Computer):
"""Initialize with a computer interface (from tool schema)."""
self.cua_computer = cua_computer
self.interface = None
async def _initialize(self):
if hasattr(self.cua_computer, '_initialized') and not self.cua_computer._initialized:
if hasattr(self.cua_computer, "_initialized") and not self.cua_computer._initialized:
await self.cua_computer.run()
self.interface = self.cua_computer.interface
# ==== Computer-Use-Preview Action Space ====
# ==== Computer-Use-Preview Action Space ====
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
"""Get the current environment type."""
@@ -32,7 +35,7 @@ class cuaComputerHandler(AsyncComputerHandler):
assert self.interface is not None
screen_size = await self.interface.get_screen_size()
return screen_size["width"], screen_size["height"]
async def screenshot(self, text: Optional[str] = None) -> str:
"""Take a screenshot and return as base64 string.
@@ -41,8 +44,8 @@ class cuaComputerHandler(AsyncComputerHandler):
"""
assert self.interface is not None
screenshot_bytes = await self.interface.screenshot()
return base64.b64encode(screenshot_bytes).decode('utf-8')
return base64.b64encode(screenshot_bytes).decode("utf-8")
async def click(self, x: int, y: int, button: str = "left") -> None:
"""Click at coordinates with specified button."""
assert self.interface is not None
@@ -53,34 +56,35 @@ class cuaComputerHandler(AsyncComputerHandler):
else:
# Default to left click for unknown buttons
await self.interface.left_click(x, y)
async def double_click(self, x: int, y: int) -> None:
"""Double click at coordinates."""
assert self.interface is not None
await self.interface.double_click(x, y)
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
"""Scroll at coordinates with specified scroll amounts."""
assert self.interface is not None
await self.interface.move_cursor(x, y)
await self.interface.scroll(scroll_x, scroll_y)
async def type(self, text: str) -> None:
"""Type text."""
assert self.interface is not None
await self.interface.type_text(text)
async def wait(self, ms: int = 1000) -> None:
"""Wait for specified milliseconds."""
assert self.interface is not None
import asyncio
await asyncio.sleep(ms / 1000.0)
async def move(self, x: int, y: int) -> None:
"""Move cursor to coordinates."""
assert self.interface is not None
await self.interface.move_cursor(x, y)
async def keypress(self, keys: Union[List[str], str]) -> None:
"""Press key combination."""
assert self.interface is not None
@@ -91,38 +95,38 @@ class cuaComputerHandler(AsyncComputerHandler):
else:
# Handle key combinations
await self.interface.hotkey(*keys)
async def drag(self, path: List[Dict[str, int]]) -> None:
"""Drag along specified path."""
assert self.interface is not None
if not path:
return
# Start drag from first point
start = path[0]
await self.interface.mouse_down(start["x"], start["y"])
# Move through path
for point in path[1:]:
await self.interface.move_cursor(point["x"], point["y"])
# End drag at last point
end = path[-1]
await self.interface.mouse_up(end["x"], end["y"])
async def get_current_url(self) -> str:
"""Get current URL (for browser environments)."""
# This would need to be implemented based on the specific browser interface
# For now, return empty string
return ""
# ==== Anthropic Computer Action Space ====
# ==== Anthropic Computer Action Space ====
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Left mouse down at coordinates."""
assert self.interface is not None
await self.interface.mouse_down(x, y, button="left")
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Left mouse up at coordinates."""
assert self.interface is not None
await self.interface.mouse_up(x, y, button="left")
await self.interface.mouse_up(x, y, button="left")

View File

@@ -3,47 +3,49 @@ Custom computer handler implementation that accepts a dictionary of functions.
"""
import base64
from typing import Dict, List, Any, Literal, Union, Optional, Callable
from PIL import Image
import io
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from PIL import Image
from .base import AsyncComputerHandler
class CustomComputerHandler(AsyncComputerHandler):
"""Computer handler that implements the Computer protocol using a dictionary of custom functions."""
def __init__(self, functions: Dict[str, Callable]):
"""
Initialize with a dictionary of functions.
Args:
functions: Dictionary where keys are method names and values are callable functions.
Only 'screenshot' is required, all others are optional.
Raises:
ValueError: If required 'screenshot' function is not provided.
"""
if 'screenshot' not in functions:
if "screenshot" not in functions:
raise ValueError("'screenshot' function is required in functions dictionary")
self.functions = functions
self._last_screenshot_size: Optional[tuple[int, int]] = None
async def _call_function(self, func, *args, **kwargs):
"""
Call a function, handling both async and sync functions.
Args:
func: The function to call
*args: Positional arguments to pass to the function
**kwargs: Keyword arguments to pass to the function
Returns:
The result of the function call
"""
import asyncio
import inspect
if callable(func):
if inspect.iscoroutinefunction(func):
return await func(*args, **kwargs)
@@ -51,14 +53,14 @@ class CustomComputerHandler(AsyncComputerHandler):
return func(*args, **kwargs)
else:
return func
async def _get_value(self, attribute: str):
"""
Get value for an attribute, checking both 'get_{attribute}' and '{attribute}' keys.
Args:
attribute: The attribute name to look for
Returns:
The value from the functions dict, called if callable, returned directly if not
"""
@@ -66,20 +68,20 @@ class CustomComputerHandler(AsyncComputerHandler):
get_key = f"get_{attribute}"
if get_key in self.functions:
return await self._call_function(self.functions[get_key])
# Check for '{attribute}'
# Check for '{attribute}'
if attribute in self.functions:
return await self._call_function(self.functions[attribute])
return None
def _to_b64_str(self, img: Union[bytes, Image.Image, str]) -> str:
"""
Convert image to base64 string.
Args:
img: Image as bytes, PIL Image, or base64 string
Returns:
str: Base64 encoded image string
"""
@@ -88,47 +90,47 @@ class CustomComputerHandler(AsyncComputerHandler):
return img
elif isinstance(img, bytes):
# Raw bytes
return base64.b64encode(img).decode('utf-8')
return base64.b64encode(img).decode("utf-8")
elif isinstance(img, Image.Image):
# PIL Image
buffer = io.BytesIO()
img.save(buffer, format='PNG')
return base64.b64encode(buffer.getvalue()).decode('utf-8')
img.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
raise ValueError(f"Unsupported image type: {type(img)}")
# ==== Computer-Use-Preview Action Space ====
# ==== Computer-Use-Preview Action Space ====
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
"""Get the current environment type."""
result = await self._get_value('environment')
result = await self._get_value("environment")
if result is None:
return "linux"
assert result in ["windows", "mac", "linux", "browser"]
return result # type: ignore
return result # type: ignore
async def get_dimensions(self) -> tuple[int, int]:
"""Get screen dimensions as (width, height)."""
result = await self._get_value('dimensions')
result = await self._get_value("dimensions")
if result is not None:
return result # type: ignore
return result # type: ignore
# Fallback: use last screenshot size if available
if not self._last_screenshot_size:
await self.screenshot()
assert self._last_screenshot_size is not None, "Failed to get screenshot size"
return self._last_screenshot_size
async def screenshot(self, text: Optional[str] = None) -> str:
"""Take a screenshot and return as base64 string.
Args:
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
"""
result = await self._call_function(self.functions['screenshot'])
b64_str = self._to_b64_str(result) # type: ignore
result = await self._call_function(self.functions["screenshot"])
b64_str = self._to_b64_str(result) # type: ignore
# Try to extract dimensions for fallback use
try:
if isinstance(result, Image.Image):
@@ -140,74 +142,75 @@ class CustomComputerHandler(AsyncComputerHandler):
except Exception:
# If we can't get dimensions, that's okay
pass
return b64_str
async def click(self, x: int, y: int, button: str = "left") -> None:
"""Click at coordinates with specified button."""
if 'click' in self.functions:
await self._call_function(self.functions['click'], x, y, button)
if "click" in self.functions:
await self._call_function(self.functions["click"], x, y, button)
# No-op if not implemented
async def double_click(self, x: int, y: int) -> None:
"""Double click at coordinates."""
if 'double_click' in self.functions:
await self._call_function(self.functions['double_click'], x, y)
if "double_click" in self.functions:
await self._call_function(self.functions["double_click"], x, y)
# No-op if not implemented
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
"""Scroll at coordinates with specified scroll amounts."""
if 'scroll' in self.functions:
await self._call_function(self.functions['scroll'], x, y, scroll_x, scroll_y)
if "scroll" in self.functions:
await self._call_function(self.functions["scroll"], x, y, scroll_x, scroll_y)
# No-op if not implemented
async def type(self, text: str) -> None:
"""Type text."""
if 'type' in self.functions:
await self._call_function(self.functions['type'], text)
if "type" in self.functions:
await self._call_function(self.functions["type"], text)
# No-op if not implemented
async def wait(self, ms: int = 1000) -> None:
"""Wait for specified milliseconds."""
if 'wait' in self.functions:
await self._call_function(self.functions['wait'], ms)
if "wait" in self.functions:
await self._call_function(self.functions["wait"], ms)
else:
# Default implementation
import asyncio
await asyncio.sleep(ms / 1000.0)
async def move(self, x: int, y: int) -> None:
"""Move cursor to coordinates."""
if 'move' in self.functions:
await self._call_function(self.functions['move'], x, y)
if "move" in self.functions:
await self._call_function(self.functions["move"], x, y)
# No-op if not implemented
async def keypress(self, keys: Union[List[str], str]) -> None:
"""Press key combination."""
if 'keypress' in self.functions:
await self._call_function(self.functions['keypress'], keys)
if "keypress" in self.functions:
await self._call_function(self.functions["keypress"], keys)
# No-op if not implemented
async def drag(self, path: List[Dict[str, int]]) -> None:
"""Drag along specified path."""
if 'drag' in self.functions:
await self._call_function(self.functions['drag'], path)
if "drag" in self.functions:
await self._call_function(self.functions["drag"], path)
# No-op if not implemented
async def get_current_url(self) -> str:
"""Get current URL (for browser environments)."""
if 'get_current_url' in self.functions:
return await self._get_value('current_url') # type: ignore
if "get_current_url" in self.functions:
return await self._get_value("current_url") # type: ignore
return "" # Default fallback
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Left mouse down at coordinates."""
if 'left_mouse_down' in self.functions:
await self._call_function(self.functions['left_mouse_down'], x, y)
if "left_mouse_down" in self.functions:
await self._call_function(self.functions["left_mouse_down"], x, y)
# No-op if not implemented
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Left mouse up at coordinates."""
if 'left_mouse_up' in self.functions:
await self._call_function(self.functions['left_mouse_up'], x, y)
if "left_mouse_up" in self.functions:
await self._call_function(self.functions["left_mouse_up"], x, y)
# No-op if not implemented

View File

@@ -3,47 +3,56 @@ Decorators for agent - agent_loop decorator
"""
from typing import List, Optional
from .types import AgentConfigInfo
# Global registry
_agent_configs: List[AgentConfigInfo] = []
def register_agent(models: str, priority: int = 0):
"""
Decorator to register an AsyncAgentConfig class.
Args:
models: Regex pattern to match supported models
priority: Priority for agent selection (higher = more priority)
"""
def decorator(agent_class: type):
# Validate that the class implements AsyncAgentConfig protocol
if not hasattr(agent_class, 'predict_step'):
raise ValueError(f"Agent class {agent_class.__name__} must implement predict_step method")
if not hasattr(agent_class, 'predict_click'):
raise ValueError(f"Agent class {agent_class.__name__} must implement predict_click method")
if not hasattr(agent_class, 'get_capabilities'):
raise ValueError(f"Agent class {agent_class.__name__} must implement get_capabilities method")
if not hasattr(agent_class, "predict_step"):
raise ValueError(
f"Agent class {agent_class.__name__} must implement predict_step method"
)
if not hasattr(agent_class, "predict_click"):
raise ValueError(
f"Agent class {agent_class.__name__} must implement predict_click method"
)
if not hasattr(agent_class, "get_capabilities"):
raise ValueError(
f"Agent class {agent_class.__name__} must implement get_capabilities method"
)
# Register the agent config
config_info = AgentConfigInfo(
agent_class=agent_class,
models_regex=models,
priority=priority
agent_class=agent_class, models_regex=models, priority=priority
)
_agent_configs.append(config_info)
# Sort by priority (highest first)
_agent_configs.sort(key=lambda x: x.priority, reverse=True)
return agent_class
return decorator
def get_agent_configs() -> List[AgentConfigInfo]:
"""Get all registered agent configs"""
return _agent_configs.copy()
def find_agent_config(model: str) -> Optional[AgentConfigInfo]:
"""Find the best matching agent config for a model"""
for config_info in _agent_configs:

View File

@@ -12,7 +12,7 @@ Components:
Usage:
# Run the server and UI
python -m agent.human_tool
# Or run components separately
python -m agent.human_tool.server # API server only
python -m agent.human_tool.ui # UI only
@@ -21,9 +21,4 @@ Usage:
from .server import CompletionQueue, completion_queue
from .ui import HumanCompletionUI, create_ui
__all__ = [
"CompletionQueue",
"completion_queue",
"HumanCompletionUI",
"create_ui"
]
__all__ = ["CompletionQueue", "completion_queue", "HumanCompletionUI", "create_ui"]

View File

@@ -8,6 +8,7 @@ with a Gradio UI for human interaction.
import gradio as gr
from fastapi import FastAPI
from .server import app as fastapi_app
from .ui import create_ui
@@ -18,6 +19,7 @@ gradio_demo = create_ui()
CUSTOM_PATH = "/gradio"
app = gr.mount_gradio_app(fastapi_app, gradio_demo, path=CUSTOM_PATH)
# Add a redirect from root to Gradio UI
@fastapi_app.get("/")
async def redirect_to_ui():
@@ -25,14 +27,16 @@ async def redirect_to_ui():
return {
"message": "Human Completion Server is running",
"ui_url": "/gradio",
"api_docs": "/docs"
"api_docs": "/docs",
}
if __name__ == "__main__":
import uvicorn
print("🚀 Starting Human-in-the-Loop Completion Server...")
print("📊 API Server: http://localhost:8002")
print("🎨 Gradio UI: http://localhost:8002/gradio")
print("📚 API Docs: http://localhost:8002/docs")
uvicorn.run(app, host="0.0.0.0", port=8002)

View File

@@ -1,9 +1,9 @@
import asyncio
import uuid
from dataclasses import asdict, dataclass
from datetime import datetime
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, asdict
from enum import Enum
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
@@ -49,7 +49,7 @@ class CompletionQueue:
self._queue: Dict[str, CompletionCall] = {}
self._pending_order: List[str] = []
self._lock = asyncio.Lock()
async def add_completion(self, messages: List[Dict[str, Any]], model: str) -> str:
"""Add a completion call to the queue."""
async with self._lock:
@@ -59,42 +59,47 @@ class CompletionQueue:
messages=messages,
model=model,
status=CompletionStatus.PENDING,
created_at=datetime.now()
created_at=datetime.now(),
)
self._queue[call_id] = completion_call
self._pending_order.append(call_id)
return call_id
async def get_pending_calls(self) -> List[Dict[str, Any]]:
"""Get all pending completion calls."""
async with self._lock:
pending_calls = []
for call_id in self._pending_order:
if call_id in self._queue and self._queue[call_id].status == CompletionStatus.PENDING:
if (
call_id in self._queue
and self._queue[call_id].status == CompletionStatus.PENDING
):
call = self._queue[call_id]
pending_calls.append({
"id": call.id,
"model": call.model,
"created_at": call.created_at.isoformat(),
"messages": call.messages
})
pending_calls.append(
{
"id": call.id,
"model": call.model,
"created_at": call.created_at.isoformat(),
"messages": call.messages,
}
)
return pending_calls
async def get_call_status(self, call_id: str) -> Optional[Dict[str, Any]]:
"""Get the status of a specific completion call."""
async with self._lock:
if call_id not in self._queue:
return None
call = self._queue[call_id]
result = {
"id": call.id,
"status": call.status.value,
"created_at": call.created_at.isoformat(),
"model": call.model,
"messages": call.messages
"messages": call.messages,
}
if call.completed_at:
result["completed_at"] = call.completed_at.isoformat()
if call.response:
@@ -103,69 +108,74 @@ class CompletionQueue:
result["tool_calls"] = call.tool_calls
if call.error:
result["error"] = call.error
return result
async def complete_call(self, call_id: str, response: Optional[str] = None, tool_calls: Optional[List[Dict[str, Any]]] = None) -> bool:
async def complete_call(
self,
call_id: str,
response: Optional[str] = None,
tool_calls: Optional[List[Dict[str, Any]]] = None,
) -> bool:
"""Mark a completion call as completed with a response or tool calls."""
async with self._lock:
if call_id not in self._queue:
return False
call = self._queue[call_id]
if call.status != CompletionStatus.PENDING:
return False
call.status = CompletionStatus.COMPLETED
call.completed_at = datetime.now()
call.response = response
call.tool_calls = tool_calls
# Remove from pending order
if call_id in self._pending_order:
self._pending_order.remove(call_id)
return True
async def fail_call(self, call_id: str, error: str) -> bool:
"""Mark a completion call as failed with an error."""
async with self._lock:
if call_id not in self._queue:
return False
call = self._queue[call_id]
if call.status != CompletionStatus.PENDING:
return False
call.status = CompletionStatus.FAILED
call.completed_at = datetime.now()
call.error = error
# Remove from pending order
if call_id in self._pending_order:
self._pending_order.remove(call_id)
return True
async def wait_for_completion(self, call_id: str, timeout: float = 300.0) -> Optional[str]:
"""Wait for a completion call to be completed and return the response."""
start_time = asyncio.get_event_loop().time()
while True:
status = await self.get_call_status(call_id)
if not status:
return None
if status["status"] == CompletionStatus.COMPLETED.value:
return status.get("response")
elif status["status"] == CompletionStatus.FAILED.value:
raise Exception(f"Completion failed: {status.get('error', 'Unknown error')}")
# Check timeout
if asyncio.get_event_loop().time() - start_time > timeout:
await self.fail_call(call_id, "Timeout waiting for human response")
raise TimeoutError("Timeout waiting for human response")
# Wait a bit before checking again
await asyncio.sleep(0.5)
@@ -204,9 +214,7 @@ async def get_status(call_id: str):
async def complete_call(call_id: str, response: CompletionResponse):
"""Complete a call with a human response."""
success = await completion_queue.complete_call(
call_id,
response=response.response,
tool_calls=response.tool_calls
call_id, response=response.response, tool_calls=response.tool_calls
)
if success:
return {"status": "success", "message": "Call completed"}
@@ -219,7 +227,9 @@ async def fail_call(call_id: str, error: Dict[str, str]):
"""Mark a call as failed."""
success = await completion_queue.fail_call(call_id, error.get("error", "Unknown error"))
if not success:
raise HTTPException(status_code=404, detail="Completion call not found or already completed")
raise HTTPException(
status_code=404, detail="Completion call not found or already completed"
)
return {"status": "failed"}
@@ -231,4 +241,5 @@ async def root():
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8002)

View File

@@ -1,14 +1,17 @@
import gradio as gr
import json
import time
from typing import List, Dict, Any, Optional
from datetime import datetime
import requests
from .server import completion_queue
import base64
import io
import json
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
import gradio as gr
import requests
from PIL import Image
from .server import completion_queue
class HumanCompletionUI:
def __init__(self, server_url: str = "http://localhost:8002"):
self.server_url = server_url
@@ -20,7 +23,7 @@ class HumanCompletionUI:
self.current_button: str = "left"
self.current_scroll_x: int = 0
self.current_scroll_y: int = -120
def format_messages_for_chatbot(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Format messages for display in gr.Chatbot with type='messages'."""
formatted = []
@@ -28,7 +31,7 @@ class HumanCompletionUI:
role = msg.get("role", "user")
content = msg.get("content", "")
tool_calls = msg.get("tool_calls", [])
# Handle different content formats
if isinstance(content, list):
# Multi-modal content - can include text and images
@@ -55,7 +58,7 @@ class HumanCompletionUI:
else:
# For URL images, create gr.Image with URL
formatted_content.append(gr.Image(value=image_url))
# Determine final content format
if len(formatted_content) == 1:
content = formatted_content[0]
@@ -63,28 +66,28 @@ class HumanCompletionUI:
content = formatted_content
else:
content = "[Empty content]"
# Ensure role is valid for Gradio Chatbot
if role not in ["user", "assistant"]:
role = "assistant" if role == "system" else "user"
# Invert roles for better display in human UI context
# (what the AI says becomes "user", what human should respond becomes "assistant")
if role == "user":
role = "assistant"
else:
role = "user"
# Add the main message if it has content
if content and str(content).strip():
formatted.append({"role": role, "content": content})
# Handle tool calls - create separate messages for each tool call
if tool_calls:
for tool_call in tool_calls:
function_name = tool_call.get("function", {}).get("name", "unknown")
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
try:
# Parse arguments to format them nicely
arguments = json.loads(arguments_str)
@@ -92,18 +95,20 @@ class HumanCompletionUI:
except json.JSONDecodeError:
# If parsing fails, use the raw string
formatted_args = arguments_str
# Create a formatted message for the tool call
tool_call_content = f"```json\n{formatted_args}\n```"
formatted.append({
"role": role,
"content": tool_call_content,
"metadata": {"title": f"🛠️ Used {function_name}"}
})
formatted.append(
{
"role": role,
"content": tool_call_content,
"metadata": {"title": f"🛠️ Used {function_name}"},
}
)
return formatted
def get_pending_calls(self) -> List[Dict[str, Any]]:
"""Get pending calls from the server."""
try:
@@ -113,38 +118,39 @@ class HumanCompletionUI:
except Exception as e:
print(f"Error fetching pending calls: {e}")
return []
def complete_call_with_response(self, call_id: str, response: str) -> bool:
"""Complete a call with a text response."""
try:
response_data = {"response": response}
response_obj = requests.post(
f"{self.server_url}/complete/{call_id}",
json=response_data,
timeout=10
f"{self.server_url}/complete/{call_id}", json=response_data, timeout=10
)
response_obj.raise_for_status()
return True
except requests.RequestException as e:
print(f"Error completing call: {e}")
return False
def complete_call_with_tool_calls(self, call_id: str, tool_calls: List[Dict[str, Any]]) -> bool:
"""Complete a call with tool calls."""
try:
response_data = {"tool_calls": tool_calls}
response_obj = requests.post(
f"{self.server_url}/complete/{call_id}",
json=response_data,
timeout=10
f"{self.server_url}/complete/{call_id}", json=response_data, timeout=10
)
response_obj.raise_for_status()
return True
except requests.RequestException as e:
print(f"Error completing call: {e}")
return False
def complete_call(self, call_id: str, response: Optional[str] = None, tool_calls: Optional[List[Dict[str, Any]]] = None) -> bool:
def complete_call(
self,
call_id: str,
response: Optional[str] = None,
tool_calls: Optional[List[Dict[str, Any]]] = None,
) -> bool:
"""Complete a call with either a response or tool calls."""
try:
response_data = {}
@@ -152,25 +158,23 @@ class HumanCompletionUI:
response_data["response"] = response
if tool_calls:
response_data["tool_calls"] = tool_calls
response_obj = requests.post(
f"{self.server_url}/complete/{call_id}",
json=response_data,
timeout=10
f"{self.server_url}/complete/{call_id}", json=response_data, timeout=10
)
response_obj.raise_for_status()
return True
except requests.RequestException as e:
print(f"Error completing call: {e}")
return False
def get_last_image_from_messages(self, messages: List[Dict[str, Any]]) -> Optional[Any]:
"""Extract the last image from the messages for display above conversation."""
last_image = None
for msg in reversed(messages): # Start from the last message
content = msg.get("content", "")
if isinstance(content, list):
for item in reversed(content): # Get the last image in the message
if item.get("type") == "image_url":
@@ -189,13 +193,13 @@ class HumanCompletionUI:
else:
# For URL images, return the URL
return image_url
return last_image
def refresh_pending_calls(self):
"""Refresh the list of pending calls."""
pending_calls = self.get_pending_calls()
if not pending_calls:
return (
gr.update(choices=["latest"], value="latest"), # dropdown
@@ -205,27 +209,27 @@ class HumanCompletionUI:
gr.update(visible=False), # click_actions_group hidden
gr.update(visible=False), # actions_group hidden
)
# Sort pending calls by created_at to get oldest first
sorted_calls = sorted(pending_calls, key=lambda x: x.get("created_at", ""))
# Create choices for dropdown
choices = [("latest", "latest")] # Add "latest" option first
for call in sorted_calls:
call_id = call["id"]
model = call.get("model", "unknown")
created_at = call.get("created_at", "")
# Format timestamp
try:
dt = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
dt = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
time_str = dt.strftime("%H:%M:%S")
except:
time_str = created_at
choice_label = f"{call_id[:8]}... ({model}) - {time_str}"
choices.append((choice_label, call_id))
# Default to "latest" which shows the oldest pending conversation
selected_call_id = "latest"
if selected_call_id == "latest" and sorted_calls:
@@ -239,7 +243,7 @@ class HumanCompletionUI:
conversation = []
self.current_call_id = None
self.last_image = None
return (
gr.update(choices=choices, value="latest"),
gr.update(value=self.last_image),
@@ -248,7 +252,7 @@ class HumanCompletionUI:
gr.update(visible=True), # click_actions_group visible when there is a call
gr.update(visible=True), # actions_group visible when there is a call
)
def on_call_selected(self, selected_choice):
"""Handle when a call is selected from the dropdown."""
if not selected_choice:
@@ -259,7 +263,7 @@ class HumanCompletionUI:
gr.update(visible=False), # click_actions_group hidden
gr.update(visible=False), # actions_group hidden
)
pending_calls = self.get_pending_calls()
if not pending_calls:
return (
@@ -269,7 +273,7 @@ class HumanCompletionUI:
gr.update(visible=False), # click_actions_group hidden
gr.update(visible=False), # actions_group hidden
)
# Handle "latest" option
if selected_choice == "latest":
# Sort calls by created_at to get oldest first
@@ -284,17 +288,17 @@ class HumanCompletionUI:
if call_id_short in selected_choice:
call_id = call["id"]
break
if not call_id:
return (
gr.update(value=None), # no image
gr.update(value=[]), # empty chatbot
gr.update(interactive=False)
gr.update(interactive=False),
)
# Find the selected call
selected_call = next((c for c in pending_calls if c["id"] == call_id), None)
if not selected_call:
return (
gr.update(value=None), # no image
@@ -303,12 +307,12 @@ class HumanCompletionUI:
gr.update(visible=False), # click_actions_group hidden
gr.update(visible=False), # actions_group hidden
)
conversation = self.format_messages_for_chatbot(selected_call.get("messages", []))
self.current_call_id = call_id
# Get the last image from messages
self.last_image = self.get_last_image_from_messages(selected_call.get("messages", []))
return (
gr.update(value=self.last_image),
gr.update(value=conversation),
@@ -316,110 +320,111 @@ class HumanCompletionUI:
gr.update(visible=True), # click_actions_group visible
gr.update(visible=True), # actions_group visible
)
def submit_response(self, response_text: str):
"""Submit a text response to the current call."""
if not self.current_call_id:
return (
gr.update(value=response_text), # keep response text
gr.update(value="❌ No call selected") # status
gr.update(value="❌ No call selected"), # status
)
if not response_text.strip():
return (
gr.update(value=response_text), # keep response text
gr.update(value="❌ Response cannot be empty") # status
gr.update(value="❌ Response cannot be empty"), # status
)
success = self.complete_call_with_response(self.current_call_id, response_text)
if success:
status_msg = "✅ Response submitted successfully!"
return (
gr.update(value=""), # clear response text
gr.update(value=status_msg) # status
gr.update(value=status_msg), # status
)
else:
return (
gr.update(value=response_text), # keep response text
gr.update(value="❌ Failed to submit response") # status
gr.update(value="❌ Failed to submit response"), # status
)
def submit_action(self, action_type: str, **kwargs) -> str:
"""Submit a computer action as a tool call."""
if not self.current_call_id:
return "❌ No call selected"
import uuid
# Create tool call structure
action_data = {"type": action_type, **kwargs}
tool_call = {
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": "computer",
"arguments": json.dumps(action_data)
}
"function": {"name": "computer", "arguments": json.dumps(action_data)},
}
success = self.complete_call_with_tool_calls(self.current_call_id, [tool_call])
if success:
return f"{action_type.capitalize()} action submitted as tool call"
else:
return f"❌ Failed to submit {action_type} action"
def submit_click_action(self, x: int, y: int, action_type: str = "click", button: str = "left") -> str:
def submit_click_action(
self, x: int, y: int, action_type: str = "click", button: str = "left"
) -> str:
"""Submit a coordinate-based action."""
if action_type == "click":
return self.submit_action(action_type, x=x, y=y, button=button)
else:
return self.submit_action(action_type, x=x, y=y)
def submit_type_action(self, text: str) -> str:
"""Submit a type action."""
return self.submit_action("type", text=text)
def submit_hotkey_action(self, keys: str) -> str:
"""Submit a hotkey action."""
return self.submit_action("keypress", keys=keys)
def submit_wait_action(self) -> str:
"""Submit a wait action with no kwargs."""
return self.submit_action("wait")
def submit_description_click(self, description: str, action_type: str = "click", button: str = "left") -> str:
def submit_description_click(
self, description: str, action_type: str = "click", button: str = "left"
) -> str:
"""Submit a description-based action."""
if action_type == "click":
return self.submit_action(action_type, element_description=description, button=button)
else:
return self.submit_action(action_type, element_description=description)
def wait_for_pending_calls(self, max_seconds: float = 10.0, check_interval: float = 0.2):
"""Wait for pending calls to appear or until max_seconds elapsed.
This method loops and checks for pending calls at regular intervals,
returning as soon as a pending call is found or the maximum wait time is reached.
Args:
max_seconds: Maximum number of seconds to wait
check_interval: How often to check for pending calls (in seconds)
"""
import time
start_time = time.time()
while time.time() - start_time < max_seconds:
# Check if there are any pending calls
pending_calls = self.get_pending_calls()
if pending_calls:
# Found pending calls, return immediately
return self.refresh_pending_calls()
# Wait before checking again
time.sleep(check_interval)
# Max wait time reached, return current state
return self.refresh_pending_calls()
@@ -427,79 +432,73 @@ class HumanCompletionUI:
def create_ui():
"""Create the Gradio interface."""
ui_handler = HumanCompletionUI()
with gr.Blocks(title="Human-in-the-Loop Agent Tool", fill_width=True) as demo:
gr.Markdown("# 🤖 Human-in-the-Loop Agent Tool")
gr.Markdown("Review AI conversation requests and provide human responses.")
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
screenshot_image = gr.Image(
label="Interactive Screenshot",
interactive=False,
height=600
label="Interactive Screenshot", interactive=False, height=600
)
# Action type selection for image clicks (wrapped for visibility control)
with gr.Group(visible=False) as click_actions_group:
with gr.Row():
action_type_radio = gr.Dropdown(
label="Interactive Action",
choices=["click", "double_click", "move", "left_mouse_up", "left_mouse_down", "scroll"],
choices=[
"click",
"double_click",
"move",
"left_mouse_up",
"left_mouse_down",
"scroll",
],
value="click",
scale=2
scale=2,
)
action_button_radio = gr.Dropdown(
label="Button",
choices=["left", "right", "wheel", "back", "forward"],
value="left",
visible=True,
scale=1
scale=1,
)
scroll_x_input = gr.Number(
label="scroll_x",
value=0,
visible=False,
scale=1
label="scroll_x", value=0, visible=False, scale=1
)
scroll_y_input = gr.Number(
label="scroll_y",
value=-120,
visible=False,
scale=1
label="scroll_y", value=-120, visible=False, scale=1
)
conversation_chatbot = gr.Chatbot(
label="Conversation",
type="messages",
height=500,
show_copy_button=True
label="Conversation", type="messages", height=500, show_copy_button=True
)
with gr.Column(scale=1):
with gr.Group():
call_dropdown = gr.Dropdown(
label="Select a pending conversation request",
choices=["latest"],
interactive=True,
value="latest"
value="latest",
)
refresh_btn = gr.Button("🔄 Refresh", variant="secondary")
status_display = gr.Textbox(
label="Status",
interactive=False,
value="Ready to receive requests..."
label="Status", interactive=False, value="Ready to receive requests..."
)
with gr.Group():
response_text = gr.Textbox(
label="Message",
lines=3,
placeholder="Enter your message here..."
label="Message", lines=3, placeholder="Enter your message here..."
)
submit_btn = gr.Button("📤 Submit Message", variant="primary", interactive=False)
submit_btn = gr.Button(
"📤 Submit Message", variant="primary", interactive=False
)
# Action Accordions (wrapped for visibility control)
with gr.Group(visible=False) as actions_group:
with gr.Tabs():
@@ -507,58 +506,73 @@ def create_ui():
with gr.Group():
description_text = gr.Textbox(
label="Element Description",
placeholder="e.g., 'Privacy and security option in left sidebar'"
placeholder="e.g., 'Privacy and security option in left sidebar'",
)
with gr.Row():
description_action_type = gr.Dropdown(
label="Action",
choices=["click", "double_click", "move", "left_mouse_up", "left_mouse_down"],
value="click"
choices=[
"click",
"double_click",
"move",
"left_mouse_up",
"left_mouse_down",
],
value="click",
)
description_button = gr.Dropdown(
label="Button",
choices=["left", "right", "wheel", "back", "forward"],
value="left"
value="left",
)
description_submit_btn = gr.Button("Submit Click Action")
with gr.Tab("📝 Type Action"):
with gr.Group():
type_text = gr.Textbox(
label="Text to Type",
placeholder="Enter text to type..."
label="Text to Type", placeholder="Enter text to type..."
)
type_submit_btn = gr.Button("Submit Type")
with gr.Tab("⌨️ Keypress Action"):
with gr.Group():
keypress_text = gr.Textbox(
label="Keys",
placeholder="e.g., ctrl+c, alt+tab"
label="Keys", placeholder="e.g., ctrl+c, alt+tab"
)
keypress_submit_btn = gr.Button("Submit Keypress")
with gr.Tab("🧰 Misc Actions"):
with gr.Group():
misc_action_dropdown = gr.Dropdown(
label="Action",
choices=["wait"],
value="wait"
label="Action", choices=["wait"], value="wait"
)
misc_submit_btn = gr.Button("Submit Action")
# Event handlers
refresh_btn.click(
fn=ui_handler.refresh_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
call_dropdown.change(
fn=ui_handler.on_call_selected,
inputs=[call_dropdown],
outputs=[screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
def handle_image_click(evt: gr.SelectData):
if evt.index is not None:
x, y = evt.index
@@ -568,31 +582,44 @@ def create_ui():
sx_i = int(ui_handler.current_scroll_x or 0)
sy_i = int(ui_handler.current_scroll_y or 0)
# Submit a scroll action with x,y position and scroll deltas
result = ui_handler.submit_action("scroll", x=x, y=y, scroll_x=sx_i, scroll_y=sy_i)
result = ui_handler.submit_action(
"scroll", x=x, y=y, scroll_x=sx_i, scroll_y=sy_i
)
else:
result = ui_handler.submit_click_action(x, y, action_type, button)
ui_handler.wait_for_pending_calls()
return result
return "No coordinates selected"
screenshot_image.select(
fn=handle_image_click,
outputs=[status_display]
).then(
screenshot_image.select(fn=handle_image_click, outputs=[status_display]).then(
fn=ui_handler.wait_for_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
# Response submission
submit_btn.click(
fn=ui_handler.submit_response,
inputs=[response_text],
outputs=[response_text, status_display]
outputs=[response_text, status_display],
).then(
fn=ui_handler.refresh_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
# Toggle visibility of controls based on action type
def toggle_action_controls(action_type):
# Button visible only for click
@@ -603,59 +630,63 @@ def create_ui():
# Update state
ui_handler.current_action_type = action_type or "click"
return button_vis, scroll_x_vis, scroll_y_vis
action_type_radio.change(
fn=toggle_action_controls,
inputs=[action_type_radio],
outputs=[action_button_radio, scroll_x_input, scroll_y_input]
outputs=[action_button_radio, scroll_x_input, scroll_y_input],
)
# Keep other control values in ui_handler state
def on_button_change(val):
ui_handler.current_button = (val or "left")
action_button_radio.change(
fn=on_button_change,
inputs=[action_button_radio]
)
ui_handler.current_button = val or "left"
action_button_radio.change(fn=on_button_change, inputs=[action_button_radio])
def on_scroll_x_change(val):
try:
ui_handler.current_scroll_x = int(val) if val is not None else 0
except Exception:
ui_handler.current_scroll_x = 0
scroll_x_input.change(
fn=on_scroll_x_change,
inputs=[scroll_x_input]
)
scroll_x_input.change(fn=on_scroll_x_change, inputs=[scroll_x_input])
def on_scroll_y_change(val):
try:
ui_handler.current_scroll_y = int(val) if val is not None else 0
except Exception:
ui_handler.current_scroll_y = 0
scroll_y_input.change(
fn=on_scroll_y_change,
inputs=[scroll_y_input]
)
scroll_y_input.change(fn=on_scroll_y_change, inputs=[scroll_y_input])
type_submit_btn.click(
fn=ui_handler.submit_type_action,
inputs=[type_text],
outputs=[status_display]
fn=ui_handler.submit_type_action, inputs=[type_text], outputs=[status_display]
).then(
fn=ui_handler.wait_for_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
keypress_submit_btn.click(
fn=ui_handler.submit_hotkey_action,
inputs=[keypress_text],
outputs=[status_display]
fn=ui_handler.submit_hotkey_action, inputs=[keypress_text], outputs=[status_display]
).then(
fn=ui_handler.wait_for_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
def handle_description_submit(description, action_type, button):
if description:
result = ui_handler.submit_description_click(description, action_type, button)
@@ -666,12 +697,19 @@ def create_ui():
description_submit_btn.click(
fn=handle_description_submit,
inputs=[description_text, description_action_type, description_button],
outputs=[status_display]
outputs=[status_display],
).then(
fn=ui_handler.wait_for_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
# Misc action handler
def handle_misc_submit(selected_action):
if selected_action == "wait":
@@ -681,20 +719,32 @@ def create_ui():
return f"Unsupported misc action: {selected_action}"
misc_submit_btn.click(
fn=handle_misc_submit,
inputs=[misc_action_dropdown],
outputs=[status_display]
fn=handle_misc_submit, inputs=[misc_action_dropdown], outputs=[status_display]
).then(
fn=ui_handler.wait_for_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
# Load initial data
demo.load(
fn=ui_handler.refresh_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
return demo

View File

@@ -8,21 +8,22 @@ Exports:
- run_full_dataset(dataset, ...)
- MCPComputerAgent
"""
import time
from typing import Any, Optional
from agent.computers import is_agent_computer
from datasets import load_dataset, Dataset
from hud.datasets import Task, run_dataset
from datasets import Dataset, load_dataset
from hud import trace
from hud.datasets import Task, run_dataset
from .agent import MCPComputerAgent
# ---------------------------------------------------------------------------
# Single-task runner
# ---------------------------------------------------------------------------
async def run_single_task(
dataset: str | Dataset | list[dict[str, Any]],
*,
@@ -47,24 +48,20 @@ async def run_single_task(
# Load dataset and pick a sample
if isinstance(dataset, str):
dataset = load_dataset(dataset, split="train") # type: ignore[arg-type]
dataset = load_dataset(dataset, split="train") # type: ignore[arg-type]
elif isinstance(dataset, list):
dataset = dataset
else:
dataset = dataset["train"]
sample_task = dataset[task_id] # type: ignore[index]
task_prompt = sample_task.get("prompt", f"Task {sample_task.get('id', 0)}") # type: ignore[attr-defined]
# Filter any existing Computer tools
# The eval framework will add its own Computer tool per task
if tools:
tools = [
tool
for tool in tools
if not is_agent_computer(tool)
]
tools = [tool for tool in tools if not is_agent_computer(tool)]
with trace(name=task_prompt):
task = Task(**sample_task) # type: ignore[arg-type]
@@ -87,13 +84,14 @@ async def run_single_task(
)
print(f"Running: {task_prompt}")
result = await agent.run(task, max_steps=10)
print(f"✅ Reward: {getattr(result, 'reward')}")
print(f"✅ Reward: {result.reward}")
# ---------------------------------------------------------------------------
# Full-dataset runner
# ---------------------------------------------------------------------------
async def run_full_dataset(
dataset: str | Dataset | list[dict[str, Any]],
*,
@@ -121,9 +119,9 @@ async def run_full_dataset(
# Run with our MCP-based agent class.
if isinstance(dataset, str):
dataset_name = dataset.split('/')[-1]
dataset_name = dataset.split("/")[-1]
job_name = job_name or f"Evaluation {dataset_name}"
dataset = load_dataset(dataset, split=split) # type: ignore[arg-type]
dataset = load_dataset(dataset, split=split) # type: ignore[arg-type]
else:
dataset_name = "custom"
job_name = job_name or f"Evaluation {time.strftime('%H:%M %Y-%m-%d')}"
@@ -131,12 +129,8 @@ async def run_full_dataset(
# Filter any existing Computer tools
# The eval framework will add its own Computer tool per task
if tools:
tools = [
tool
for tool in tools
if not is_agent_computer(tool)
]
tools = [tool for tool in tools if not is_agent_computer(tool)]
# Execute evaluation
return await run_dataset(
name=job_name,
@@ -170,4 +164,4 @@ __all__ = [
"run_single_task",
"run_full_dataset",
"MCPComputerAgent",
]
]

View File

@@ -9,26 +9,26 @@ Key differences from the OpenAI OperatorAgent variant:
- Planning is executed via `ComputerAgent.run(messages)`.
- The first yielded result per step is returned as the agent response.
"""
from __future__ import annotations
import base64
import io
import uuid
from pathlib import Path
from typing import Any, ClassVar, Optional
import hud
import mcp.types as types
from agent.agent import ComputerAgent as BaseComputerAgent
from agent.callbacks import PromptInstructionsCallback
from agent.callbacks.trajectory_saver import TrajectorySaverCallback
from agent.computers import is_agent_computer
from agent.responses import make_failed_tool_call_items
from hud.agents import MCPAgent
from hud.tools.computer.settings import computer_settings
from hud.types import AgentResponse, MCPToolCall, MCPToolResult, Trace
from agent.responses import make_failed_tool_call_items
from agent.computers import is_agent_computer
from PIL import Image
import mcp.types as types
import hud
import uuid
import base64
from pathlib import Path
class MCPComputerAgent(MCPAgent):
@@ -114,8 +114,10 @@ class MCPComputerAgent(MCPAgent):
self.last_screenshot_b64 = None
buffer = io.BytesIO()
Image.new('RGB', (self.metadata["display_width"], self.metadata["display_height"])).save(buffer, format='PNG')
self.last_screenshot_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
Image.new("RGB", (self.metadata["display_width"], self.metadata["display_height"])).save(
buffer, format="PNG"
)
self.last_screenshot_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
# Ensure a computer shim is present so width/height/environment are known
computer_shim = {
@@ -128,12 +130,8 @@ class MCPComputerAgent(MCPAgent):
}
agent_tools: list[Any] = [computer_shim]
if tools:
agent_tools.extend([
tool
for tool in tools
if not is_agent_computer(tool)
])
agent_tools.extend([tool for tool in tools if not is_agent_computer(tool)])
agent_kwargs = {
"model": self.model,
"trajectory_dir": trajectory_dir,
@@ -150,9 +148,7 @@ class MCPComputerAgent(MCPAgent):
"telemetry_enabled": telemetry_enabled,
}
self.computer_agent = BaseComputerAgent(
**agent_kwargs
)
self.computer_agent = BaseComputerAgent(**agent_kwargs)
async def get_system_messages(self) -> list[Any]:
"""Create initial messages.
@@ -161,9 +157,7 @@ class MCPComputerAgent(MCPAgent):
"""
return []
async def format_blocks(
self, blocks: list[types.ContentBlock]
) -> list[dict[str, Any]]:
async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[dict[str, Any]]:
"""
Format blocks for OpenAI input format.
@@ -200,42 +194,49 @@ class MCPComputerAgent(MCPAgent):
# Call the ComputerAgent LLM API
async for result in self.computer_agent.run(messages): # type: ignore[arg-type]
items = result['output']
items = result["output"]
if not items or tool_calls:
break
for item in items:
if item['type'] in ['reasoning', 'message', 'computer_call', 'function_call', 'function_call_output']:
if item["type"] in [
"reasoning",
"message",
"computer_call",
"function_call",
"function_call_output",
]:
agent_result.append(item)
# Add messages to output text
if item['type'] == 'reasoning':
if item["type"] == "reasoning":
output_text.extend(
f"Reasoning: {summary['text']}"
for summary in item['summary']
f"Reasoning: {summary['text']}" for summary in item["summary"]
)
elif item['type'] == 'message':
if isinstance(item['content'], list):
elif item["type"] == "message":
if isinstance(item["content"], list):
output_text.extend(
item['text']
for item in item['content']
if item['type'] == 'output_text'
item["text"]
for item in item["content"]
if item["type"] == "output_text"
)
elif isinstance(item['content'], str):
output_text.append(item['content'])
elif isinstance(item["content"], str):
output_text.append(item["content"])
# If we get a tool call, we're not done
if item['type'] == 'computer_call':
if item["type"] == "computer_call":
id = item["call_id"]
tool_calls.append(MCPToolCall(
name="openai_computer",
arguments=item["action"],
id=id,
))
tool_calls.append(
MCPToolCall(
name="openai_computer",
arguments=item["action"],
id=id,
)
)
is_done = False
self.tool_call_inputs[id] = agent_result
break
# if we have tool calls, we should exit the loop
if tool_calls:
break
@@ -247,7 +248,7 @@ class MCPComputerAgent(MCPAgent):
tool_calls=tool_calls,
done=is_done,
)
def _log_image(self, image_b64: str):
callbacks = self.computer_agent.callbacks
for callback in callbacks:
@@ -257,9 +258,7 @@ class MCPComputerAgent(MCPAgent):
callback._save_artifact("screenshot_after", image_bytes)
async def format_tool_results(
self,
tool_calls: list[MCPToolCall],
tool_results: list[MCPToolResult]
self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
) -> list[dict[str, Any]]:
"""Extract latest screenshot from tool results in dict form.
@@ -274,45 +273,60 @@ class MCPComputerAgent(MCPAgent):
previous_output = self.previous_output.copy() or []
# First we need to remove any pending computer_calls from the end of previous_output
while previous_output and previous_output[-1]['type'] == 'computer_call':
while previous_output and previous_output[-1]["type"] == "computer_call":
previous_output.pop()
messages.extend(previous_output)
# If the call is a 'response', don't add the result
if call.name == 'response':
if call.name == "response":
continue
# Otherwise, if we have a result, we should add it to the messages
content = [
{ "type": "input_text", "text": content.text } if isinstance(content, types.TextContent)
else { "type": "input_image", "image_url": f"data:image/png;base64,{content.data}" } if isinstance(content, types.ImageContent)
else { "type": "input_text", "text": "" }
(
{"type": "input_text", "text": content.text}
if isinstance(content, types.TextContent)
else (
{
"type": "input_image",
"image_url": f"data:image/png;base64,{content.data}",
}
if isinstance(content, types.ImageContent)
else {"type": "input_text", "text": ""}
)
)
for content in result.content
]
messages.append({
"role": "user",
"content": content,
})
messages.append(
{
"role": "user",
"content": content,
}
)
continue
# Add the assistant's computer call
messages.extend(self.tool_call_inputs[call.id])
if result.isError:
error_text = "".join([
content.text
for content in result.content
if isinstance(content, types.TextContent)
])
error_text = "".join(
[
content.text
for content in result.content
if isinstance(content, types.TextContent)
]
)
# Replace computer call with failed tool call
messages.pop()
messages.extend(make_failed_tool_call_items(
tool_name=call.name,
tool_kwargs=call.arguments or {},
error_message=error_text,
call_id=call.id,
))
messages.extend(
make_failed_tool_call_items(
tool_name=call.name,
tool_kwargs=call.arguments or {},
error_message=error_text,
call_id=call.id,
)
)
else:
# Get the latest screenshot
screenshots = [
@@ -325,23 +339,27 @@ class MCPComputerAgent(MCPAgent):
if screenshots:
self._log_image(screenshots[0])
self.last_screenshot_b64 = screenshots[0]
messages.append({
"type": "computer_call_output",
"call_id": call.id,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshots[0]}"
},
})
messages.append(
{
"type": "computer_call_output",
"call_id": call.id,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshots[0]}",
},
}
)
else:
# Otherwise, replace computer call with failed tool call
messages.pop()
messages.extend(make_failed_tool_call_items(
tool_name=call.name,
tool_kwargs=call.arguments or {},
error_message="No screenshots returned.",
call_id=call.id,
))
messages.extend(
make_failed_tool_call_items(
tool_name=call.name,
tool_kwargs=call.arguments or {},
error_message="No screenshots returned.",
call_id=call.id,
)
)
return messages

View File

@@ -7,30 +7,33 @@ OpenAI-like response blocks. We intentionally only support a single-step call
by consuming the first yielded result from `ComputerAgent.run()`.
"""
import traceback
import time
import traceback
import uuid
from typing import Any, Dict, List, Optional
from agent.agent import ComputerAgent as BaseComputerAgent
from agent.callbacks import PromptInstructionsCallback
from hud.tools.computer.settings import computer_settings
from PIL import Image
from hud.agents import OperatorAgent
from hud.tools.computer.settings import computer_settings
# OpenAI Responses typed models (required)
from openai.types.responses import (
Response,
ResponseComputerToolCall,
ResponseInputParam,
ResponseOutputItem,
ResponseComputerToolCall,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
ResponseUsage,
)
from PIL import Image
def _map_agent_output_to_openai_blocks(output_items: List[Dict[str, Any]]) -> List[ResponseOutputItem]:
def _map_agent_output_to_openai_blocks(
output_items: List[Dict[str, Any]],
) -> List[ResponseOutputItem]:
"""Map our agent output items to OpenAI ResponseOutputItem typed models.
Only a subset is supported: computer_call, assistant message (text), and reasoning.
@@ -40,14 +43,16 @@ def _map_agent_output_to_openai_blocks(output_items: List[Dict[str, Any]]) -> Li
for item in output_items or []:
t = item.get("type")
if t == "computer_call":
comp = ResponseComputerToolCall.model_validate({
"id": item.get("id") or f"cu_{uuid.uuid4().hex}",
"type": "computer_call",
"call_id": item["call_id"],
"action": item["action"],
"pending_safety_checks": item.get("pending_safety_checks", []),
"status": "completed",
})
comp = ResponseComputerToolCall.model_validate(
{
"id": item.get("id") or f"cu_{uuid.uuid4().hex}",
"type": "computer_call",
"call_id": item["call_id"],
"action": item["action"],
"pending_safety_checks": item.get("pending_safety_checks", []),
"status": "completed",
}
)
blocks.append(comp)
# we will exit early here as the responses api only supports a single step
break
@@ -55,31 +60,38 @@ def _map_agent_output_to_openai_blocks(output_items: List[Dict[str, Any]]) -> Li
content_blocks: List[ResponseOutputText] = []
for c in item.get("content", []) or []:
content_blocks.append(
ResponseOutputText.model_validate({
"type": "output_text",
"text": c["text"],
"annotations": [],
})
ResponseOutputText.model_validate(
{
"type": "output_text",
"text": c["text"],
"annotations": [],
}
)
)
if content_blocks:
msg = ResponseOutputMessage.model_validate({
"id": item.get("id") or f"msg_{uuid.uuid4()}",
"type": "message",
"role": "assistant",
"status": "completed",
"content": [ct.model_dump() for ct in content_blocks],
})
msg = ResponseOutputMessage.model_validate(
{
"id": item.get("id") or f"msg_{uuid.uuid4()}",
"type": "message",
"role": "assistant",
"status": "completed",
"content": [ct.model_dump() for ct in content_blocks],
}
)
blocks.append(msg)
elif t == "reasoning":
reasoning = ResponseReasoningItem.model_validate({
"id": item.get("id") or f"rsn_{uuid.uuid4()}",
"type": "reasoning",
"summary": item["summary"],
})
reasoning = ResponseReasoningItem.model_validate(
{
"id": item.get("id") or f"rsn_{uuid.uuid4()}",
"type": "reasoning",
"summary": item["summary"],
}
)
blocks.append(reasoning)
# Unhandled types are ignored
return blocks
def _to_plain_dict_list(items: Any) -> List[Dict[str, Any]]:
out: List[Dict[str, Any]] = []
for it in list(items):
@@ -92,6 +104,7 @@ def _to_plain_dict_list(items: Any) -> List[Dict[str, Any]]:
out.append(dict(it)) # may raise if not mapping
return out
class FakeAsyncOpenAI:
"""Minimal fake OpenAI client with only `responses.create` implemented.
@@ -132,10 +145,12 @@ class FakeAsyncOpenAI:
# Pre-pend instructions message
effective_input = full_input
if instructions:
effective_input = [{
"role": "user",
"content": instructions,
}] + full_input
effective_input = [
{
"role": "user",
"content": instructions,
}
] + full_input
# Run a single iteration of the ComputerAgent
agent_result: Optional[Dict[str, Any]] = None
@@ -152,32 +167,43 @@ class FakeAsyncOpenAI:
blocks_to_cache = full_input + output
for b in blocks_to_cache:
bid = getattr(b, "id", None) or f"tmp-{hash(repr(b))}"
self.blocks_cache[bid] = b # type: ignore[assignment]
self.blocks_cache[bid] = b # type: ignore[assignment]
block_ids.append(bid)
response_id = agent_result.get("id") or f"fake-{int(time.time()*1000)}"
self.context_cache[response_id] = block_ids
try:
return Response.model_validate({
"id": response_id,
"created_at": time.time(),
"object": "response",
"model": model,
"output": output,
"parallel_tool_calls": False,
"tool_choice": "auto",
"tools": [],
"previous_response_id": previous_response_id,
"usage": ResponseUsage.model_validate({
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
"input_tokens_details": usage.get("input_tokens_details", { "cached_tokens": 0 }),
"output_tokens_details": usage.get("output_tokens_details", { "reasoning_tokens": 0 }),
}),
})
return Response.model_validate(
{
"id": response_id,
"created_at": time.time(),
"object": "response",
"model": model,
"output": output,
"parallel_tool_calls": False,
"tool_choice": "auto",
"tools": [],
"previous_response_id": previous_response_id,
"usage": ResponseUsage.model_validate(
{
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
"input_tokens_details": usage.get(
"input_tokens_details", {"cached_tokens": 0}
),
"output_tokens_details": usage.get(
"output_tokens_details", {"reasoning_tokens": 0}
),
}
),
}
)
except Exception as e:
print(f"Error while validating agent response (attempt {attempt + 1}/{max_retries}): ", e)
print(
f"Error while validating agent response (attempt {attempt + 1}/{max_retries}): ",
e,
)
if attempt == max_retries - 1:
print(traceback.format_exc())
raise e
@@ -221,9 +247,15 @@ class ProxyOperatorAgent(OperatorAgent):
allowed_tools = allowed_tools or ["openai_computer"]
computer_shim = {
'screenshot': lambda: Image.new('RGB', (computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT)),
'environment': 'linux',
'dimensions': (computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT)
"screenshot": lambda: Image.new(
"RGB",
(computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT),
),
"environment": "linux",
"dimensions": (
computer_settings.OPENAI_COMPUTER_WIDTH,
computer_settings.OPENAI_COMPUTER_HEIGHT,
),
}
# Build tools ensuring the computer_shim is included
agent_tools: list[Any] = [computer_shim]
@@ -258,6 +290,7 @@ class ProxyOperatorAgent(OperatorAgent):
**kwargs,
)
__all__ = [
"FakeAsyncOpenAI",
"ProxyOperatorAgent",

View File

@@ -3,26 +3,34 @@ Agent loops for agent
"""
# Import the loops to register them
from . import anthropic
from . import openai
from . import uitars
from . import omniparser
from . import gta1
from . import composed_grounded
from . import glm45v
from . import opencua
from . import internvl
from . import holo
from . import (
anthropic,
composed_grounded,
gemini,
glm45v,
gta1,
holo,
internvl,
moondream3,
omniparser,
openai,
opencua,
qwen,
uitars,
)
__all__ = [
"anthropic",
"openai",
"uitars",
"omniparser",
"gta1",
"composed_grounded",
"glm45v",
"anthropic",
"openai",
"uitars",
"omniparser",
"gta1",
"composed_grounded",
"glm45v",
"opencua",
"internvl",
"holo",
]
"moondream3",
"gemini",
"qwen",
]

File diff suppressed because it is too large Load Diff

View File

@@ -2,13 +2,15 @@
Base protocol for async agent configurations
"""
from typing import Protocol, List, Dict, Any, Optional, Tuple, Union
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
from ..types import AgentCapability
class AsyncAgentConfig(Protocol):
"""Protocol defining the interface for async agent configurations."""
@abstractmethod
async def predict_step(
self,
@@ -22,11 +24,11 @@ class AsyncAgentConfig(Protocol):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""
Predict the next step based on input items.
Args:
messages: Input items following Responses format (message, function_call, computer_call)
model: Model name to use
@@ -39,37 +41,34 @@ class AsyncAgentConfig(Protocol):
_on_usage: Callback for usage tracking
_on_screenshot: Callback for screenshot events
**kwargs: Additional arguments
Returns:
Dictionary with "output" (output items) and "usage" array
"""
...
@abstractmethod
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str
self, model: str, image_b64: str, instruction: str
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates based on image and instruction.
Args:
model: Model name to use
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
None or tuple with (x, y) coordinates
"""
...
@abstractmethod
def get_capabilities(self) -> List[AgentCapability]:
"""
Get list of capabilities supported by this agent config.
Returns:
List of capability strings (e.g., ["step", "click"])
"""

View File

@@ -3,122 +3,117 @@ Composed-grounded agent loop implementation that combines grounding and thinking
Uses a two-stage approach: grounding model for element detection, thinking model for reasoning.
"""
import uuid
import asyncio
import json
import base64
from typing import Dict, List, Any, Optional, Tuple
import json
import uuid
from io import BytesIO
from PIL import Image
import litellm
from typing import Any, Dict, List, Optional, Tuple
import litellm
from PIL import Image
from ..agent import find_agent_config
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..loops.base import AsyncAgentConfig
from ..responses import (
convert_computer_calls_xy2desc,
convert_responses_items_to_completion_messages,
convert_completion_messages_to_responses_items,
convert_computer_calls_desc2xy,
get_all_element_descriptions
convert_computer_calls_xy2desc,
convert_responses_items_to_completion_messages,
get_all_element_descriptions,
)
from ..agent import find_agent_config
from ..types import AgentCapability, AgentResponse, Messages, Tools
GROUNDED_COMPUTER_TOOL_SCHEMA = {
"type": "function",
"function": {
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment"
],
"description": "The action to perform (required for all actions)"
},
"element_description": {
"type": "string",
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions)"
},
"start_element_description": {
"type": "string",
"description": "Description of the element to start dragging from (required for drag action)"
},
"end_element_description": {
"type": "string",
"description": "Description of the element to drag to (required for drag action)"
},
"text": {
"type": "string",
"description": "The text to type (required for type action)"
},
"keys": {
"type": "array",
"items": {
"type": "string"
"type": "function",
"function": {
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment",
],
"description": "The action to perform (required for all actions)",
},
"element_description": {
"type": "string",
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions)",
},
"start_element_description": {
"type": "string",
"description": "Description of the element to start dragging from (required for drag action)",
},
"end_element_description": {
"type": "string",
"description": "Description of the element to drag to (required for drag action)",
},
"text": {
"type": "string",
"description": "The text to type (required for type action)",
},
"keys": {
"type": "array",
"items": {"type": "string"},
"description": "Key(s) to press (required for keypress action)",
},
"button": {
"type": "string",
"enum": ["left", "right", "wheel", "back", "forward"],
"description": "The mouse button to use for click action (required for click and double_click action)",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (required for scroll action)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (required for scroll action)",
},
},
"description": "Key(s) to press (required for keypress action)"
"required": ["action"],
},
"button": {
"type": "string",
"enum": [
"left",
"right",
"wheel",
"back",
"forward"
],
"description": "The mouse button to use for click action (required for click and double_click action)",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (required for scroll action)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (required for scroll action)",
},
},
"required": [
"action"
]
}
}
},
}
def _prepare_tools_for_grounded(tool_schemas: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Prepare tools for grounded API format"""
grounded_tools = []
for schema in tool_schemas:
if schema["type"] == "computer":
grounded_tools.append(GROUNDED_COMPUTER_TOOL_SCHEMA)
else:
grounded_tools.append(schema)
return grounded_tools
def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str]:
"""Get the last computer call output image from messages."""
for message in reversed(messages):
if (isinstance(message, dict) and
message.get("type") == "computer_call_output" and
isinstance(message.get("output"), dict) and
message["output"].get("type") == "input_image"):
if (
isinstance(message, dict)
and message.get("type") == "computer_call_output"
and isinstance(message.get("output"), dict)
and message["output"].get("type") == "input_image"
):
image_url = message["output"].get("image_url", "")
if image_url.startswith("data:image/png;base64,"):
return image_url.split(",", 1)[1]
@@ -129,14 +124,14 @@ def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str
class ComposedGroundedConfig(AsyncAgentConfig):
"""
Composed-grounded agent configuration that uses both grounding and thinking models.
The model parameter should be in format: "grounding_model+thinking_model"
e.g., "huggingface-local/HelloKKMe/GTA1-7B+gemini/gemini-1.5-pro"
"""
def __init__(self):
self.desc2xy: Dict[str, Tuple[float, float]] = {}
async def predict_step(
self,
messages: List[Dict[str, Any]],
@@ -150,11 +145,11 @@ class ComposedGroundedConfig(AsyncAgentConfig):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""
Composed-grounded predict step implementation.
Process:
0. Store last computer call image, if none then take a screenshot
1. Convert computer calls from xy to descriptions
@@ -167,18 +162,20 @@ class ComposedGroundedConfig(AsyncAgentConfig):
"""
# Parse the composed model
if "+" not in model:
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
raise ValueError(
f"Composed model must be in format 'grounding_model+thinking_model', got: {model}"
)
grounding_model, thinking_model = model.split("+", 1)
pre_output_items = []
# Step 0: Store last computer call image, if none then take a screenshot
last_image_b64 = get_last_computer_call_image(messages)
if last_image_b64 is None:
# Take a screenshot
screenshot_b64 = await computer_handler.screenshot() # type: ignore
screenshot_b64 = await computer_handler.screenshot() # type: ignore
if screenshot_b64:
call_id = uuid.uuid4().hex
pre_output_items += [
{
@@ -187,45 +184,42 @@ class ComposedGroundedConfig(AsyncAgentConfig):
"content": [
{
"type": "output_text",
"text": "Taking a screenshot to see the current computer screen."
"text": "Taking a screenshot to see the current computer screen.",
}
]
],
},
{
"action": {
"type": "screenshot"
},
"action": {"type": "screenshot"},
"call_id": call_id,
"status": "completed",
"type": "computer_call"
"type": "computer_call",
},
{
"type": "computer_call_output",
"call_id": call_id,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_b64}"
}
"image_url": f"data:image/png;base64,{screenshot_b64}",
},
},
]
last_image_b64 = screenshot_b64
# Call screenshot callback if provided
if _on_screenshot:
await _on_screenshot(screenshot_b64)
tool_schemas = _prepare_tools_for_grounded(tools) # type: ignore
tool_schemas = _prepare_tools_for_grounded(tools) # type: ignore
# Step 1: Convert computer calls from xy to descriptions
input_messages = messages + pre_output_items
messages_with_descriptions = convert_computer_calls_xy2desc(input_messages, self.desc2xy)
# Step 2: Convert responses items to completion messages
completion_messages = convert_responses_items_to_completion_messages(
messages_with_descriptions,
allow_images_in_tool_results=False
messages_with_descriptions, allow_images_in_tool_results=False
)
# Step 3: Call thinking model with litellm.acompletion
api_kwargs = {
"model": thinking_model,
@@ -233,98 +227,90 @@ class ComposedGroundedConfig(AsyncAgentConfig):
"tools": tool_schemas,
"max_retries": max_retries,
"stream": stream,
**kwargs
**kwargs,
}
if use_prompt_caching:
api_kwargs["use_prompt_caching"] = use_prompt_caching
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
# Make the completion call
response = await litellm.acompletion(**api_kwargs)
# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)
# Extract usage information
usage = {
**response.usage.model_dump(), # type: ignore
**response.usage.model_dump(), # type: ignore
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(usage)
# Step 4: Convert completion messages back to responses items format
response_dict = response.model_dump() # type: ignore
response_dict = response.model_dump() # type: ignore
choice_messages = [choice["message"] for choice in response_dict["choices"]]
thinking_output_items = []
for choice_message in choice_messages:
thinking_output_items.extend(convert_completion_messages_to_responses_items([choice_message]))
thinking_output_items.extend(
convert_completion_messages_to_responses_items([choice_message])
)
# Step 5: Get all element descriptions and populate desc2xy mapping
element_descriptions = get_all_element_descriptions(thinking_output_items)
if element_descriptions and last_image_b64:
# Use grounding model to predict coordinates for each description
grounding_agent_conf = find_agent_config(grounding_model)
if grounding_agent_conf:
grounding_agent = grounding_agent_conf.agent_class()
for desc in element_descriptions:
for _ in range(3): # try 3 times
for _ in range(3): # try 3 times
coords = await grounding_agent.predict_click(
model=grounding_model,
image_b64=last_image_b64,
instruction=desc
model=grounding_model, image_b64=last_image_b64, instruction=desc
)
if coords:
self.desc2xy[desc] = coords
break
# Step 6: Convert computer calls from descriptions back to xy coordinates
final_output_items = convert_computer_calls_desc2xy(thinking_output_items, self.desc2xy)
# Step 7: Return output and usage
return {
"output": pre_output_items + final_output_items,
"usage": usage
}
return {"output": pre_output_items + final_output_items, "usage": usage}
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates using the grounding model.
For composed models, uses only the grounding model part for click prediction.
"""
# Parse the composed model to get grounding model
if "+" not in model:
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
raise ValueError(
f"Composed model must be in format 'grounding_model+thinking_model', got: {model}"
)
grounding_model, thinking_model = model.split("+", 1)
# Find and use the grounding agent
grounding_agent_conf = find_agent_config(grounding_model)
if grounding_agent_conf:
grounding_agent = grounding_agent_conf.agent_class()
return await grounding_agent.predict_click(
model=grounding_model,
image_b64=image_b64,
instruction=instruction,
**kwargs
model=grounding_model, image_b64=image_b64, instruction=instruction, **kwargs
)
return None
def get_capabilities(self) -> List[AgentCapability]:
"""Return the capabilities supported by this agent."""
return ["click", "step"]

View File

@@ -0,0 +1,410 @@
"""
Gemini 2.5 Computer Use agent loop
Maps internal Agent SDK message format to Google's Gemini Computer Use API and back.
Key features:
- Lazy import of google.genai
- Configure Computer Use tool with excluded browser-specific predefined functions
- Optional custom function declarations hook for computer-call specific functions
- Convert Gemini function_call parts into internal computer_call actions
"""
from __future__ import annotations
import base64
import io
import uuid
from typing import Any, Dict, List, Optional, Tuple
from PIL import Image
from ..decorators import register_agent
from ..loops.base import AsyncAgentConfig
from ..types import AgentCapability
def _lazy_import_genai():
"""Import google.genai lazily to avoid hard dependency unless used."""
try:
from google import genai # type: ignore
from google.genai import types # type: ignore
return genai, types
except Exception as e: # pragma: no cover
raise RuntimeError(
"google.genai is required for the Gemini Computer Use loop. Install the Google Gemini SDK."
) from e
def _data_url_to_bytes(data_url: str) -> Tuple[bytes, str]:
"""Convert a data URL to raw bytes and mime type."""
if not data_url.startswith("data:"):
# Assume it's base64 png payload
try:
return base64.b64decode(data_url), "image/png"
except Exception:
return b"", "application/octet-stream"
header, b64 = data_url.split(",", 1)
mime = "image/png"
if ";" in header:
mime = header.split(";")[0].split(":", 1)[1] or "image/png"
return base64.b64decode(b64), mime
def _bytes_image_size(img_bytes: bytes) -> Tuple[int, int]:
try:
img = Image.open(io.BytesIO(img_bytes))
return img.size
except Exception:
return (1024, 768)
def _find_last_user_text(messages: List[Dict[str, Any]]) -> List[str]:
texts: List[str] = []
for msg in reversed(messages):
if msg.get("type") in (None, "message") and msg.get("role") == "user":
content = msg.get("content")
if isinstance(content, str):
return [content]
elif isinstance(content, list):
for c in content:
if c.get("type") in ("input_text", "output_text") and c.get("text"):
texts.append(c["text"]) # newest first
if texts:
return list(reversed(texts))
return []
def _find_last_screenshot(messages: List[Dict[str, Any]]) -> Optional[bytes]:
for msg in reversed(messages):
if msg.get("type") == "computer_call_output":
out = msg.get("output", {})
if isinstance(out, dict) and out.get("type") in ("input_image", "computer_screenshot"):
image_url = out.get("image_url", "")
if image_url:
data, _ = _data_url_to_bytes(image_url)
return data
return None
def _denormalize(v: int, size: int) -> int:
# Gemini returns 0-999 normalized
try:
return max(0, min(size - 1, int(round(v / 1000 * size))))
except Exception:
return 0
def _map_gemini_fc_to_computer_call(
fc: Dict[str, Any],
screen_w: int,
screen_h: int,
) -> Optional[Dict[str, Any]]:
name = fc.get("name")
args = fc.get("args", {}) or {}
action: Dict[str, Any] = {}
if name == "click_at":
x = _denormalize(int(args.get("x", 0)), screen_w)
y = _denormalize(int(args.get("y", 0)), screen_h)
action = {"type": "click", "x": x, "y": y, "button": "left"}
elif name == "type_text_at":
x = _denormalize(int(args.get("x", 0)), screen_w)
y = _denormalize(int(args.get("y", 0)), screen_h)
text = args.get("text", "")
if args.get("press_enter") == True:
text += "\n"
action = {"type": "type", "x": x, "y": y, "text": text}
elif name == "hover_at":
x = _denormalize(int(args.get("x", 0)), screen_w)
y = _denormalize(int(args.get("y", 0)), screen_h)
action = {"type": "move", "x": x, "y": y}
elif name == "key_combination":
keys = str(args.get("keys", ""))
action = {"type": "keypress", "keys": keys}
elif name == "scroll_document":
direction = args.get("direction", "down")
magnitude = 800
dx, dy = 0, 0
if direction == "down":
dy = magnitude
elif direction == "up":
dy = -magnitude
elif direction == "right":
dx = magnitude
elif direction == "left":
dx = -magnitude
action = {
"type": "scroll",
"scroll_x": dx,
"scroll_y": dy,
"x": int(screen_w / 2),
"y": int(screen_h / 2),
}
elif name == "scroll_at":
x = _denormalize(int(args.get("x", 500)), screen_w)
y = _denormalize(int(args.get("y", 500)), screen_h)
direction = args.get("direction", "down")
magnitude = int(args.get("magnitude", 800))
dx, dy = 0, 0
if direction == "down":
dy = magnitude
elif direction == "up":
dy = -magnitude
elif direction == "right":
dx = magnitude
elif direction == "left":
dx = -magnitude
action = {"type": "scroll", "scroll_x": dx, "scroll_y": dy, "x": x, "y": y}
elif name == "drag_and_drop":
x = _denormalize(int(args.get("x", 0)), screen_w)
y = _denormalize(int(args.get("y", 0)), screen_h)
dx = _denormalize(int(args.get("destination_x", x)), screen_w)
dy = _denormalize(int(args.get("destination_y", y)), screen_h)
action = {
"type": "drag",
"start_x": x,
"start_y": y,
"end_x": dx,
"end_y": dy,
"button": "left",
}
elif name == "wait_5_seconds":
action = {"type": "wait"}
else:
# Unsupported / excluded browser-specific or custom function; ignore
return None
return {
"type": "computer_call",
"call_id": uuid.uuid4().hex,
"status": "completed",
"action": action,
}
@register_agent(models=r"^gemini-2\.5-computer-use-preview-10-2025$")
class GeminiComputerUseConfig(AsyncAgentConfig):
async def predict_step(
self,
messages: List[Dict[str, Any]],
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs,
) -> Dict[str, Any]:
genai, types = _lazy_import_genai()
client = genai.Client()
# Build excluded predefined functions for browser-specific behavior
excluded = [
"open_web_browser",
"search",
"navigate",
"go_forward",
"go_back",
"scroll_document",
]
# Optional custom functions: can be extended by host code via `tools` parameter later if desired
CUSTOM_FUNCTION_DECLARATIONS: List[Any] = []
# Compose tools config
generate_content_config = types.GenerateContentConfig(
tools=[
types.Tool(
computer_use=types.ComputerUse(
environment=types.Environment.ENVIRONMENT_BROWSER,
excluded_predefined_functions=excluded,
)
),
# types.Tool(function_declarations=CUSTOM_FUNCTION_DECLARATIONS), # enable when custom functions needed
]
)
# Prepare contents: last user text + latest screenshot
user_texts = _find_last_user_text(messages)
screenshot_bytes = _find_last_screenshot(messages)
parts: List[Any] = []
for t in user_texts:
parts.append(types.Part(text=t))
screen_w, screen_h = 1024, 768
if screenshot_bytes:
screen_w, screen_h = _bytes_image_size(screenshot_bytes)
parts.append(types.Part.from_bytes(data=screenshot_bytes, mime_type="image/png"))
# If we don't have any content, at least pass an empty user part to prompt reasoning
if not parts:
parts = [types.Part(text="Proceed to the next action.")]
contents = [types.Content(role="user", parts=parts)]
api_kwargs = {
"model": model,
"contents": contents,
"config": generate_content_config,
}
if _on_api_start:
await _on_api_start(
{
"model": api_kwargs["model"],
# "contents": api_kwargs["contents"], # Disabled for now
"config": api_kwargs["config"],
}
)
response = client.models.generate_content(**api_kwargs)
if _on_api_end:
await _on_api_end(
{
"model": api_kwargs["model"],
# "contents": api_kwargs["contents"], # Disabled for now
"config": api_kwargs["config"],
},
response,
)
# Usage (Gemini SDK may not always provide token usage; populate when available)
usage: Dict[str, Any] = {}
try:
# Some SDKs expose response.usage; if available, copy
if getattr(response, "usage_metadata", None):
md = response.usage_metadata
usage = {
"prompt_tokens": getattr(md, "prompt_token_count", None) or 0,
"completion_tokens": getattr(md, "candidates_token_count", None) or 0,
"total_tokens": getattr(md, "total_token_count", None) or 0,
}
except Exception:
pass
if _on_usage and usage:
await _on_usage(usage)
# Parse output into internal items
output_items: List[Dict[str, Any]] = []
candidate = response.candidates[0]
# Text parts from the model (assistant message)
text_parts: List[str] = []
function_calls: List[Dict[str, Any]] = []
for p in candidate.content.parts:
if getattr(p, "text", None):
text_parts.append(p.text)
if getattr(p, "function_call", None):
# p.function_call has name and args
fc = {
"name": getattr(p.function_call, "name", None),
"args": dict(getattr(p.function_call, "args", {}) or {}),
}
function_calls.append(fc)
if text_parts:
output_items.append(
{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "\n".join(text_parts)}],
}
)
# Map function calls to internal computer_call actions
for fc in function_calls:
item = _map_gemini_fc_to_computer_call(fc, screen_w, screen_h)
if item is not None:
output_items.append(item)
return {"output": output_items, "usage": usage}
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs,
) -> Optional[Tuple[float, float]]:
"""Ask Gemini CUA to output a single click action for the given instruction.
Excludes all predefined tools except `click_at` and sends the screenshot.
Returns pixel (x, y) if a click is proposed, else None.
"""
genai, types = _lazy_import_genai()
client = genai.Client()
# Exclude all but click_at
exclude_all_but_click = [
"open_web_browser",
"wait_5_seconds",
"go_back",
"go_forward",
"search",
"navigate",
"hover_at",
"type_text_at",
"key_combination",
"scroll_document",
"scroll_at",
"drag_and_drop",
]
config = types.GenerateContentConfig(
tools=[
types.Tool(
computer_use=types.ComputerUse(
environment=types.Environment.ENVIRONMENT_BROWSER,
excluded_predefined_functions=exclude_all_but_click,
)
)
]
)
# Prepare prompt parts
try:
img_bytes = base64.b64decode(image_b64)
except Exception:
img_bytes = b""
w, h = _bytes_image_size(img_bytes) if img_bytes else (1024, 768)
parts: List[Any] = [types.Part(text=f"Click {instruction}.")]
if img_bytes:
parts.append(types.Part.from_bytes(data=img_bytes, mime_type="image/png"))
contents = [types.Content(role="user", parts=parts)]
response = client.models.generate_content(
model=model,
contents=contents,
config=config,
)
# Parse first click_at
try:
candidate = response.candidates[0]
for p in candidate.content.parts:
fc = getattr(p, "function_call", None)
if fc and getattr(fc, "name", None) == "click_at":
args = dict(getattr(fc, "args", {}) or {})
x = _denormalize(int(args.get("x", 0)), w)
y = _denormalize(int(args.get("y", 0)), h)
return float(x), float(y)
except Exception:
return None
return None
def get_capabilities(self) -> List[AgentCapability]:
return ["click", "step"]

View File

@@ -4,33 +4,36 @@ Supports vision-language models for computer control with bounding box parsing.
"""
import asyncio
import json
import base64
import json
import re
from typing import Dict, List, Any, Optional, Tuple
from io import BytesIO
from PIL import Image
from typing import Any, Dict, List, Optional, Tuple
import litellm
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
from litellm.types.utils import ModelResponse
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
from PIL import Image
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..loops.base import AsyncAgentConfig
from ..responses import (
convert_responses_items_to_completion_messages,
convert_completion_messages_to_responses_items,
make_reasoning_item,
make_output_text_item,
convert_responses_items_to_completion_messages,
make_click_item,
make_double_click_item,
make_drag_item,
make_input_image_item,
make_keypress_item,
make_output_text_item,
make_reasoning_item,
make_scroll_item,
make_type_item,
make_wait_item,
make_input_image_item
)
from ..types import AgentCapability, AgentResponse, Messages, Tools
# GLM-4.5V specific constants
GLM_ACTION_SPACE = """
@@ -251,16 +254,18 @@ Call rule: `FAIL()`
}
}"""
def encode_image_to_base64(image_path: str) -> str:
"""Encode image file to base64 string with data URI."""
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return f"data:image/png;base64,{encoded_string}"
def parse_glm_response(response: str) -> Dict[str, Any]:
"""
Parse GLM-4.5V response to extract action and memory.
The special tokens <|begin_of_box|> and <|end_of_box|> mark bounding boxes.
Coordinates are normalized values between 0 and 1000.
"""
@@ -274,26 +279,23 @@ def parse_glm_response(response: str) -> Dict[str, Any]:
action_pattern = r"[\w_]+\([^)]*\)"
matches = re.findall(action_pattern, response)
action = matches[0] if matches else None
# Extract memory section
memory_pattern = r"Memory:(.*?)$"
memory_match = re.search(memory_pattern, response, re.DOTALL)
memory = memory_match.group(1).strip() if memory_match else "[]"
# Extract action text (everything before Memory:)
action_text_pattern = r'^(.*?)Memory:'
action_text_pattern = r"^(.*?)Memory:"
action_text_match = re.search(action_text_pattern, response, re.DOTALL)
action_text = action_text_match.group(1).strip() if action_text_match else response
# Clean up action text by removing special tokens
if action_text:
action_text = action_text.replace("<|begin_of_box|>", "").replace("<|end_of_box|>", "")
return {
"action": action,
"action_text": action_text,
"memory": memory
}
return {"action": action, "action_text": action_text, "memory": memory}
def get_last_image_from_messages(messages: Messages) -> Optional[str]:
"""Extract the last image from messages for processing."""
@@ -314,23 +316,28 @@ def get_last_image_from_messages(messages: Messages) -> Optional[str]:
image_url_obj = item.get("image_url", {})
if isinstance(image_url_obj, dict):
image_url = image_url_obj.get("url", "")
if isinstance(image_url, str) and image_url.startswith("data:image/"):
if isinstance(image_url, str) and image_url.startswith(
"data:image/"
):
return image_url.split(",", 1)[1]
return None
def convert_responses_items_to_glm45v_pc_prompt(messages: Messages, task: str, memory: str = "") -> List[Dict[str, Any]]:
def convert_responses_items_to_glm45v_pc_prompt(
messages: Messages, task: str, memory: str = ""
) -> List[Dict[str, Any]]:
"""Convert responses items to GLM-4.5V PC prompt format with historical actions.
Args:
messages: List of message items from the conversation
task: The task description
memory: Current memory state
Returns:
List of content items for the prompt (text and image_url items)
"""
action_space = GLM_ACTION_SPACE
# Template head
head_text = f"""You are a GUI Agent, and your primary task is to respond accurately to user requests or questions. In addition to directly answering the user's queries, you can also use tools or perform GUI operations directly until you fulfill the user's request or provide a correct answer. You should carefully read and understand the images and questions provided by the user, and engage in thinking and reflection when appropriate. The coordinates involved are all represented in thousandths (0-999).
@@ -345,7 +352,7 @@ Ubuntu
# Historical Actions and Current Memory
History:"""
# Template tail
tail_text = f"""
Memory:
@@ -363,18 +370,18 @@ Memory:
Current Screenshot:
"""
# Build history from messages
history = []
history_images = []
# Group messages into steps
current_step = []
step_num = 0
for message in messages:
msg_type = message.get("type")
if msg_type == "reasoning":
current_step.append(message)
elif msg_type == "message" and message.get("role") == "assistant":
@@ -386,7 +393,7 @@ Current Screenshot:
# End of step - process it
if current_step:
step_num += 1
# Extract bot thought from message content
bot_thought = ""
for item in current_step:
@@ -397,14 +404,14 @@ Current Screenshot:
bot_thought = content_item.get("text", "")
break
break
# Extract action from computer_call
action_text = ""
for item in current_step:
if item.get("type") == "computer_call":
action = item.get("action", {})
action_type = action.get("type", "")
if action_type == "click":
x, y = action.get("x", 0), action.get("y", 0)
# Convert to 0-999 range (assuming screen dimensions)
@@ -436,7 +443,7 @@ Current Screenshot:
elif action_type == "wait":
action_text = "WAIT()"
break
# Extract screenshot from computer_call_output
screenshot_url = None
for item in current_step:
@@ -445,34 +452,34 @@ Current Screenshot:
if output.get("type") == "input_image":
screenshot_url = output.get("image_url", "")
break
# Store step info
step_info = {
"step_num": step_num,
"bot_thought": bot_thought,
"action_text": action_text,
"screenshot_url": screenshot_url
"screenshot_url": screenshot_url,
}
history.append(step_info)
# Store screenshot for last 4 steps
if screenshot_url:
history_images.append(screenshot_url)
current_step = []
# Build content array with head, history, and tail
content = []
current_text = head_text
total_history_steps = len(history)
history_image_count = min(4, len(history_images)) # Last 4 images
for step_idx, step_info in enumerate(history):
step_num = step_info["step_num"]
bot_thought = step_info["bot_thought"]
action_text = step_info["action_text"]
if step_idx < total_history_steps - history_image_count:
# For steps beyond the last 4, use text placeholder
current_text += f"\nstep {step_num}: Screenshot:(Omitted in context.) Thought: {bot_thought}\nAction: {action_text}"
@@ -480,20 +487,21 @@ Current Screenshot:
# For the last 4 steps, insert images
current_text += f"\nstep {step_num}: Screenshot:"
content.append({"type": "text", "text": current_text})
# Add image
img_idx = step_idx - (total_history_steps - history_image_count)
if img_idx < len(history_images):
content.append({"type": "image_url", "image_url": {"url": history_images[img_idx]}})
current_text = f" Thought: {bot_thought}\nAction: {action_text}"
# Add tail
current_text += tail_text
content.append({"type": "text", "text": current_text})
return content
def model_dump(obj) -> Dict[str, Any]:
if isinstance(obj, dict):
return {k: model_dump(v) for k, v in obj.items()}
@@ -502,58 +510,61 @@ def model_dump(obj) -> Dict[str, Any]:
else:
return obj
def convert_glm_completion_to_responses_items(response: ModelResponse, image_width: int, image_height: int) -> List[Dict[str, Any]]:
def convert_glm_completion_to_responses_items(
response: ModelResponse, image_width: int, image_height: int
) -> List[Dict[str, Any]]:
"""
Convert GLM-4.5V completion response to responses items format.
Args:
response: LiteLLM ModelResponse from GLM-4.5V
image_width: Original image width for coordinate scaling
image_height: Original image height for coordinate scaling
Returns:
List of response items in the proper format
"""
import uuid
response_items = []
if not response.choices or not response.choices[0].message:
return response_items
message = response.choices[0].message
content = message.content or ""
reasoning_content = getattr(message, 'reasoning_content', None)
reasoning_content = getattr(message, "reasoning_content", None)
# Add reasoning item if present
if reasoning_content:
reasoning_item = model_dump(make_reasoning_item(reasoning_content))
response_items.append(reasoning_item)
# Parse the content to extract action and text
parsed_response = parse_glm_response(content)
action = parsed_response.get("action", "")
action_text = parsed_response.get("action_text", "")
# Add message item with text content (excluding action and memory)
if action_text:
# Remove action from action_text if it's there
clean_text = action_text
if action and action in clean_text:
clean_text = clean_text.replace(action, "").strip()
# Remove memory section
memory_pattern = r"Memory:\s*\[.*?\]\s*$"
clean_text = re.sub(memory_pattern, "", clean_text, flags=re.DOTALL).strip()
if clean_text:
message_item = model_dump(make_output_text_item(clean_text))
response_items.append(message_item)
# Convert action to computer call if present
if action:
call_id = f"call_{uuid.uuid4().hex[:8]}"
# Parse different action types and create appropriate computer calls
if action.startswith("left_click"):
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
@@ -566,7 +577,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action.startswith("right_click"):
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
if coord_match:
@@ -577,7 +588,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action.startswith("left_double_click"):
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
if coord_match:
@@ -588,7 +599,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action.startswith("left_drag"):
start_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
end_match = re.search(r"end_box='?\[(\d+),\s*(\d+)\]'?", action)
@@ -605,18 +616,18 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action.startswith("key"):
key_match = re.search(r"keys='([^']+)'", action)
if key_match:
keys = key_match.group(1)
# Split keys by '+' for key combinations, or use as single key
key_list = keys.split('+') if '+' in keys else [keys]
key_list = keys.split("+") if "+" in keys else [keys]
computer_call = model_dump(make_keypress_item(key_list))
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action.startswith("type"):
content_match = re.search(r"content='([^']*)'", action)
if content_match:
@@ -625,7 +636,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action.startswith("scroll"):
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
direction_match = re.search(r"direction='([^']+)'", action)
@@ -648,15 +659,16 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action == "WAIT()":
computer_call = model_dump(make_wait_item())
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
return response_items
@register_agent(models=r"(?i).*GLM-4\.5V.*")
class Glm4vConfig(AsyncAgentConfig):
"""GLM-4.5V agent configuration using liteLLM."""
@@ -674,11 +686,11 @@ class Glm4vConfig(AsyncAgentConfig):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""
Predict the next step using GLM-4.5V model.
Args:
messages: Input messages following Responses format
model: Model name to use
@@ -691,7 +703,7 @@ class Glm4vConfig(AsyncAgentConfig):
_on_api_end: Callback for API end
_on_usage: Callback for usage tracking
_on_screenshot: Callback for screenshot events
Returns:
Dict with "output" and "usage" keys
"""
@@ -708,7 +720,7 @@ class Glm4vConfig(AsyncAgentConfig):
user_instruction = item.get("text", "")
break
break
# Get the last image for processing
last_image_b64 = get_last_image_from_messages(messages)
if not last_image_b64 and computer_handler:
@@ -718,35 +730,28 @@ class Glm4vConfig(AsyncAgentConfig):
last_image_b64 = screenshot_b64
if _on_screenshot:
await _on_screenshot(screenshot_b64)
if not last_image_b64:
raise ValueError("No image available for GLM-4.5V processing")
# Convert responses items to GLM-4.5V PC prompt format with historical actions
prompt_content = convert_responses_items_to_glm45v_pc_prompt(
messages=messages,
task=user_instruction,
memory="[]" # Initialize with empty memory for now
memory="[]", # Initialize with empty memory for now
)
# Add the current screenshot to the end
prompt_content.append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{last_image_b64}"}
})
prompt_content.append(
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{last_image_b64}"}}
)
# Prepare messages for liteLLM
litellm_messages = [
{
"role": "system",
"content": "You are a helpful GUI agent assistant."
},
{
"role": "user",
"content": prompt_content
}
{"role": "system", "content": "You are a helpful GUI agent assistant."},
{"role": "user", "content": prompt_content},
]
# Prepare API call kwargs
api_kwargs = {
"model": model,
@@ -757,20 +762,20 @@ class Glm4vConfig(AsyncAgentConfig):
# "skip_special_tokens": False,
# }
}
# Add API callbacks
if _on_api_start:
await _on_api_start(api_kwargs)
# Call liteLLM
response = await litellm.acompletion(**api_kwargs)
if _on_api_end:
await _on_api_end(api_kwargs, response)
# Get image dimensions for coordinate scaling
image_width, image_height = 1920, 1080 # Default dimensions
# Try to get actual dimensions from the image
try:
image_data = base64.b64decode(last_image_b64)
@@ -778,41 +783,38 @@ class Glm4vConfig(AsyncAgentConfig):
image_width, image_height = image.size
except Exception:
pass # Use default dimensions
# Convert GLM completion response to responses items
response_items = convert_glm_completion_to_responses_items(response, image_width, image_height)
response_items = convert_glm_completion_to_responses_items(
response, image_width, image_height
)
# Extract usage information
response_usage = {
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(
response.usage
).model_dump(),
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(response_usage)
# Create agent response
agent_response = {
"output": response_items,
"usage": response_usage
}
agent_response = {"output": response_items, "usage": response_usage}
return agent_response
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates using GLM-4.5V model.
Args:
model: Model name to use
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
Tuple with (x, y) coordinates or None
"""
@@ -824,22 +826,22 @@ Respond with a single click action in this format:
left_click(start_box='[x,y]')
Where x,y are coordinates normalized to 0-999 range."""
# Prepare messages for liteLLM
litellm_messages = [
{
"role": "system",
"content": "You are a helpful GUI agent assistant."
},
{"role": "system", "content": "You are a helpful GUI agent assistant."},
{
"role": "user",
"content": [
{"type": "text", "text": click_prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
]
}
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
},
],
},
]
# Prepare API call kwargs
api_kwargs = {
"model": model,
@@ -848,21 +850,21 @@ Where x,y are coordinates normalized to 0-999 range."""
"temperature": 0.001,
"extra_body": {
"skip_special_tokens": False,
}
},
}
# Call liteLLM
response = await litellm.acompletion(**api_kwargs)
# Extract response content
response_content = response.choices[0].message.content.strip()
print(response)
# Parse response for click coordinates
# Look for coordinates in the response, handling special tokens
coord_pattern = r"<\|begin_of_box\|>.*?left_click\(start_box='?\[(\d+),(\d+)\]'?\).*?<\|end_of_box\|>"
match = re.search(coord_pattern, response_content)
if not match:
# Fallback: look for coordinates without special tokens
coord_pattern = r"left_click\(start_box='?\[(\d+),(\d+)\]'?\)"
@@ -870,7 +872,7 @@ Where x,y are coordinates normalized to 0-999 range."""
if match:
x, y = int(match.group(1)), int(match.group(2))
# Get actual image dimensions for scaling
try:
image_data = base64.b64decode(image_b64)
@@ -879,15 +881,15 @@ Where x,y are coordinates normalized to 0-999 range."""
except Exception:
# Use default dimensions
image_width, image_height = 1920, 1080
# Convert from 0-999 normalized coordinates to actual pixel coordinates
actual_x = int((x / 999.0) * image_width)
actual_y = int((y / 999.0) * image_height)
return (actual_x, actual_y)
return None
except Exception as e:
# Log error and return None
print(f"Error in predict_click: {e}")
@@ -896,7 +898,7 @@ Where x,y are coordinates normalized to 0-999 range."""
def get_capabilities(self) -> List[AgentCapability]:
"""
Get list of capabilities supported by this agent config.
Returns:
List of capability strings
"""

View File

@@ -5,75 +5,80 @@ Code: https://github.com/Yan98/GTA1
"""
import asyncio
import json
import re
import base64
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
from io import BytesIO
import uuid
from PIL import Image
import litellm
import json
import math
import re
import uuid
from io import BytesIO
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import litellm
from PIL import Image
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..loops.base import AsyncAgentConfig
from ..types import AgentCapability, AgentResponse, Messages, Tools
SYSTEM_PROMPT = '''
SYSTEM_PROMPT = """
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
Output the coordinate pair exactly:
(x,y)
'''.strip()
""".strip()
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
"""Extract coordinates from model output."""
try:
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
return tuple(map(float, matches[0])) # type: ignore
return tuple(map(float, matches[0])) # type: ignore
except:
return (0.0, 0.0)
def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360) -> Tuple[int, int]:
def smart_resize(
height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360
) -> Tuple[int, int]:
"""Smart resize function similar to qwen_vl_utils."""
# Calculate the total pixels
total_pixels = height * width
# If already within bounds, return original dimensions
if min_pixels <= total_pixels <= max_pixels:
# Round to nearest factor
new_height = (height // factor) * factor
new_width = (width // factor) * factor
return new_height, new_width
# Calculate scaling factor
if total_pixels > max_pixels:
scale = (max_pixels / total_pixels) ** 0.5
else:
scale = (min_pixels / total_pixels) ** 0.5
# Apply scaling
new_height = int(height * scale)
new_width = int(width * scale)
# Round to nearest factor
new_height = (new_height // factor) * factor
new_width = (new_width // factor) * factor
# Ensure minimum size
new_height = max(new_height, factor)
new_width = max(new_width, factor)
return new_height, new_width
@register_agent(models=r".*GTA1.*")
class GTA1Config(AsyncAgentConfig):
"""GTA1 agent configuration implementing AsyncAgentConfig protocol for click prediction."""
def __init__(self):
self.current_model = None
self.last_screenshot_b64 = None
async def predict_step(
self,
@@ -87,25 +92,21 @@ class GTA1Config(AsyncAgentConfig):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
raise NotImplementedError()
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[float, float]]:
"""
Predict click coordinates using GTA1 model via litellm.acompletion.
Args:
model: The GTA1 model name
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
@@ -113,66 +114,62 @@ class GTA1Config(AsyncAgentConfig):
image_data = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_data))
width, height = image.width, image.height
# Smart resize the image (similar to qwen_vl_utils)
resized_height, resized_width = smart_resize(
height, width,
height,
width,
factor=28, # Default factor for Qwen models
min_pixels=3136,
max_pixels=4096 * 2160
max_pixels=4096 * 2160,
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height
# Convert resized image back to base64
buffered = BytesIO()
resized_image.save(buffered, format="PNG")
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
# Prepare system and user messages
system_message = {
"role": "system",
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width)
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width),
}
user_message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{resized_image_b64}"
}
"image_url": {"url": f"data:image/png;base64,{resized_image_b64}"},
},
{
"type": "text",
"text": instruction
}
]
{"type": "text", "text": instruction},
],
}
# Prepare API call kwargs
api_kwargs = {
"model": model,
"messages": [system_message, user_message],
"max_tokens": 2056,
"temperature": 0.0,
**kwargs
**kwargs,
}
# Use liteLLM acompletion
response = await litellm.acompletion(**api_kwargs)
# Extract response text
output_text = response.choices[0].message.content # type: ignore
output_text = response.choices[0].message.content # type: ignore
# Extract and rescale coordinates
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
pred_x *= scale_x
pred_y *= scale_y
return (math.floor(pred_x), math.floor(pred_y))
def get_capabilities(self) -> List[AgentCapability]:
"""Return the capabilities supported by this agent."""
return ["click"]

View File

@@ -21,8 +21,8 @@ import litellm
from PIL import Image
from ..decorators import register_agent
from .base import AsyncAgentConfig
from ..types import AgentCapability
from .base import AsyncAgentConfig
def _strip_hf_prefix(model: str) -> str:
@@ -53,7 +53,9 @@ def _maybe_smart_resize(image: Image.Image, model: str) -> Tuple[Image.Image, Tu
if image_processor is None:
return image, (orig_w, orig_h)
factor = getattr(image_processor, "patch_size", 14) * getattr(image_processor, "merge_size", 1)
factor = getattr(image_processor, "patch_size", 14) * getattr(
image_processor, "merge_size", 1
)
min_pixels = getattr(image_processor, "min_pixels", 256 * 256)
max_pixels = getattr(image_processor, "max_pixels", 1536 * 1536)

View File

@@ -18,13 +18,12 @@ import re
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple
from PIL import Image
import litellm
from PIL import Image
from ..decorators import register_agent
from .composed_grounded import ComposedGroundedConfig
from ..types import AgentCapability
from .composed_grounded import ComposedGroundedConfig
# Regex patterns for extracting coordinates
# Accept optional whitespace and optional decimal fractions
@@ -91,7 +90,7 @@ class InternVLConfig(ComposedGroundedConfig):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""Fallback to a self-composed model"""
return await super().predict_step(
@@ -105,15 +104,11 @@ class InternVLConfig(ComposedGroundedConfig):
_on_api_end=_on_api_end,
_on_usage=_on_usage,
_on_screenshot=_on_screenshot,
**kwargs
**kwargs,
)
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates using InternVL via litellm.acompletion.

View File

@@ -0,0 +1,493 @@
"""
Moondream3+ composed-grounded agent loop implementation.
Grounding is handled by a local Moondream3 preview model via Transformers.
Thinking is delegated to the trailing LLM in the composed model string: "moondream3+<thinking_model>".
Differences from composed_grounded:
- Provides a singleton Moondream3 client outside the class.
- predict_click uses model.point(image, instruction, settings={"max_objects": 1}) and returns pixel coordinates.
- If the last image was a screenshot (or we take one), run model.detect(image, "all form ui") to get bboxes, then
run model.caption on each cropped bbox to label it. Overlay labels on the screenshot and emit via _on_screenshot.
- Add a user message listing all detected form UI names so the thinker can reference them.
- If the thinking model doesn't support vision, filter out image content before calling litellm.
"""
from __future__ import annotations
import base64
import io
import uuid
from typing import Any, Dict, List, Optional, Tuple
import litellm
from PIL import Image, ImageDraw, ImageFont
from ..decorators import register_agent
from ..loops.base import AsyncAgentConfig
from ..responses import (
convert_completion_messages_to_responses_items,
convert_computer_calls_desc2xy,
convert_computer_calls_xy2desc,
convert_responses_items_to_completion_messages,
get_all_element_descriptions,
)
from ..types import AgentCapability
_MOONDREAM_SINGLETON = None
def get_moondream_model() -> Any:
"""Get a singleton instance of the Moondream3 preview model."""
global _MOONDREAM_SINGLETON
if _MOONDREAM_SINGLETON is None:
try:
import torch
from transformers import AutoModelForCausalLM
_MOONDREAM_SINGLETON = AutoModelForCausalLM.from_pretrained(
"moondream/moondream3-preview",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="cuda",
)
except ImportError as e:
raise RuntimeError(
"moondream3 requires torch and transformers. Install with: pip install cua-agent[moondream3]"
) from e
return _MOONDREAM_SINGLETON
def _decode_image_b64(image_b64: str) -> Image.Image:
data = base64.b64decode(image_b64)
return Image.open(io.BytesIO(data)).convert("RGB")
def _image_to_b64(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def _supports_vision(model: str) -> bool:
"""Heuristic vision support detection for thinking model."""
m = model.lower()
vision_markers = [
"gpt-4o",
"gpt-4.1",
"o1",
"o3",
"claude-3",
"claude-3.5",
"sonnet",
"haiku",
"opus",
"gemini-1.5",
"llava",
]
return any(v in m for v in vision_markers)
def _filter_images_from_completion_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
filtered: List[Dict[str, Any]] = []
for msg in messages:
msg_copy = {**msg}
content = msg_copy.get("content")
if isinstance(content, list):
msg_copy["content"] = [c for c in content if c.get("type") != "image_url"]
filtered.append(msg_copy)
return filtered
def _annotate_detect_and_label_ui(base_img: Image.Image, model_md) -> Tuple[str, List[str]]:
"""Detect UI elements with Moondream, caption each, draw labels with backgrounds.
Args:
base_img: PIL image of the screenshot (RGB or RGBA). Will be copied/converted internally.
model_md: Moondream model instance with .detect() and .query() methods.
Returns:
A tuple of (annotated_image_base64_png, detected_names)
"""
# Ensure RGBA for semi-transparent fills
if base_img.mode != "RGBA":
base_img = base_img.convert("RGBA")
W, H = base_img.width, base_img.height
# Detect objects
try:
detect_result = model_md.detect(base_img, "all ui elements")
objects = detect_result.get("objects", []) if isinstance(detect_result, dict) else []
except Exception:
objects = []
draw = ImageDraw.Draw(base_img)
try:
font = ImageFont.load_default()
except Exception:
font = None
detected_names: List[str] = []
for i, obj in enumerate(objects):
try:
# Clamp normalized coords and crop
x_min = max(0.0, min(1.0, float(obj.get("x_min", 0.0))))
y_min = max(0.0, min(1.0, float(obj.get("y_min", 0.0))))
x_max = max(0.0, min(1.0, float(obj.get("x_max", 0.0))))
y_max = max(0.0, min(1.0, float(obj.get("y_max", 0.0))))
left, top, right, bottom = (
int(x_min * W),
int(y_min * H),
int(x_max * W),
int(y_max * H),
)
left, top = max(0, left), max(0, top)
right, bottom = min(W - 1, right), min(H - 1, bottom)
crop = base_img.crop((left, top, right, bottom))
# Prompted short caption
try:
result = model_md.query(crop, "Caption this UI element in few words.")
caption_text = (result or {}).get("answer", "")
except Exception:
caption_text = ""
name = (caption_text or "").strip() or f"element_{i+1}"
detected_names.append(name)
# Draw bbox
draw.rectangle([left, top, right, bottom], outline=(255, 215, 0, 255), width=2)
# Label background with padding and rounded corners
label = f"{i+1}. {name}"
padding = 3
if font:
text_bbox = draw.textbbox((0, 0), label, font=font)
else:
text_bbox = draw.textbbox((0, 0), label)
text_w = text_bbox[2] - text_bbox[0]
text_h = text_bbox[3] - text_bbox[1]
tx = left + 3
ty = top - (text_h + 2 * padding + 4)
if ty < 0:
ty = top + 3
bg_left = tx - padding
bg_top = ty - padding
bg_right = tx + text_w + padding
bg_bottom = ty + text_h + padding
try:
draw.rounded_rectangle(
[bg_left, bg_top, bg_right, bg_bottom],
radius=4,
fill=(0, 0, 0, 160),
outline=(255, 215, 0, 200),
width=1,
)
except Exception:
draw.rectangle(
[bg_left, bg_top, bg_right, bg_bottom],
fill=(0, 0, 0, 160),
outline=(255, 215, 0, 200),
width=1,
)
text_fill = (255, 255, 255, 255)
if font:
draw.text((tx, ty), label, fill=text_fill, font=font)
else:
draw.text((tx, ty), label, fill=text_fill)
except Exception:
continue
# Encode PNG base64
annotated = base_img
if annotated.mode not in ("RGBA", "RGB"):
annotated = annotated.convert("RGBA")
annotated_b64 = _image_to_b64(annotated)
return annotated_b64, detected_names
GROUNDED_COMPUTER_TOOL_SCHEMA = {
"type": "function",
"function": {
"name": "computer",
"description": (
"Control a computer by taking screenshots and interacting with UI elements. "
"The screenshot action will include a list of detected form UI element names when available. "
"Use element descriptions to locate and interact with UI elements on the screen."
),
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment",
],
"description": "The action to perform (required for all actions)",
},
"element_description": {
"type": "string",
"description": "Description of the element to interact with (required for click/double_click/move/scroll)",
},
"start_element_description": {
"type": "string",
"description": "Description of the element to start dragging from (required for drag)",
},
"end_element_description": {
"type": "string",
"description": "Description of the element to drag to (required for drag)",
},
"text": {
"type": "string",
"description": "The text to type (required for type)",
},
"keys": {
"type": "array",
"items": {"type": "string"},
"description": "Key(s) to press (required for keypress)",
},
"button": {
"type": "string",
"enum": ["left", "right", "wheel", "back", "forward"],
"description": "The mouse button to use for click/double_click",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount (required for scroll)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount (required for scroll)",
},
},
"required": ["action"],
},
},
}
@register_agent(r"moondream3\+.*", priority=2)
class Moondream3PlusConfig(AsyncAgentConfig):
def __init__(self):
self.desc2xy: Dict[str, Tuple[float, float]] = {}
async def predict_step(
self,
messages: List[Dict[str, Any]],
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs,
) -> Dict[str, Any]:
# Parse composed model: moondream3+<thinking_model>
if "+" not in model:
raise ValueError(f"Composed model must be 'moondream3+<thinking_model>', got: {model}")
_, thinking_model = model.split("+", 1)
pre_output_items: List[Dict[str, Any]] = []
# Acquire last screenshot; if missing, take one
last_image_b64: Optional[str] = None
for message in reversed(messages):
if (
isinstance(message, dict)
and message.get("type") == "computer_call_output"
and isinstance(message.get("output"), dict)
and message["output"].get("type") == "input_image"
):
image_url = message["output"].get("image_url", "")
if image_url.startswith("data:image/png;base64,"):
last_image_b64 = image_url.split(",", 1)[1]
break
if last_image_b64 is None and computer_handler is not None:
# Take a screenshot
screenshot_b64 = await computer_handler.screenshot() # type: ignore
if screenshot_b64:
call_id = uuid.uuid4().hex
pre_output_items += [
{
"type": "message",
"role": "assistant",
"content": [
{
"type": "output_text",
"text": "Taking a screenshot to analyze the current screen.",
}
],
},
{
"type": "computer_call",
"call_id": call_id,
"status": "completed",
"action": {"type": "screenshot"},
},
{
"type": "computer_call_output",
"call_id": call_id,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_b64}",
},
},
]
last_image_b64 = screenshot_b64
if _on_screenshot:
await _on_screenshot(screenshot_b64)
# If we have a last screenshot, run Moondream detection and labeling
detected_names: List[str] = []
if last_image_b64 is not None:
base_img = _decode_image_b64(last_image_b64)
model_md = get_moondream_model()
annotated_b64, detected_names = _annotate_detect_and_label_ui(base_img, model_md)
if _on_screenshot:
await _on_screenshot(annotated_b64, "annotated_form_ui")
# Also push a user message listing all detected names
if detected_names:
names_text = "\n".join(f"- {n}" for n in detected_names)
pre_output_items.append(
{
"type": "message",
"role": "user",
"content": [
{"type": "input_text", "text": "Detected form UI elements on screen:"},
{"type": "input_text", "text": names_text},
{
"type": "input_text",
"text": "Please continue with the next action needed to perform your task.",
},
],
}
)
tool_schemas = []
for schema in tools or []:
if schema.get("type") == "computer":
tool_schemas.append(GROUNDED_COMPUTER_TOOL_SCHEMA)
else:
tool_schemas.append(schema)
# Step 1: Convert computer calls from xy to descriptions
input_messages = messages + pre_output_items
messages_with_descriptions = convert_computer_calls_xy2desc(input_messages, self.desc2xy)
# Step 2: Convert responses items to completion messages
completion_messages = convert_responses_items_to_completion_messages(
messages_with_descriptions,
allow_images_in_tool_results=False,
)
# Optionally filter images if model lacks vision
if not _supports_vision(thinking_model):
completion_messages = _filter_images_from_completion_messages(completion_messages)
# Step 3: Call thinking model with litellm.acompletion
api_kwargs = {
"model": thinking_model,
"messages": completion_messages,
"tools": tool_schemas,
"max_retries": max_retries,
"stream": stream,
**kwargs,
}
if use_prompt_caching:
api_kwargs["use_prompt_caching"] = use_prompt_caching
if _on_api_start:
await _on_api_start(api_kwargs)
response = await litellm.acompletion(**api_kwargs)
if _on_api_end:
await _on_api_end(api_kwargs, response)
usage = {
**response.usage.model_dump(), # type: ignore
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(usage)
# Step 4: Convert completion messages back to responses items format
response_dict = response.model_dump() # type: ignore
choice_messages = [choice["message"] for choice in response_dict["choices"]]
thinking_output_items: List[Dict[str, Any]] = []
for choice_message in choice_messages:
thinking_output_items.extend(
convert_completion_messages_to_responses_items([choice_message])
)
# Step 5: Use Moondream to get coordinates for each description
element_descriptions = get_all_element_descriptions(thinking_output_items)
if element_descriptions and last_image_b64:
for desc in element_descriptions:
for _ in range(3): # try 3 times
coords = await self.predict_click(
model=model,
image_b64=last_image_b64,
instruction=desc,
)
if coords:
self.desc2xy[desc] = coords
break
# Step 6: Convert computer calls from descriptions back to xy coordinates
final_output_items = convert_computer_calls_desc2xy(thinking_output_items, self.desc2xy)
# Step 7: Return output and usage
return {"output": pre_output_items + final_output_items, "usage": usage}
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs,
) -> Optional[Tuple[float, float]]:
"""Predict click coordinates using Moondream3's point API.
Returns pixel coordinates (x, y) as floats.
"""
img = _decode_image_b64(image_b64)
W, H = img.width, img.height
model_md = get_moondream_model()
try:
result = model_md.point(img, instruction, settings={"max_objects": 1})
except Exception:
return None
try:
pt = (result or {}).get("points", [])[0]
x_norm = float(pt.get("x", 0.0))
y_norm = float(pt.get("y", 0.0))
x_px = max(0.0, min(float(W - 1), x_norm * W))
y_px = max(0.0, min(float(H - 1), y_norm * H))
return (x_px, y_px)
except Exception:
return None
def get_capabilities(self) -> List[AgentCapability]:
return ["click", "step"]

View File

@@ -5,100 +5,102 @@ Code: https://github.com/microsoft/OmniParser
"""
import asyncio
import json
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
import litellm
import inspect
import base64
import inspect
import json
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import litellm
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..loops.base import AsyncAgentConfig
from ..types import AgentCapability, AgentResponse, Messages, Tools
SOM_TOOL_SCHEMA = {
"type": "function",
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment"
],
"description": "The action to perform"
},
"element_id": {
"type": "integer",
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
},
"start_element_id": {
"type": "integer",
"description": "The ID of the element to start dragging from (required for drag action)"
},
"end_element_id": {
"type": "integer",
"description": "The ID of the element to drag to (required for drag action)"
},
"text": {
"type": "string",
"description": "The text to type (required for type action)"
},
"keys": {
"type": "string",
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
},
"button": {
"type": "string",
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
},
"type": "function",
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment",
],
"description": "The action to perform",
},
"element_id": {
"type": "integer",
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)",
},
"start_element_id": {
"type": "integer",
"description": "The ID of the element to start dragging from (required for drag action)",
},
"end_element_id": {
"type": "integer",
"description": "The ID of the element to drag to (required for drag action)",
},
"text": {
"type": "string",
"description": "The text to type (required for type action)",
},
"keys": {
"type": "string",
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')",
},
"button": {
"type": "string",
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
},
},
"required": ["action"],
},
"required": [
"action"
]
}
}
OMNIPARSER_AVAILABLE = False
try:
from som import OmniParser
OMNIPARSER_AVAILABLE = True
except ImportError:
pass
OMNIPARSER_SINGLETON = None
def get_parser():
global OMNIPARSER_SINGLETON
if OMNIPARSER_SINGLETON is None:
OMNIPARSER_SINGLETON = OmniParser()
return OMNIPARSER_SINGLETON
def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Get the last computer_call_output message from a messages list.
Args:
messages: List of messages to search through
Returns:
The last computer_call_output message dict, or None if not found
"""
@@ -107,11 +109,12 @@ def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Di
return message
return None
def _prepare_tools_for_omniparser(tool_schemas: List[Dict[str, Any]]) -> Tuple[Tools, dict]:
"""Prepare tools for OpenAI API format"""
omniparser_tools = []
id2xy = dict()
for schema in tool_schemas:
if schema["type"] == "computer":
omniparser_tools.append(SOM_TOOL_SCHEMA)
@@ -122,72 +125,80 @@ def _prepare_tools_for_omniparser(tool_schemas: List[Dict[str, Any]]) -> Tuple[T
elif schema["type"] == "function":
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
# Schema should be: {type, name, description, parameters}
omniparser_tools.append({ "type": "function", **schema["function"] })
omniparser_tools.append({"type": "function", **schema["function"]})
return omniparser_tools, id2xy
async def replace_function_with_computer_call(item: Dict[str, Any], id2xy: Dict[int, Tuple[float, float]]):
item_type = item.get("type")
def _get_xy(element_id: Optional[int]) -> Union[Tuple[float, float], Tuple[None, None]]:
if element_id is None:
return (None, None)
return id2xy.get(element_id, (None, None))
async def replace_function_with_computer_call(
item: Dict[str, Any], id2xy: Dict[int, Tuple[float, float]]
):
item_type = item.get("type")
if item_type == "function_call":
fn_name = item.get("name")
fn_args = json.loads(item.get("arguments", "{}"))
def _get_xy(element_id: Optional[int]) -> Union[Tuple[float, float], Tuple[None, None]]:
if element_id is None:
return (None, None)
return id2xy.get(element_id, (None, None))
item_id = item.get("id")
call_id = item.get("call_id")
if fn_name == "computer":
action = fn_args.get("action")
element_id = fn_args.get("element_id")
start_element_id = fn_args.get("start_element_id")
end_element_id = fn_args.get("end_element_id")
text = fn_args.get("text")
keys = fn_args.get("keys")
button = fn_args.get("button")
scroll_x = fn_args.get("scroll_x")
scroll_y = fn_args.get("scroll_y")
if item_type == "function_call":
fn_name = item.get("name")
fn_args = json.loads(item.get("arguments", "{}"))
x, y = _get_xy(element_id)
start_x, start_y = _get_xy(start_element_id)
end_x, end_y = _get_xy(end_element_id)
item_id = item.get("id")
call_id = item.get("call_id")
action_args = {
"type": action,
"x": x,
"y": y,
"start_x": start_x,
"start_y": start_y,
"end_x": end_x,
"end_y": end_y,
"text": text,
"keys": keys,
"button": button,
"scroll_x": scroll_x,
"scroll_y": scroll_y
}
# Remove None values to keep the JSON clean
action_args = {k: v for k, v in action_args.items() if v is not None}
if fn_name == "computer":
action = fn_args.get("action")
element_id = fn_args.get("element_id")
start_element_id = fn_args.get("start_element_id")
end_element_id = fn_args.get("end_element_id")
text = fn_args.get("text")
keys = fn_args.get("keys")
button = fn_args.get("button")
scroll_x = fn_args.get("scroll_x")
scroll_y = fn_args.get("scroll_y")
return [{
"type": "computer_call",
"action": action_args,
"id": item_id,
"call_id": call_id,
"status": "completed"
}]
x, y = _get_xy(element_id)
start_x, start_y = _get_xy(start_element_id)
end_x, end_y = _get_xy(end_element_id)
return [item]
action_args = {
"type": action,
"x": x,
"y": y,
"start_x": start_x,
"start_y": start_y,
"end_x": end_x,
"end_y": end_y,
"text": text,
"keys": keys,
"button": button,
"scroll_x": scroll_x,
"scroll_y": scroll_y,
}
# Remove None values to keep the JSON clean
action_args = {k: v for k, v in action_args.items() if v is not None}
async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[Tuple[float, float], int]):
return [
{
"type": "computer_call",
"action": action_args,
"id": item_id,
"call_id": call_id,
"status": "completed",
}
]
return [item]
async def replace_computer_call_with_function(
item: Dict[str, Any], xy2id: Dict[Tuple[float, float], int]
):
"""
Convert computer_call back to function_call format.
Also handles computer_call_output -> function_call_output conversion.
Args:
item: The item to convert
xy2id: Mapping from (x, y) coordinates to element IDs
@@ -202,12 +213,12 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
if item_type == "computer_call":
action_data = item.get("action", {})
# Extract coordinates and convert back to element IDs
element_id = _get_element_id(action_data.get("x"), action_data.get("y"))
start_element_id = _get_element_id(action_data.get("start_x"), action_data.get("start_y"))
end_element_id = _get_element_id(action_data.get("end_x"), action_data.get("end_y"))
# Build function arguments
fn_args = {
"action": action_data.get("type"),
@@ -218,33 +229,36 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
"keys": action_data.get("keys"),
"button": action_data.get("button"),
"scroll_x": action_data.get("scroll_x"),
"scroll_y": action_data.get("scroll_y")
"scroll_y": action_data.get("scroll_y"),
}
# Remove None values to keep the JSON clean
fn_args = {k: v for k, v in fn_args.items() if v is not None}
return [{
"type": "function_call",
"name": "computer",
"arguments": json.dumps(fn_args),
"id": item.get("id"),
"call_id": item.get("call_id"),
"status": "completed",
# Fall back to string representation
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})"
}]
return [
{
"type": "function_call",
"name": "computer",
"arguments": json.dumps(fn_args),
"id": item.get("id"),
"call_id": item.get("call_id"),
"status": "completed",
# Fall back to string representation
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})",
}
]
elif item_type == "computer_call_output":
# Simple conversion: computer_call_output -> function_call_output
return [{
"type": "function_call_output",
"call_id": item.get("call_id"),
"content": [item.get("output")],
"id": item.get("id"),
"status": "completed"
}]
return [
{
"type": "function_call_output",
"call_id": item.get("call_id"),
"content": [item.get("output")],
"id": item.get("id"),
"status": "completed",
}
]
return [item]
@@ -252,7 +266,7 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
@register_agent(models=r"omniparser\+.*|omni\+.*", priority=2)
class OmniparserConfig(AsyncAgentConfig):
"""Omniparser agent configuration implementing AsyncAgentConfig protocol."""
async def predict_step(
self,
messages: List[Dict[str, Any]],
@@ -266,25 +280,27 @@ class OmniparserConfig(AsyncAgentConfig):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""
OpenAI computer-use-preview agent loop using liteLLM responses.
Supports OpenAI's computer use preview models.
"""
if not OMNIPARSER_AVAILABLE:
raise ValueError("omniparser loop requires som to be installed. Install it with `pip install cua-som`.")
raise ValueError(
"omniparser loop requires som to be installed. Install it with `pip install cua-som`."
)
tools = tools or []
llm_model = model.split('+')[-1]
llm_model = model.split("+")[-1]
# Prepare tools for OpenAI API
openai_tools, id2xy = _prepare_tools_for_omniparser(tools)
# Find last computer_call_output
last_computer_call_output = get_last_computer_call_output(messages) # type: ignore
last_computer_call_output = get_last_computer_call_output(messages) # type: ignore
if last_computer_call_output:
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
image_data = image_url.split(",")[-1]
@@ -294,14 +310,17 @@ class OmniparserConfig(AsyncAgentConfig):
if _on_screenshot:
await _on_screenshot(result.annotated_image_base64, "annotated_image")
for element in result.elements:
id2xy[element.id] = ((element.bbox.x1 + element.bbox.x2) / 2, (element.bbox.y1 + element.bbox.y2) / 2)
id2xy[element.id] = (
(element.bbox.x1 + element.bbox.x2) / 2,
(element.bbox.y1 + element.bbox.y2) / 2,
)
# handle computer calls -> function calls
new_messages = []
for message in messages:
if not isinstance(message, dict):
message = message.__dict__
new_messages += await replace_computer_call_with_function(message, id2xy) # type: ignore
new_messages += await replace_computer_call_with_function(message, id2xy) # type: ignore
messages = new_messages
# Prepare API call kwargs
@@ -312,13 +331,13 @@ class OmniparserConfig(AsyncAgentConfig):
"stream": stream,
"truncation": "auto",
"num_retries": max_retries,
**kwargs
**kwargs,
}
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
print(str(api_kwargs)[:1000])
# Use liteLLM responses
@@ -330,60 +349,50 @@ class OmniparserConfig(AsyncAgentConfig):
# Extract usage information
usage = {
**response.usage.model_dump(), # type: ignore
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
**response.usage.model_dump(), # type: ignore
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
}
if _on_usage:
await _on_usage(usage)
# handle som function calls -> xy computer calls
new_output = []
for i in range(len(response.output)): # type: ignore
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) # type: ignore
return {
"output": new_output,
"usage": usage
}
for i in range(len(response.output)): # type: ignore
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) # type: ignore
return {"output": new_output, "usage": usage}
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[float, float]]:
"""
Predict click coordinates using OmniParser and LLM.
Uses OmniParser to annotate the image with element IDs, then uses LLM
to identify the correct element ID based on the instruction.
"""
if not OMNIPARSER_AVAILABLE:
return None
# Parse the image with OmniParser to get annotated image and elements
parser = get_parser()
result = parser.parse(image_b64)
# Extract the LLM model from composed model string
llm_model = model.split('+')[-1]
llm_model = model.split("+")[-1]
# Create system prompt for element ID prediction
SYSTEM_PROMPT = f'''
SYSTEM_PROMPT = """
You are an expert UI element locator. Given a GUI image annotated with numerical IDs over each interactable element, along with a user's element description, provide the ID of the specified element.
The image shows UI elements with numbered overlays. Each number corresponds to a clickable/interactable element.
Output only the element ID as a single integer.
'''.strip()
""".strip()
# Prepare messages for LLM
messages = [
{
"role": "system",
"content": SYSTEM_PROMPT
},
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": [
@@ -391,31 +400,25 @@ Output only the element ID as a single integer.
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{result.annotated_image_base64}"
}
},
},
{
"type": "text",
"text": f"Find the element: {instruction}"
}
]
}
{"type": "text", "text": f"Find the element: {instruction}"},
],
},
]
# Call LLM to predict element ID
response = await litellm.acompletion(
model=llm_model,
messages=messages,
max_tokens=10,
temperature=0.1
model=llm_model, messages=messages, max_tokens=10, temperature=0.1
)
# Extract element ID from response
response_text = response.choices[0].message.content.strip() # type: ignore
response_text = response.choices[0].message.content.strip() # type: ignore
# Try to parse the element ID
try:
element_id = int(response_text)
# Find the element with this ID and return its center coordinates
for element in result.elements:
if element.id == element_id:
@@ -425,9 +428,9 @@ Output only the element ID as a single integer.
except ValueError:
# If we can't parse the ID, return None
pass
return None
def get_capabilities(self) -> List[AgentCapability]:
"""Return the capabilities supported by this agent."""
return ["step"]

View File

@@ -6,12 +6,14 @@ import asyncio
import base64
import json
from io import BytesIO
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import litellm
from PIL import Image
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..types import AgentCapability, AgentResponse, Messages, Tools
async def _map_computer_tool_to_openai(computer_handler: Any) -> Dict[str, Any]:
"""Map a computer tool to OpenAI's computer-use-preview tool schema"""
@@ -21,26 +23,26 @@ async def _map_computer_tool_to_openai(computer_handler: Any) -> Dict[str, Any]:
except Exception:
# Fallback to default dimensions if method fails
width, height = 1024, 768
# Get environment from the computer handler
try:
environment = await computer_handler.get_environment()
except Exception:
# Fallback to default environment if method fails
environment = "linux"
return {
"type": "computer_use_preview",
"display_width": width,
"display_height": height,
"environment": environment # mac, windows, linux, browser
"environment": environment, # mac, windows, linux, browser
}
async def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools:
"""Prepare tools for OpenAI API format"""
openai_tools = []
for schema in tool_schemas:
if schema["type"] == "computer":
# Map computer tool to OpenAI format
@@ -49,19 +51,19 @@ async def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools
elif schema["type"] == "function":
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
# Schema should be: {type, name, description, parameters}
openai_tools.append({ "type": "function", **schema["function"] })
openai_tools.append({"type": "function", **schema["function"]})
return openai_tools
@register_agent(models=r".*computer-use-preview.*")
@register_agent(models=r".*(^|/)computer-use-preview")
class OpenAIComputerUseConfig:
"""
OpenAI computer-use-preview agent configuration using liteLLM responses.
Supports OpenAI's computer use preview models.
"""
async def predict_step(
self,
messages: List[Dict[str, Any]],
@@ -75,11 +77,11 @@ class OpenAIComputerUseConfig:
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""
Predict the next step based on input items.
Args:
messages: Input items following Responses format
model: Model name to use
@@ -92,12 +94,12 @@ class OpenAIComputerUseConfig:
_on_usage: Callback for usage tracking
_on_screenshot: Callback for screenshot events
**kwargs: Additional arguments
Returns:
Dictionary with "output" (output items) and "usage" array
"""
tools = tools or []
# Prepare tools for OpenAI API
openai_tools = await _prepare_tools_for_openai(tools)
@@ -110,16 +112,16 @@ class OpenAIComputerUseConfig:
"reasoning": {"summary": "concise"},
"truncation": "auto",
"num_retries": max_retries,
**kwargs
**kwargs,
}
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
# Use liteLLM responses
response = await litellm.aresponses(**api_kwargs)
# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)
@@ -136,24 +138,21 @@ class OpenAIComputerUseConfig:
output_dict = response.model_dump()
output_dict["usage"] = usage
return output_dict
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str
self, model: str, image_b64: str, instruction: str
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates based on image and instruction.
Uses OpenAI computer-use-preview with manually constructed input items
and a prompt that instructs the agent to only output clicks.
Args:
model: Model name to use
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
@@ -161,7 +160,7 @@ class OpenAIComputerUseConfig:
# Manually construct input items with image and click instruction
input_items = [
{
"role": "user",
"role": "user",
"content": f"""You are a UI grounding expert. Follow these guidelines:
1. NEVER ask for confirmation. Complete all tasks autonomously.
@@ -173,19 +172,16 @@ class OpenAIComputerUseConfig:
7. Be decisive and action-oriented. Complete the requested task fully.
Remember: You are expected to complete tasks autonomously. The user trusts you to do what they asked.
Task: Click {instruction}. Output ONLY a click action on the target element."""
Task: Click {instruction}. Output ONLY a click action on the target element.""",
},
{
"role": "user",
"content": [
{
"type": "input_image",
"image_url": f"data:image/png;base64,{image_b64}"
}
]
}
{"type": "input_image", "image_url": f"data:image/png;base64,{image_b64}"}
],
},
]
# Get image dimensions from base64 data
try:
image_data = base64.b64decode(image_b64)
@@ -194,15 +190,15 @@ Task: Click {instruction}. Output ONLY a click action on the target element."""
except Exception:
# Fallback to default dimensions if image parsing fails
display_width, display_height = 1024, 768
# Prepare computer tool for click actions
computer_tool = {
"type": "computer_use_preview",
"display_width": display_width,
"display_height": display_height,
"environment": "windows"
"environment": "windows",
}
# Prepare API call kwargs
api_kwargs = {
"model": model,
@@ -211,32 +207,34 @@ Task: Click {instruction}. Output ONLY a click action on the target element."""
"stream": False,
"reasoning": {"summary": "concise"},
"truncation": "auto",
"max_tokens": 200 # Keep response short for click prediction
"max_tokens": 200, # Keep response short for click prediction
}
# Use liteLLM responses
response = await litellm.aresponses(**api_kwargs)
# Extract click coordinates from response output
output_dict = response.model_dump()
output_items = output_dict.get("output", [])
output_items = output_dict.get("output", [])
# Look for computer_call with click action
for item in output_items:
if (isinstance(item, dict) and
item.get("type") == "computer_call" and
isinstance(item.get("action"), dict)):
if (
isinstance(item, dict)
and item.get("type") == "computer_call"
and isinstance(item.get("action"), dict)
):
action = item["action"]
if action.get("x") is not None and action.get("y") is not None:
return (int(action.get("x")), int(action.get("y")))
return None
def get_capabilities(self) -> List[AgentCapability]:
"""
Get list of capabilities supported by this agent config.
Returns:
List of capability strings
"""

View File

@@ -4,20 +4,22 @@ Based on OpenCUA model for GUI grounding tasks.
"""
import asyncio
import json
import re
import base64
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
from io import BytesIO
import uuid
from PIL import Image
import litellm
import json
import math
import re
import uuid
from io import BytesIO
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import litellm
from PIL import Image
from .composed_grounded import ComposedGroundedConfig
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..loops.base import AsyncAgentConfig
from ..types import AgentCapability, AgentResponse, Messages, Tools
from .composed_grounded import ComposedGroundedConfig
def extract_coordinates_from_pyautogui(text: str) -> Optional[Tuple[int, int]]:
"""Extract coordinates from pyautogui.click(x=..., y=...) format."""
@@ -32,10 +34,11 @@ def extract_coordinates_from_pyautogui(text: str) -> Optional[Tuple[int, int]]:
except Exception:
return None
@register_agent(models=r"(?i).*OpenCUA.*")
class OpenCUAConfig(ComposedGroundedConfig):
"""OpenCUA agent configuration implementing AsyncAgentConfig protocol for click prediction."""
def __init__(self):
super().__init__()
self.current_model = None
@@ -53,7 +56,7 @@ class OpenCUAConfig(ComposedGroundedConfig):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""Fallback to a self-composed model"""
return await super().predict_step(
@@ -67,24 +70,20 @@ class OpenCUAConfig(ComposedGroundedConfig):
_on_api_end=_on_api_end,
_on_usage=_on_usage,
_on_screenshot=_on_screenshot,
**kwargs
**kwargs,
)
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates using OpenCUA model via litellm.acompletion.
Args:
model: The OpenCUA model name
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
@@ -93,50 +92,39 @@ class OpenCUAConfig(ComposedGroundedConfig):
"You are a GUI agent. You are given a task and a screenshot of the screen. "
"You need to perform a series of pyautogui actions to complete the task."
)
system_message = {
"role": "system",
"content": system_prompt
}
system_message = {"role": "system", "content": system_prompt}
# Prepare user message with image and instruction
user_message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_b64}"
}
},
{
"type": "text",
"text": f"Click on {instruction}"
}
]
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
{"type": "text", "text": f"Click on {instruction}"},
],
}
# Prepare API call kwargs
api_kwargs = {
"model": model,
"messages": [system_message, user_message],
"max_new_tokens": 2056,
"temperature": 0,
**kwargs
**kwargs,
}
# Use liteLLM acompletion
response = await litellm.acompletion(**api_kwargs)
# Extract response text
output_text = response.choices[0].message.content
# print(output_text)
# Extract coordinates from pyautogui format
coordinates = extract_coordinates_from_pyautogui(output_text)
return coordinates
def get_capabilities(self) -> List[AgentCapability]:
"""Return the capabilities supported by this agent."""
return ["click"]

View File

@@ -0,0 +1,510 @@
"""
Qwen3-VL agent loop implementation using litellm with function/tool calling.
- Passes a ComputerUse tool schema to acompletion
- Converts between Responses items and completion messages using helpers
"""
from __future__ import annotations
import json
import re
from typing import Any, Dict, List, Optional, Tuple
import litellm
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
from ..decorators import register_agent
from ..loops.base import AsyncAgentConfig
from ..responses import (
convert_completion_messages_to_responses_items,
convert_responses_items_to_completion_messages,
)
from ..types import AgentCapability
# ComputerUse tool schema (OpenAI function tool format)
QWEN3_COMPUTER_TOOL: Dict[str, Any] = {
"type": "function",
"function": {
"name": "computer",
"description": (
"Use a mouse and keyboard to interact with a computer, and take screenshots.\n"
"* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n"
"* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.\n"
"* The screen's resolution is 1000x1000.\n"
"* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n"
"* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n"
"* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges."
),
"parameters": {
"type": "object",
"properties": {
"action": {
"description": "The action to perform.",
"enum": [
"key",
"type",
"mouse_move",
"left_click",
"left_click_drag",
"right_click",
"middle_click",
"double_click",
"triple_click",
"scroll",
"hscroll",
"screenshot",
"wait",
# "terminate",
# "answer",
],
"type": "string",
},
"keys": {
"description": "Required only by action=key.",
"type": "array",
"items": {"type": "string"},
},
"text": {
"description": "Required only by action=type and action=answer.",
"type": "string",
},
"coordinate": {
"description": "(x, y): Pixel coordinates from top-left.",
"type": "array",
"items": {"type": ["number", "integer"]},
"minItems": 2,
"maxItems": 2,
},
"pixels": {
"description": "Scroll amount. Positive=up, negative=down. For scroll/hscroll.",
"type": "number",
},
"time": {
"description": "Seconds to wait (action=wait).",
"type": "number",
},
# "status": {
# "description": "Task status (action=terminate).",
# "type": "string",
# "enum": ["success", "failure"],
# },
},
"required": ["action"],
},
},
}
def _build_nous_system(functions: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Use qwen-agent NousFnCallPrompt to generate a system message embedding tool schema."""
try:
from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
ContentItem as NousContentItem,
)
from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
Message as NousMessage,
)
from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
NousFnCallPrompt,
)
except ImportError:
raise ImportError(
"qwen-agent not installed. Please install it with `pip install cua-agent[qwen]`."
)
msgs = NousFnCallPrompt().preprocess_fncall_messages(
messages=[
NousMessage(
role="system", content=[NousContentItem(text="You are a helpful assistant.")]
)
],
functions=functions,
lang="en",
)
sys = msgs[0].model_dump()
# Convert qwen-agent structured content to OpenAI-style content list
content = [{"type": "text", "text": c["text"]} for c in sys.get("content", [])]
return {"role": "system", "content": content}
def _parse_tool_call_from_text(text: str) -> Optional[Dict[str, Any]]:
"""Extract JSON object within <tool_call>...</tool_call> from model text."""
m = re.search(r"<tool_call>\s*(\{[\s\S]*?\})\s*</tool_call>", text)
if not m:
return None
try:
return json.loads(m.group(1))
except Exception:
return None
async def _unnormalize_coordinate(args: Dict[str, Any], dims: Tuple[int, int]) -> Dict[str, Any]:
"""Coordinates appear in 0..1000 space, scale to actual screen size using dims if provided."""
coord = args.get("coordinate")
if not coord or not isinstance(coord, (list, tuple)) or len(coord) < 2:
return args
x, y = float(coord[0]), float(coord[1])
width, height = float(dims[0]), float(dims[1])
x_abs = max(0.0, min(width, (x / 1000.0) * width))
y_abs = max(0.0, min(height, (y / 1000.0) * height))
args = {**args, "coordinate": [round(x_abs), round(y_abs)]}
return args
def convert_qwen_tool_args_to_computer_action(args: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Convert Qwen computer tool arguments to the Computer Calls action schema.
Qwen (example):
{"action": "left_click", "coordinate": [114, 68]}
Target (example):
{"action": "left_click", "x": 114, "y": 68}
Other mappings:
- right_click, middle_click, double_click (triple_click -> double_click)
- mouse_move -> { action: "move", x, y }
- key -> { action: "keypress", keys: [...] }
- type -> { action: "type", text }
- scroll/hscroll -> { action: "scroll", scroll_x, scroll_y, x, y }
- wait -> { action: "wait" }
- terminate/answer are not direct UI actions; return None for now
"""
if not isinstance(args, dict):
return None
action = args.get("action")
if not isinstance(action, str):
return None
# Coordinates helper
coord = args.get("coordinate")
x = y = None
if isinstance(coord, (list, tuple)) and len(coord) >= 2:
try:
x = int(round(float(coord[0])))
y = int(round(float(coord[1])))
except Exception:
x = y = None
# Map actions
a = action.lower()
if a in {"left_click", "right_click", "middle_click", "double_click"}:
if x is None or y is None:
return None
return {"action": a, "x": x, "y": y}
if a == "triple_click":
# Approximate as double_click
if x is None or y is None:
return None
return {"action": "double_click", "x": x, "y": y}
if a == "mouse_move":
if x is None or y is None:
return None
return {"action": "move", "x": x, "y": y}
if a == "key":
keys = args.get("keys")
if isinstance(keys, list) and all(isinstance(k, str) for k in keys):
return {"action": "keypress", "keys": keys}
return None
if a == "type":
text = args.get("text")
if isinstance(text, str):
return {"action": "type", "text": text}
return None
if a in {"scroll", "hscroll"}:
pixels = args.get("pixels") or 0
try:
pixels_val = int(round(float(pixels)))
except Exception:
pixels_val = 0
scroll_x = pixels_val if a == "hscroll" else 0
scroll_y = pixels_val if a == "scroll" else 0
# Include cursor position if available (optional)
out: Dict[str, Any] = {"action": "scroll", "scroll_x": scroll_x, "scroll_y": scroll_y}
if x is not None and y is not None:
out.update({"x": x, "y": y})
return out
if a == "wait":
return {"action": "wait"}
# Non-UI or terminal actions: terminate/answer -> not mapped here
return None
@register_agent(models=r"(?i).*qwen.*", priority=-1)
class Qwen3VlConfig(AsyncAgentConfig):
async def predict_step(
self,
messages: List[Dict[str, Any]],
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs,
) -> Dict[str, Any]:
# Build messages using NousFnCallPrompt system with tool schema in text
# Start with converted conversation (images/text preserved)
converted_msgs = convert_responses_items_to_completion_messages(
messages,
allow_images_in_tool_results=False,
)
# Prepend Nous-generated system if available
nous_system = _build_nous_system([QWEN3_COMPUTER_TOOL["function"]])
completion_messages = ([nous_system] if nous_system else []) + converted_msgs
# If there is no screenshot in the conversation, take one now and inject it.
# Also record a pre_output_items assistant message to reflect action.
def _has_any_image(msgs: List[Dict[str, Any]]) -> bool:
for m in msgs:
content = m.get("content")
if isinstance(content, list):
for p in content:
if isinstance(p, dict) and p.get("type") == "image_url":
return True
return False
pre_output_items: List[Dict[str, Any]] = []
if not _has_any_image(completion_messages):
if computer_handler is None or not hasattr(computer_handler, "screenshot"):
raise RuntimeError(
"No screenshots present and computer_handler.screenshot is not available."
)
screenshot_b64 = await computer_handler.screenshot()
if not screenshot_b64:
raise RuntimeError("Failed to capture screenshot from computer_handler.")
# Inject a user message with the screenshot so the model can see current context
completion_messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{screenshot_b64}"},
},
{"type": "text", "text": "Current screen"},
],
}
)
# Add assistant message to outputs to reflect the action, similar to composed_grounded.py
pre_output_items.append(
{
"type": "message",
"role": "assistant",
"content": [
{
"type": "text",
"text": "Taking a screenshot to see the current computer screen.",
}
],
}
)
# Smart-resize all screenshots and attach min/max pixel hints. Fail fast if deps missing.
# Also record the last resized width/height to unnormalize coordinates later.
last_rw: Optional[int] = None
last_rh: Optional[int] = None
MIN_PIXELS = 3136
MAX_PIXELS = 12845056
try:
import base64
import io
from PIL import Image # type: ignore
from qwen_vl_utils import smart_resize # type: ignore
except Exception:
raise ImportError(
"qwen-vl-utils not installed. Please install it with `pip install cua-agent[qwen]`."
)
for msg in completion_messages:
content = msg.get("content")
if not isinstance(content, list):
continue
for part in content:
if isinstance(part, dict) and part.get("type") == "image_url":
url = ((part.get("image_url") or {}).get("url")) or ""
# Expect data URL like data:image/png;base64,<b64>
if url.startswith("data:") and "," in url:
b64 = url.split(",", 1)[1]
img_bytes = base64.b64decode(b64)
im = Image.open(io.BytesIO(img_bytes))
h, w = im.height, im.width
rh, rw = smart_resize(
h, w, factor=32, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS
)
# Attach hints on this image block
part["min_pixels"] = MIN_PIXELS
part["max_pixels"] = MAX_PIXELS
last_rw, last_rh = rw, rh
api_kwargs: Dict[str, Any] = {
"model": model,
"messages": completion_messages,
"max_retries": max_retries,
"stream": stream,
**{k: v for k, v in kwargs.items()},
}
if use_prompt_caching:
api_kwargs["use_prompt_caching"] = use_prompt_caching
if _on_api_start:
await _on_api_start(api_kwargs)
response = await litellm.acompletion(**api_kwargs)
if _on_api_end:
await _on_api_end(api_kwargs, response)
usage = {
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage( # type: ignore
response.usage
).model_dump(),
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(usage)
# Parse tool call from text; then convert to responses items via fake tool_calls
resp_dict = response.model_dump() # type: ignore
choice = (resp_dict.get("choices") or [{}])[0]
content_text = ((choice.get("message") or {}).get("content")) or ""
tool_call = _parse_tool_call_from_text(content_text)
output_items: List[Dict[str, Any]] = []
if tool_call and isinstance(tool_call, dict):
fn_name = tool_call.get("name") or "computer"
raw_args = tool_call.get("arguments") or {}
# Unnormalize coordinates to actual screen size using last resized dims
if last_rw is None or last_rh is None:
raise RuntimeError(
"No screenshots found to derive dimensions for coordinate unnormalization."
)
args = await _unnormalize_coordinate(raw_args, (last_rw, last_rh))
# Build an OpenAI-style tool call so we can reuse the converter
fake_cm = {
"role": "assistant",
"tool_calls": [
{
"type": "function",
"id": "call_0",
"function": {
"name": fn_name,
"arguments": json.dumps(args),
},
}
],
}
output_items.extend(convert_completion_messages_to_responses_items([fake_cm]))
else:
# Fallback: just return assistant text
fake_cm = {"role": "assistant", "content": content_text}
output_items.extend(convert_completion_messages_to_responses_items([fake_cm]))
# Prepend any pre_output_items (e.g., simulated screenshot-taking message)
return {"output": (pre_output_items + output_items), "usage": usage}
def get_capabilities(self) -> List[AgentCapability]:
return ["step"]
async def predict_click(
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates using Qwen3-VL via litellm.acompletion.
Only exposes a reduced tool schema with left_click to bias model to output a single click.
Returns (x, y) absolute pixels when screen dimensions can be obtained; otherwise normalized 0..1000 integers.
"""
# Reduced tool
reduced_tool = {
"type": "function",
"function": {
**QWEN3_COMPUTER_TOOL["function"],
"parameters": {
"type": "object",
"properties": {
"action": {"type": "string", "enum": ["left_click"]},
"coordinate": {
"description": "(x, y) in 0..1000 reference space",
"type": "array",
"items": {"type": ["number", "integer"]},
"minItems": 2,
"maxItems": 2,
},
},
"required": ["action", "coordinate"],
},
},
}
# Build Nous system (lazy import inside helper already raises clear guidance if missing)
nous_system = _build_nous_system([reduced_tool["function"]])
# Pre-process using smart_resize
min_pixels = 3136
max_pixels = 12845056
try:
# Lazy import to avoid hard dependency
import base64
import io
# If PIL is available, estimate size from image to derive smart bounds
from PIL import Image
from qwen_vl_utils import smart_resize # type: ignore
img_bytes = base64.b64decode(image_b64)
im = Image.open(io.BytesIO(img_bytes))
h, w = im.height, im.width
# Qwen notebook suggests factor=32 and a wide min/max range
rh, rw = smart_resize(h, w, factor=32, min_pixels=min_pixels, max_pixels=max_pixels)
except Exception:
raise ImportError(
"qwen-vl-utils not installed. Please install it with `pip install cua-agent[qwen]`."
)
messages = []
if nous_system:
messages.append(nous_system)
image_block: Dict[str, Any] = {
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
"min_pixels": min_pixels,
"max_pixels": max_pixels,
}
# Single user message with image and instruction, matching OpenAI-style content blocks
messages.append(
{
"role": "user",
"content": [
image_block,
{"type": "text", "text": instruction},
],
}
)
api_kwargs: Dict[str, Any] = {
"model": model,
"messages": messages,
**{k: v for k, v in kwargs.items()},
}
response = await litellm.acompletion(**api_kwargs)
resp = response.model_dump() # type: ignore
choice = (resp.get("choices") or [{}])[0]
content_text = ((choice.get("message") or {}).get("content")) or ""
tool_call = _parse_tool_call_from_text(content_text) or {}
args = tool_call.get("arguments") or {}
args = await _unnormalize_coordinate(args, (rh, rw))
coord = args.get("coordinate")
if isinstance(coord, (list, tuple)) and len(coord) >= 2:
return int(coord[0]), int(coord[1])
return None

View File

@@ -4,39 +4,50 @@ Paper: https://arxiv.org/abs/2501.12326
Code: https://github.com/bytedance/UI-TARS
"""
import ast
import asyncio
from ctypes import cast
import json
import base64
import json
import math
import re
import ast
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
from ctypes import cast
from io import BytesIO
from PIL import Image
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import litellm
from litellm.types.utils import ModelResponse
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
from litellm.responses.utils import Usage
from openai.types.responses.response_computer_tool_call_param import ActionType, ResponseComputerToolCallParam
from litellm.types.utils import ModelResponse
from openai.types.responses.response_computer_tool_call_param import (
ActionType,
ResponseComputerToolCallParam,
)
from openai.types.responses.response_input_param import ComputerCallOutput
from openai.types.responses.response_output_message_param import ResponseOutputMessageParam
from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam, Summary
from openai.types.responses.response_output_message_param import (
ResponseOutputMessageParam,
)
from openai.types.responses.response_reasoning_item_param import (
ResponseReasoningItemParam,
Summary,
)
from PIL import Image
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..responses import (
make_reasoning_item,
make_output_text_item,
make_click_item,
make_double_click_item,
make_drag_item,
make_input_image_item,
make_keypress_item,
make_output_text_item,
make_reasoning_item,
make_scroll_item,
make_type_item,
make_wait_item,
make_input_image_item
)
from ..types import AgentCapability, AgentResponse, Messages, Tools
# Constants from reference code
IMAGE_FACTOR = 28
@@ -94,6 +105,7 @@ click(point='<|box_start|>(x1,y1)<|box_end|>')
## User Instruction
{instruction}"""
def round_by_factor(number: float, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
@@ -110,7 +122,11 @@ def floor_by_factor(number: float, factor: int) -> int:
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
@@ -144,14 +160,14 @@ def escape_single_quotes(text):
def parse_action(action_str):
"""Parse action string into structured format."""
try:
node = ast.parse(action_str, mode='eval')
node = ast.parse(action_str, mode="eval")
if not isinstance(node, ast.Expression):
raise ValueError("Not an expression")
call = node.body
if not isinstance(call, ast.Call):
raise ValueError("Not a function call")
# Get function name
if isinstance(call.func, ast.Name):
func_name = call.func.id
@@ -159,7 +175,7 @@ def parse_action(action_str):
func_name = call.func.attr
else:
func_name = None
# Get keyword arguments
kwargs = {}
for kw in call.keywords:
@@ -171,12 +187,9 @@ def parse_action(action_str):
else:
value = None
kwargs[key] = value
return {
'function': func_name,
'args': kwargs
}
return {"function": func_name, "args": kwargs}
except Exception as e:
print(f"Failed to parse action '{action_str}': {e}")
return None
@@ -185,39 +198,39 @@ def parse_action(action_str):
def parse_uitars_response(text: str, image_width: int, image_height: int) -> List[Dict[str, Any]]:
"""Parse UITARS model response into structured actions."""
text = text.strip()
# Extract thought
thought = None
if text.startswith("Thought:"):
thought_match = re.search(r"Thought: (.+?)(?=\s*Action:|$)", text, re.DOTALL)
if thought_match:
thought = thought_match.group(1).strip()
# Extract action
if "Action:" not in text:
raise ValueError("No Action found in response")
action_str = text.split("Action:")[-1].strip()
# Handle special case for type actions
if "type(content" in action_str:
def escape_quotes(match):
return match.group(1)
pattern = r"type\(content='(.*?)'\)"
content = re.sub(pattern, escape_quotes, action_str)
action_str = escape_single_quotes(content)
action_str = "type(content='" + action_str + "')"
# Parse the action
parsed_action = parse_action(action_str.replace("\n", "\\n").lstrip())
if parsed_action is None:
raise ValueError(f"Action can't parse: {action_str}")
action_type = parsed_action["function"]
params = parsed_action["args"]
# Process parameters
action_inputs = {}
for param_name, param in params.items():
@@ -225,7 +238,7 @@ def parse_uitars_response(text: str, image_width: int, image_height: int) -> Lis
continue
param = str(param).lstrip()
action_inputs[param_name.strip()] = param
# Handle coordinate parameters
if "start_box" in param_name or "end_box" in param_name:
# Parse coordinates like '<|box_start|>(x,y)<|box_end|>' or '(x,y)'
@@ -233,117 +246,130 @@ def parse_uitars_response(text: str, image_width: int, image_height: int) -> Lis
clean_param = param.replace("<|box_start|>", "").replace("<|box_end|>", "")
# Then remove parentheses and split
numbers = clean_param.replace("(", "").replace(")", "").split(",")
try:
float_numbers = [float(num.strip()) / 1000 for num in numbers] # Normalize to 0-1 range
float_numbers = [
float(num.strip()) / 1000 for num in numbers
] # Normalize to 0-1 range
if len(float_numbers) == 2:
# Single point, duplicate for box format
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
float_numbers = [
float_numbers[0],
float_numbers[1],
float_numbers[0],
float_numbers[1],
]
action_inputs[param_name.strip()] = str(float_numbers)
except ValueError as e:
# If parsing fails, keep the original parameter value
print(f"Warning: Could not parse coordinates '{param}': {e}")
action_inputs[param_name.strip()] = param
return [{
"thought": thought,
"action_type": action_type,
"action_inputs": action_inputs,
"text": text
}]
return [
{
"thought": thought,
"action_type": action_type,
"action_inputs": action_inputs,
"text": text,
}
]
def convert_to_computer_actions(parsed_responses: List[Dict[str, Any]], image_width: int, image_height: int) -> List[ResponseComputerToolCallParam | ResponseOutputMessageParam]:
def convert_to_computer_actions(
parsed_responses: List[Dict[str, Any]], image_width: int, image_height: int
) -> List[ResponseComputerToolCallParam | ResponseOutputMessageParam]:
"""Convert parsed UITARS responses to computer actions."""
computer_actions = []
for response in parsed_responses:
action_type = response.get("action_type")
action_inputs = response.get("action_inputs", {})
if action_type == "finished":
finished_text = action_inputs.get("content", "Task completed successfully.")
computer_actions.append(make_output_text_item(finished_text))
break
elif action_type == "wait":
computer_actions.append(make_wait_item())
elif action_type == "call_user":
computer_actions.append(make_output_text_item("I need assistance from the user to proceed with this task."))
computer_actions.append(
make_output_text_item("I need assistance from the user to proceed with this task.")
)
elif action_type in ["click", "left_single"]:
start_box = action_inputs.get("start_box")
if start_box:
coords = eval(start_box)
x = int((coords[0] + coords[2]) / 2 * image_width)
y = int((coords[1] + coords[3]) / 2 * image_height)
computer_actions.append(make_click_item(x, y, "left"))
elif action_type == "double_click":
start_box = action_inputs.get("start_box")
if start_box:
coords = eval(start_box)
x = int((coords[0] + coords[2]) / 2 * image_width)
y = int((coords[1] + coords[3]) / 2 * image_height)
computer_actions.append(make_double_click_item(x, y))
elif action_type == "right_click":
start_box = action_inputs.get("start_box")
if start_box:
coords = eval(start_box)
x = int((coords[0] + coords[2]) / 2 * image_width)
y = int((coords[1] + coords[3]) / 2 * image_height)
computer_actions.append(make_click_item(x, y, "right"))
elif action_type == "type":
content = action_inputs.get("content", "")
computer_actions.append(make_type_item(content))
elif action_type == "hotkey":
key = action_inputs.get("key", "")
keys = key.split()
computer_actions.append(make_keypress_item(keys))
elif action_type == "press":
key = action_inputs.get("key", "")
computer_actions.append(make_keypress_item([key]))
elif action_type == "scroll":
start_box = action_inputs.get("start_box")
direction = action_inputs.get("direction", "down")
if start_box:
coords = eval(start_box)
x = int((coords[0] + coords[2]) / 2 * image_width)
y = int((coords[1] + coords[3]) / 2 * image_height)
else:
x, y = image_width // 2, image_height // 2
scroll_y = 5 if "up" in direction.lower() else -5
computer_actions.append(make_scroll_item(x, y, 0, scroll_y))
elif action_type == "drag":
start_box = action_inputs.get("start_box")
end_box = action_inputs.get("end_box")
if start_box and end_box:
start_coords = eval(start_box)
end_coords = eval(end_box)
start_x = int((start_coords[0] + start_coords[2]) / 2 * image_width)
start_y = int((start_coords[1] + start_coords[3]) / 2 * image_height)
end_x = int((end_coords[0] + end_coords[2]) / 2 * image_width)
end_y = int((end_coords[1] + end_coords[3]) / 2 * image_height)
path = [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}]
computer_actions.append(make_drag_item(path))
return computer_actions
@@ -354,33 +380,35 @@ def pil_to_base64(image: Image.Image) -> str:
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def process_image_for_uitars(image_data: str, max_pixels: int = MAX_PIXELS, min_pixels: int = MIN_PIXELS) -> tuple[Image.Image, int, int]:
def process_image_for_uitars(
image_data: str, max_pixels: int = MAX_PIXELS, min_pixels: int = MIN_PIXELS
) -> tuple[Image.Image, int, int]:
"""Process image for UITARS model input."""
# Decode base64 image
if image_data.startswith('data:image'):
image_data = image_data.split(',')[1]
if image_data.startswith("data:image"):
image_data = image_data.split(",")[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes))
original_width, original_height = image.size
# Resize image according to UITARS requirements
if image.width * image.height > max_pixels:
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width = int(image.width * resize_factor)
height = int(image.height * resize_factor)
image = image.resize((width, height))
if image.width * image.height < min_pixels:
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
width = math.ceil(image.width * resize_factor)
height = math.ceil(image.height * resize_factor)
image = image.resize((width, height))
if image.mode != "RGB":
image = image.convert("RGB")
return image, original_width, original_height
@@ -391,7 +419,11 @@ def sanitize_message(msg: Any) -> Any:
for key, value in msg.items():
if key == "content" and isinstance(value, list):
result[key] = [
{k: v for k, v in item.items() if k != "image_url"} if isinstance(item, dict) else item
(
{k: v for k, v in item.items() if k != "image_url"}
if isinstance(item, dict)
else item
)
for item in value
]
else:
@@ -406,38 +438,41 @@ def sanitize_message(msg: Any) -> Any:
def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any]]:
"""
Convert UITARS internal message format back to LiteLLM format.
This function processes reasoning, computer_call, and computer_call_output messages
and converts them to the appropriate LiteLLM assistant message format.
Args:
messages: List of UITARS internal messages
Returns:
List of LiteLLM formatted messages
"""
litellm_messages = []
current_assistant_content = []
for message in messages:
if isinstance(message, dict):
message_type = message.get("type")
if message_type == "reasoning":
# Extract reasoning text from summary
summary = message.get("summary", [])
if summary and isinstance(summary, list):
for summary_item in summary:
if isinstance(summary_item, dict) and summary_item.get("type") == "summary_text":
if (
isinstance(summary_item, dict)
and summary_item.get("type") == "summary_text"
):
reasoning_text = summary_item.get("text", "")
if reasoning_text:
current_assistant_content.append(f"Thought: {reasoning_text}")
elif message_type == "computer_call":
# Convert computer action to UITARS action format
action = message.get("action", {})
action_type = action.get("type")
if action_type == "click":
x, y = action.get("x", 0), action.get("y", 0)
button = action.get("button", "left")
@@ -447,59 +482,65 @@ def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any
action_text = f"Action: right_single(start_box='({x},{y})')"
else:
action_text = f"Action: click(start_box='({x},{y})')"
elif action_type == "double_click":
x, y = action.get("x", 0), action.get("y", 0)
action_text = f"Action: left_double(start_box='({x},{y})')"
elif action_type == "drag":
start_x, start_y = action.get("start_x", 0), action.get("start_y", 0)
end_x, end_y = action.get("end_x", 0), action.get("end_y", 0)
action_text = f"Action: drag(start_box='({start_x},{start_y})', end_box='({end_x},{end_y})')"
elif action_type == "key":
key = action.get("key", "")
action_text = f"Action: hotkey(key='{key}')"
elif action_type == "type":
text = action.get("text", "")
# Escape single quotes in the text
escaped_text = escape_single_quotes(text)
action_text = f"Action: type(content='{escaped_text}')"
elif action_type == "scroll":
x, y = action.get("x", 0), action.get("y", 0)
direction = action.get("direction", "down")
action_text = f"Action: scroll(start_box='({x},{y})', direction='{direction}')"
elif action_type == "wait":
action_text = "Action: wait()"
else:
# Fallback for unknown action types
action_text = f"Action: {action_type}({action})"
current_assistant_content.append(action_text)
# When we hit a computer_call_output, finalize the current assistant message
if current_assistant_content:
litellm_messages.append({
"role": "assistant",
"content": [{"type": "text", "text": "\n".join(current_assistant_content)}]
})
litellm_messages.append(
{
"role": "assistant",
"content": [
{"type": "text", "text": "\n".join(current_assistant_content)}
],
}
)
current_assistant_content = []
elif message_type == "computer_call_output":
# Add screenshot from computer call output
output = message.get("output", {})
if isinstance(output, dict) and output.get("type") == "input_image":
image_url = output.get("image_url", "")
if image_url:
litellm_messages.append({
"role": "user",
"content": [{"type": "image_url", "image_url": {"url": image_url}}]
})
litellm_messages.append(
{
"role": "user",
"content": [{"type": "image_url", "image_url": {"url": image_url}}],
}
)
elif message.get("role") == "user":
# # Handle user messages
# content = message.get("content", "")
@@ -514,24 +555,22 @@ def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any
# "content": content
# })
pass
# Add any remaining assistant content
if current_assistant_content:
litellm_messages.append({
"role": "assistant",
"content": current_assistant_content
})
litellm_messages.append({"role": "assistant", "content": current_assistant_content})
return litellm_messages
@register_agent(models=r"(?i).*ui-?tars.*")
class UITARSConfig:
"""
UITARS agent configuration using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B model.
Supports UITARS vision-language models for computer control.
"""
async def predict_step(
self,
messages: List[Dict[str, Any]],
@@ -545,11 +584,11 @@ class UITARSConfig:
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""
Predict the next step based on input messages.
Args:
messages: Input messages following Responses format
model: Model name to use
@@ -562,22 +601,22 @@ class UITARSConfig:
_on_usage: Callback for usage tracking
_on_screenshot: Callback for screenshot events
**kwargs: Additional arguments
Returns:
Dictionary with "output" (output items) and "usage" array
"""
tools = tools or []
# Create response items
response_items = []
# Find computer tool for screen dimensions
computer_tool = None
for tool_schema in tools:
if tool_schema["type"] == "computer":
computer_tool = tool_schema["computer"]
break
# Get screen dimensions
screen_width, screen_height = 1024, 768
if computer_tool:
@@ -585,20 +624,20 @@ class UITARSConfig:
screen_width, screen_height = await computer_tool.get_dimensions()
except:
pass
# Process messages to extract instruction and image
instruction = ""
image_data = None
# Convert messages to list if string
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
# Extract instruction and latest screenshot
for message in reversed(messages):
if isinstance(message, dict):
content = message.get("content", "")
# Handle different content formats
if isinstance(content, str):
if not instruction and message.get("role") == "user":
@@ -614,46 +653,41 @@ class UITARSConfig:
image_data = image_url.get("url", "")
else:
image_data = image_url
# Also check for computer_call_output with screenshots
if message.get("type") == "computer_call_output" and not image_data:
output = message.get("output", {})
if isinstance(output, dict) and output.get("type") == "input_image":
image_data = output.get("image_url", "")
if instruction and image_data:
break
if not instruction:
instruction = "Help me complete this task by analyzing the screen and taking appropriate actions."
instruction = (
"Help me complete this task by analyzing the screen and taking appropriate actions."
)
# Create prompt
user_prompt = UITARS_PROMPT_TEMPLATE.format(
instruction=instruction,
action_space=UITARS_ACTION_SPACE,
language="English"
instruction=instruction, action_space=UITARS_ACTION_SPACE, language="English"
)
# Convert conversation history to LiteLLM format
history_messages = convert_uitars_messages_to_litellm(messages)
# Prepare messages for liteLLM
litellm_messages = [
{
"role": "system",
"content": "You are a helpful assistant."
}
]
litellm_messages = [{"role": "system", "content": "You are a helpful assistant."}]
# Add current user instruction with screenshot
current_user_message = {
"role": "user",
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
]
],
}
litellm_messages.append(current_user_message)
# Process image for UITARS
if not image_data:
# Take screenshot if none found in messages
@@ -667,17 +701,22 @@ class UITARSConfig:
raise ValueError("No screenshot found in messages and no computer_handler provided")
processed_image, original_width, original_height = process_image_for_uitars(image_data)
encoded_image = pil_to_base64(processed_image)
# Add conversation history
if history_messages:
litellm_messages.extend(history_messages)
else:
litellm_messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
]
})
litellm_messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{encoded_image}"},
}
],
}
)
# Prepare API call kwargs
api_kwargs = {
@@ -687,146 +726,142 @@ class UITARSConfig:
"temperature": kwargs.get("temperature", 0.0),
"do_sample": kwargs.get("temperature", 0.0) > 0.0,
"num_retries": max_retries,
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
}
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
# Call liteLLM with UITARS model
response = await litellm.acompletion(**api_kwargs)
# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)
# Extract response content
response_content = response.choices[0].message.content.strip() # type: ignore
response_content = response.choices[0].message.content.strip() # type: ignore
# Parse UITARS response
parsed_responses = parse_uitars_response(response_content, original_width, original_height)
# Convert to computer actions
computer_actions = convert_to_computer_actions(parsed_responses, original_width, original_height)
computer_actions = convert_to_computer_actions(
parsed_responses, original_width, original_height
)
# Add computer actions to response items
thought = parsed_responses[0].get("thought", "")
if thought:
response_items.append(make_reasoning_item(thought))
response_items.extend(computer_actions)
# Extract usage information
response_usage = {
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(
response.usage
).model_dump(),
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(response_usage)
# Create agent response
agent_response = {
"output": response_items,
"usage": response_usage
}
agent_response = {"output": response_items, "usage": response_usage}
return agent_response
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str
self, model: str, image_b64: str, instruction: str
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates based on image and instruction.
UITARS supports click prediction through its action parsing.
Args:
model: Model name to use
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
Tuple with (x, y) coordinates or None
"""
try:
# Create prompt using grounding template
user_prompt = GROUNDING_UITARS_PROMPT_TEMPLATE.format(
instruction=instruction
)
user_prompt = GROUNDING_UITARS_PROMPT_TEMPLATE.format(instruction=instruction)
# Process image for UITARS
processed_image, original_width, original_height = process_image_for_uitars(image_b64)
encoded_image = pil_to_base64(processed_image)
# Prepare messages for liteLLM
litellm_messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
]
}
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{encoded_image}"},
},
],
},
]
# Prepare API call kwargs
api_kwargs = {
"model": model,
"messages": litellm_messages,
"max_tokens": 2056,
"temperature": 0.0,
"do_sample": False
"do_sample": False,
}
# Call liteLLM with UITARS model
response = await litellm.acompletion(**api_kwargs)
# Extract response content
response_content = response.choices[0].message.content.strip() # type: ignore
response_content = response.choices[0].message.content.strip() # type: ignore
print(response_content)
# Parse the response to extract click coordinates
# Look for click action with coordinates (with special tokens)
click_pattern = r"click\(point='<\|box_start\|>\((\d+),(\d+)\)<\|box_end\|>'\)"
match = re.search(click_pattern, response_content)
# Fallback: Look for simpler format without special tokens
if not match:
# Pattern for: click(start_box='(x,y)') or click(point='(x,y)')
fallback_pattern = r"click\((?:start_box|point)='\((\d+),(\d+)\)'\)"
match = re.search(fallback_pattern, response_content)
if match:
x, y = int(match.group(1)), int(match.group(2))
# Scale coordinates back to original image dimensions
scale_x = original_width / processed_image.width
scale_y = original_height / processed_image.height
scaled_x = int(x * scale_x)
scaled_y = int(y * scale_y)
return (scaled_x, scaled_y)
return None
except Exception as e:
# Log error and return None
print(f"Error in predict_click: {e}")
return None
def get_capabilities(self) -> List[AgentCapability]:
"""
Get list of capabilities supported by this agent config.
Returns:
List of capability strings
"""
return ["step", "click"]
return ["step", "click"]

View File

@@ -1,19 +1,22 @@
"""
Example usage of the proxy server and client requests.
"""
import dotenv
dotenv.load_dotenv()
import asyncio
import json
import os
from typing import Any, Dict
import aiohttp
from typing import Dict, Any
async def test_http_endpoint():
"""Test the HTTP /responses endpoint."""
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
assert isinstance(anthropic_api_key, str), "ANTHROPIC_API_KEY environment variable must be set"
@@ -21,11 +24,9 @@ async def test_http_endpoint():
simple_request = {
"model": "anthropic/claude-3-5-sonnet-20241022",
"input": "Tell me a three sentence bedtime story about a unicorn.",
"env": {
"ANTHROPIC_API_KEY": anthropic_api_key
}
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
}
# Example 2: Multi-modal request with image
multimodal_request = {
"model": "anthropic/claude-3-5-sonnet-20241022",
@@ -36,70 +37,72 @@ async def test_http_endpoint():
{"type": "input_text", "text": "what is in this image?"},
{
"type": "input_image",
"image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
]
"image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
},
],
}
],
"env": {
"ANTHROPIC_API_KEY": anthropic_api_key
}
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
}
# Example 3: Request with custom agent and computer kwargs
custom_request = {
"model": "anthropic/claude-3-5-sonnet-20241022",
"input": "Take a screenshot and tell me what you see",
"env": {
"ANTHROPIC_API_KEY": anthropic_api_key
}
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
}
# Test requests
base_url = "https://m-linux-96lcxd2c2k.containers.cloud.trycua.com:8443"
# base_url = "http://localhost:8000"
api_key = os.getenv("CUA_API_KEY")
assert isinstance(api_key, str), "CUA_API_KEY environment variable must be set"
async with aiohttp.ClientSession() as session:
for i, request_data in enumerate([
simple_request,
# multimodal_request,
custom_request
], 1):
for i, request_data in enumerate(
[
simple_request,
# multimodal_request,
custom_request,
],
1,
):
print(f"\n--- Test {i} ---")
print(f"Request: {json.dumps(request_data, indent=2)}")
try:
print(f"Sending request to {base_url}/responses")
async with session.post(
f"{base_url}/responses",
json=request_data,
headers={"Content-Type": "application/json", "X-API-Key": api_key}
headers={"Content-Type": "application/json", "X-API-Key": api_key},
) as response:
result = await response.json()
print(f"Status: {response.status}")
print(f"Response: {json.dumps(result, indent=2)}")
except Exception as e:
print(f"Error: {e}")
def curl_examples():
"""Print curl command examples."""
print("=== CURL Examples ===\n")
print("1. Simple text request:")
print("""curl http://localhost:8000/responses \\
print(
"""curl http://localhost:8000/responses \\
-H "Content-Type: application/json" \\
-d '{
"model": "anthropic/claude-3-5-sonnet-20241022",
"input": "Tell me a three sentence bedtime story about a unicorn."
}'""")
}'"""
)
print("\n2. Multi-modal request with image:")
print("""curl http://localhost:8000/responses \\
print(
"""curl http://localhost:8000/responses \\
-H "Content-Type: application/json" \\
-d '{
"model": "anthropic/claude-3-5-sonnet-20241022",
@@ -115,10 +118,12 @@ def curl_examples():
]
}
]
}'""")
}'"""
)
print("\n3. Request with custom configuration:")
print("""curl http://localhost:8000/responses \\
print(
"""curl http://localhost:8000/responses \\
-H "Content-Type: application/json" \\
-d '{
"model": "anthropic/claude-3-5-sonnet-20241022",
@@ -131,50 +136,49 @@ def curl_examples():
"os_type": "linux",
"provider_type": "cloud"
}
}'""")
}'"""
)
async def test_p2p_client():
"""Example P2P client using peerjs-python."""
try:
from peerjs import Peer, PeerOptions, ConnectionEventType
from aiortc import RTCConfiguration, RTCIceServer
from peerjs import ConnectionEventType, Peer, PeerOptions
# Set up client peer
options = PeerOptions(
host="0.peerjs.com",
port=443,
secure=True,
config=RTCConfiguration(
iceServers=[RTCIceServer(urls="stun:stun.l.google.com:19302")]
)
config=RTCConfiguration(iceServers=[RTCIceServer(urls="stun:stun.l.google.com:19302")]),
)
client_peer = Peer(id="test-client", peer_options=options)
await client_peer.start()
# Connect to proxy server
connection = client_peer.connect("computer-agent-proxy")
@connection.on(ConnectionEventType.Open)
async def connection_open():
print("Connected to proxy server")
# Send a test request
request = {
"model": "anthropic/claude-3-5-sonnet-20241022",
"input": "Hello from P2P client!"
"input": "Hello from P2P client!",
}
await connection.send(json.dumps(request))
@connection.on(ConnectionEventType.Data)
async def connection_data(data):
print(f"Received response: {data}")
await client_peer.destroy()
# Wait for connection
await asyncio.sleep(10)
except ImportError:
print("P2P dependencies not available. Install peerjs-python for P2P testing.")
except Exception as e:
@@ -183,7 +187,7 @@ async def test_p2p_client():
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] == "curl":
curl_examples()
elif len(sys.argv) > 1 and sys.argv[1] == "p2p":

View File

@@ -7,24 +7,25 @@ import json
import logging
import os
from contextlib import contextmanager
from typing import Dict, Any, List, Union, Optional
from typing import Any, Dict, List, Optional, Union
from computer import Computer
from ..agent import ComputerAgent
from computer import Computer
logger = logging.getLogger(__name__)
class ResponsesHandler:
"""Handler for /responses endpoint that processes agent requests."""
def __init__(self):
self.computer = None
self.agent = None
# Simple in-memory caches
self._computer_cache: Dict[str, Any] = {}
self._agent_cache: Dict[str, Any] = {}
async def setup_computer_agent(
self,
model: str,
@@ -75,7 +76,9 @@ class ResponsesHandler:
computer = Computer(**default_c_config)
await computer.__aenter__()
self._computer_cache[comp_key] = computer
logger.info(f"Computer created and cached with key={comp_key} config={default_c_config}")
logger.info(
f"Computer created and cached with key={comp_key} config={default_c_config}"
)
else:
logger.info(f"Reusing cached computer for key={comp_key}")
@@ -115,14 +118,14 @@ class ResponsesHandler:
# Bind current agent reference
self.agent = agent
async def process_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process a /responses request and return the result.
Args:
request_data: Dictionary containing model, input, and optional kwargs
Returns:
Dictionary with the agent's response
"""
@@ -133,12 +136,12 @@ class ResponsesHandler:
agent_kwargs = request_data.get("agent_kwargs", {})
computer_kwargs = request_data.get("computer_kwargs", {})
env_overrides = request_data.get("env", {}) or {}
if not model:
raise ValueError("Model is required")
if not input_data:
raise ValueError("Input is required")
# Apply env overrides for the duration of this request
with self._env_overrides(env_overrides):
# Set up (and possibly reuse) computer and agent via caches
@@ -155,28 +158,22 @@ class ResponsesHandler:
# Run agent and get first result
async for result in agent.run(messages):
# Return the first result and break
return {
"success": True,
"result": result,
"model": model
}
return {"success": True, "result": result, "model": model}
# If no results were yielded
return {
"success": False,
"error": "No results from agent",
"model": model
}
return {"success": False, "error": "No results from agent", "model": model}
except Exception as e:
logger.error(f"Error processing request: {e}")
return {
"success": False,
"error": str(e),
"model": request_data.get("model", "unknown")
"model": request_data.get("model", "unknown"),
}
def _convert_input_to_messages(self, input_data: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
def _convert_input_to_messages(
self, input_data: Union[str, List[Dict[str, Any]]]
) -> List[Dict[str, Any]]:
"""Convert input data to messages format."""
if isinstance(input_data, str):
# Simple string input
@@ -192,22 +189,18 @@ class ResponsesHandler:
if part.get("type") == "input_text":
content_parts.append({"type": "text", "text": part["text"]})
elif part.get("type") == "input_image":
content_parts.append({
"type": "image_url",
"image_url": {"url": part["image_url"]}
})
content_parts.append(
{"type": "image_url", "image_url": {"url": part["image_url"]}}
)
else:
content_parts.append(part)
messages.append({
"role": msg["role"],
"content": content_parts
})
messages.append({"role": msg["role"], "content": content_parts})
else:
messages.append(msg)
return messages
else:
raise ValueError("Input must be string or list of messages")
async def cleanup(self):
"""Clean up resources."""
if self.computer:

File diff suppressed because it is too large Load Diff

View File

@@ -2,37 +2,43 @@
Type definitions for agent
"""
from typing import Dict, List, Any, Optional, Callable, Protocol, Literal
from pydantic import BaseModel
import re
from litellm import ResponseInputParam, ResponsesAPIResponse, ToolParam
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Literal, Optional, Protocol
from litellm import ResponseInputParam, ResponsesAPIResponse, ToolParam
from pydantic import BaseModel
# Agent input types
Messages = str | ResponseInputParam | List[Dict[str, Any]]
Tools = Optional[Iterable[ToolParam]]
# Agent output types
AgentResponse = ResponsesAPIResponse
AgentResponse = ResponsesAPIResponse
AgentCapability = Literal["step", "click"]
# Exception types
class ToolError(RuntimeError):
"""Base exception for tool-related errors"""
pass
class IllegalArgumentError(ToolError):
"""Exception raised when function arguments are invalid"""
pass
# Agent config registration
class AgentConfigInfo(BaseModel):
"""Information about a registered agent config"""
agent_class: type
models_regex: str
priority: int = 0
def matches_model(self, model: str) -> bool:
"""Check if this agent config matches the given model"""
return bool(re.match(self.models_regex, model))

View File

@@ -2,6 +2,6 @@
UI components for agent
"""
from .gradio import launch_ui, create_gradio_ui
from .gradio import create_gradio_ui, launch_ui
__all__ = ["launch_ui", "create_gradio_ui"]

View File

@@ -1,4 +1,4 @@
from .gradio import launch_ui
if __name__ == "__main__":
launch_ui()
launch_ui()

View File

@@ -18,21 +18,21 @@ Requirements:
- OpenAI or Anthropic API key
"""
import os
import asyncio
import logging
import json
import logging
import os
import platform
from pathlib import Path
from typing import Dict, List, Optional, AsyncGenerator, Any, Tuple, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast
import gradio as gr
from gradio.components.chatbot import MetadataDict
from typing import cast
# Import from agent package
from agent import ComputerAgent
from agent.types import Messages, AgentResponse
from agent.types import AgentResponse, Messages
from computer import Computer
from gradio.components.chatbot import MetadataDict
# Global variables
global_agent = None
@@ -42,11 +42,13 @@ SETTINGS_FILE = Path(".gradio_settings.json")
logging.basicConfig(level=logging.INFO)
import dotenv
if dotenv.load_dotenv():
print(f"DEBUG - Loaded environment variables from {dotenv.find_dotenv()}")
else:
print("DEBUG - No .env file found")
# --- Settings Load/Save Functions ---
def load_settings() -> Dict[str, Any]:
"""Loads settings from the JSON file."""
@@ -84,7 +86,7 @@ def save_settings(settings: Dict[str, Any]):
# async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None:
# """Add screenshot to chatbot when a screenshot is taken."""
# image_markdown = f"![Screenshot after {action_type}](data:image/png;base64,{screenshot_base64})"
# if self.chatbot_history is not None:
# self.chatbot_history.append(
# gr.ChatMessage(
@@ -141,7 +143,7 @@ def get_model_string(model_name: str, loop_provider: str) -> str:
ollama_model = model_name.split("OMNI: Ollama ", 1)[1]
return f"omniparser+ollama_chat/{ollama_model}"
return "omniparser+ollama_chat/llama3"
# Map based on loop provider
mapping = MODEL_MAPPINGS.get(loop_provider.lower(), MODEL_MAPPINGS["openai"])
return mapping.get(model_name, mapping["default"])
@@ -151,6 +153,7 @@ def get_ollama_models() -> List[str]:
"""Get available models from Ollama if installed."""
try:
import subprocess
result = subprocess.run(["ollama", "list"], capture_output=True, text=True)
if result.returncode == 0:
lines = result.stdout.strip().split("\n")
@@ -174,16 +177,14 @@ def create_computer_instance(
os_type: str = "macos",
provider_type: str = "lume",
name: Optional[str] = None,
api_key: Optional[str] = None
api_key: Optional[str] = None,
) -> Computer:
"""Create or get the global Computer instance."""
global global_computer
if global_computer is None:
if provider_type == "localhost":
global_computer = Computer(
verbosity=verbosity,
os_type=os_type,
use_host_computer_server=True
verbosity=verbosity, os_type=os_type, use_host_computer_server=True
)
else:
global_computer = Computer(
@@ -191,7 +192,7 @@ def create_computer_instance(
os_type=os_type,
provider_type=provider_type,
name=name if name else "",
api_key=api_key
api_key=api_key,
)
return global_computer
@@ -217,7 +218,7 @@ def create_agent(
os_type=computer_os,
provider_type=computer_provider,
name=computer_name,
api_key=computer_api_key
api_key=computer_api_key,
)
# Handle custom models
@@ -233,12 +234,15 @@ def create_agent(
"only_n_most_recent_images": only_n_most_recent_images,
"verbosity": verbosity,
}
if save_trajectory:
agent_kwargs["trajectory_dir"] = "trajectories"
if max_trajectory_budget:
agent_kwargs["max_trajectory_budget"] = {"max_budget": max_trajectory_budget, "raise_error": True}
agent_kwargs["max_trajectory_budget"] = {
"max_budget": max_trajectory_budget,
"raise_error": True,
}
global_agent = ComputerAgent(**agent_kwargs)
return global_agent
@@ -247,7 +251,8 @@ def create_agent(
def launch_ui():
"""Standalone function to launch the Gradio app."""
from agent.ui.gradio.ui_components import create_gradio_ui
print(f"Starting Gradio app for CUA Agent...")
print("Starting Gradio app for CUA Agent...")
demo = create_gradio_ui()
demo.launch(share=False, inbrowser=True)

View File

@@ -2,19 +2,25 @@
UI Components for the Gradio interface
"""
import os
import asyncio
import logging
import json
import logging
import os
import platform
from pathlib import Path
from typing import Dict, List, Optional, Any, cast
from typing import Any, Dict, List, Optional, cast
import gradio as gr
from gradio.components.chatbot import MetadataDict
from .app import (
load_settings, save_settings, create_agent, get_model_string,
get_ollama_models, global_agent, global_computer
create_agent,
get_model_string,
get_ollama_models,
global_agent,
global_computer,
load_settings,
save_settings,
)
# Global messages array to maintain conversation history
@@ -23,15 +29,15 @@ global_messages = []
def create_gradio_ui() -> gr.Blocks:
"""Create a Gradio UI for the Computer-Use Agent."""
# Load settings
saved_settings = load_settings()
# Check for API keys
openai_api_key = os.environ.get("OPENAI_API_KEY", "")
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY", "")
cua_api_key = os.environ.get("CUA_API_KEY", "")
# Model choices
openai_models = ["OpenAI: Computer-Use Preview"]
anthropic_models = [
@@ -43,10 +49,10 @@ def create_gradio_ui() -> gr.Blocks:
omni_models = [
"OMNI: OpenAI GPT-4o",
"OMNI: OpenAI GPT-4o mini",
"OMNI: Claude 3.7 Sonnet (20250219)",
"OMNI: Claude 3.5 Sonnet (20241022)"
"OMNI: Claude 3.7 Sonnet (20250219)",
"OMNI: Claude 3.5 Sonnet (20241022)",
]
# Check if API keys are available
has_openai_key = bool(openai_api_key)
has_anthropic_key = bool(anthropic_api_key)
@@ -59,15 +65,20 @@ def create_gradio_ui() -> gr.Blocks:
# Detect platform
is_mac = platform.system().lower() == "darwin"
# Format model choices
provider_to_models = {
"OPENAI": openai_models,
"ANTHROPIC": anthropic_models,
"OMNI": omni_models + ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
"UITARS": ([
"huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
] if is_mac else []) + ["Custom model (OpenAI compatible API)"],
"UITARS": (
[
"huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
]
if is_mac
else []
)
+ ["Custom model (OpenAI compatible API)"],
}
# Apply saved settings
@@ -82,7 +93,9 @@ def create_gradio_ui() -> gr.Blocks:
elif initial_loop == "ANTHROPIC":
initial_model = anthropic_models[0] if anthropic_models else "No models available"
else: # OMNI
initial_model = omni_models[0] if omni_models else "Custom model (OpenAI compatible API)"
initial_model = (
omni_models[0] if omni_models else "Custom model (OpenAI compatible API)"
)
initial_custom_model = saved_settings.get("custom_model", "Qwen2.5-VL-7B-Instruct")
initial_provider_base_url = saved_settings.get("provider_base_url", "http://localhost:1234/v1")
@@ -96,16 +109,27 @@ def create_gradio_ui() -> gr.Blocks:
"Open Safari, search for 'macOS automation tools', and save the first three results as bookmarks",
"Configure SSH keys and set up a connection to a remote server",
]
def generate_python_code(agent_loop_choice, model_name, tasks, recent_images=3, save_trajectory=True, computer_os="linux", computer_provider="cloud", container_name="", cua_cloud_api_key="", max_budget=None):
def generate_python_code(
agent_loop_choice,
model_name,
tasks,
recent_images=3,
save_trajectory=True,
computer_os="linux",
computer_provider="cloud",
container_name="",
cua_cloud_api_key="",
max_budget=None,
):
"""Generate Python code for the current configuration and tasks."""
tasks_str = ""
for task in tasks:
if task and task.strip():
tasks_str += f' "{task}",\n'
model_string = get_model_string(model_name, agent_loop_choice)
computer_args = []
if computer_os != "macos":
computer_args.append(f'os_type="{computer_os}"')
@@ -115,14 +139,14 @@ def create_gradio_ui() -> gr.Blocks:
computer_args.append(f'name="{container_name}"')
if cua_cloud_api_key:
computer_args.append(f'api_key="{cua_cloud_api_key}"')
computer_args_str = ", ".join(computer_args)
if computer_args_str:
computer_args_str = f"({computer_args_str})"
else:
computer_args_str = "()"
code = f'''import asyncio
code = f"""import asyncio
from computer import Computer
from agent import ComputerAgent
@@ -131,22 +155,22 @@ async def main():
agent = ComputerAgent(
model="{model_string}",
tools=[computer],
only_n_most_recent_images={recent_images},'''
only_n_most_recent_images={recent_images},"""
if save_trajectory:
code += '''
trajectory_dir="trajectories",'''
code += """
trajectory_dir="trajectories","""
if max_budget:
code += f'''
max_trajectory_budget={{"max_budget": {max_budget}, "raise_error": True}},'''
code += '''
code += f"""
max_trajectory_budget={{"max_budget": {max_budget}, "raise_error": True}},"""
code += """
)
'''
"""
if tasks_str:
code += f'''
code += f"""
# Prompts for the computer-use agent
tasks = [
{tasks_str.rstrip()}
@@ -158,23 +182,23 @@ async def main():
async for result in agent.run(messages):
for item in result["output"]:
if item["type"] == "message":
print(item["content"][0]["text"])'''
print(item["content"][0]["text"])"""
else:
code += f'''
code += """
# Execute a single task
task = "Search for information about CUA on GitHub"
print(f"Executing task: {{task}}")
messages = [{{"role": "user", "content": task}}]
print(f"Executing task: {task}")
messages = [{"role": "user", "content": task}]
async for result in agent.run(messages):
for item in result["output"]:
if item["type"] == "message":
print(item["content"][0]["text"])'''
print(item["content"][0]["text"])"""
code += '''
code += """
if __name__ == "__main__":
asyncio.run(main())'''
asyncio.run(main())"""
return code
# Create the Gradio interface
@@ -199,11 +223,11 @@ if __name__ == "__main__":
value=generate_python_code(initial_loop, "gpt-4o", []),
interactive=False,
)
with gr.Accordion("Computer Configuration", open=True):
is_windows = platform.system().lower() == "windows"
is_mac = platform.system().lower() == "darwin"
providers = ["cloud", "localhost", "docker"]
if is_mac:
providers += ["lume"]
@@ -227,30 +251,30 @@ if __name__ == "__main__":
value=computer_choices[0],
info="Select the operating system for the computer",
)
computer_provider = gr.Radio(
choices=providers,
label="Provider",
value="lume" if is_mac else "cloud",
info="Select the computer provider",
)
container_name = gr.Textbox(
label="Container Name",
placeholder="Enter container name (optional)",
value=os.environ.get("CUA_CONTAINER_NAME", ""),
info="Optional name for the container",
)
cua_cloud_api_key = gr.Textbox(
label="CUA Cloud API Key",
placeholder="Enter your CUA Cloud API key",
value=os.environ.get("CUA_API_KEY", ""),
type="password",
info="Required for cloud provider",
visible=(not has_cua_key)
visible=(not has_cua_key),
)
with gr.Accordion("Agent Configuration", open=True):
agent_loop = gr.Dropdown(
choices=["OPENAI", "ANTHROPIC", "OMNI", "UITARS"],
@@ -267,90 +291,113 @@ if __name__ == "__main__":
value=openai_models[0] if openai_models else "No models available",
info="Select OpenAI model",
interactive=True,
visible=(initial_loop == "OPENAI")
visible=(initial_loop == "OPENAI"),
)
anthropic_model_choice = gr.Dropdown(
choices=anthropic_models,
label="Anthropic Model",
value=anthropic_models[0] if anthropic_models else "No models available",
value=(
anthropic_models[0] if anthropic_models else "No models available"
),
info="Select Anthropic model",
interactive=True,
visible=(initial_loop == "ANTHROPIC")
visible=(initial_loop == "ANTHROPIC"),
)
omni_model_choice = gr.Dropdown(
choices=omni_models + ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
choices=omni_models
+ ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
label="OMNI Model",
value=omni_models[0] if omni_models else "Custom model (OpenAI compatible API)",
value=(
omni_models[0]
if omni_models
else "Custom model (OpenAI compatible API)"
),
info="Select OMNI model or choose a custom model option",
interactive=True,
visible=(initial_loop == "OMNI")
visible=(initial_loop == "OMNI"),
)
uitars_model_choice = gr.Dropdown(
choices=provider_to_models.get("UITARS", ["No models available"]),
label="UITARS Model",
value=provider_to_models.get("UITARS", ["No models available"])[0] if provider_to_models.get("UITARS") else "No models available",
value=(
provider_to_models.get("UITARS", ["No models available"])[0]
if provider_to_models.get("UITARS")
else "No models available"
),
info="Select UITARS model",
interactive=True,
visible=(initial_loop == "UITARS")
visible=(initial_loop == "UITARS"),
)
model_choice = gr.Textbox(visible=False)
# API key inputs
with gr.Group(visible=not has_openai_key and (initial_loop == "OPENAI" or initial_loop == "OMNI")) as openai_key_group:
with gr.Group(
visible=not has_openai_key
and (initial_loop == "OPENAI" or initial_loop == "OMNI")
) as openai_key_group:
openai_api_key_input = gr.Textbox(
label="OpenAI API Key",
placeholder="Enter your OpenAI API key",
value=os.environ.get("OPENAI_API_KEY", ""),
interactive=True,
type="password",
info="Required for OpenAI models"
info="Required for OpenAI models",
)
with gr.Group(visible=not has_anthropic_key and (initial_loop == "ANTHROPIC" or initial_loop == "OMNI")) as anthropic_key_group:
with gr.Group(
visible=not has_anthropic_key
and (initial_loop == "ANTHROPIC" or initial_loop == "OMNI")
) as anthropic_key_group:
anthropic_api_key_input = gr.Textbox(
label="Anthropic API Key",
placeholder="Enter your Anthropic API key",
value=os.environ.get("ANTHROPIC_API_KEY", ""),
interactive=True,
type="password",
info="Required for Anthropic models"
info="Required for Anthropic models",
)
# API key handlers
def set_openai_api_key(key):
if key and key.strip():
os.environ["OPENAI_API_KEY"] = key.strip()
print(f"DEBUG - Set OpenAI API key environment variable")
print("DEBUG - Set OpenAI API key environment variable")
return key
def set_anthropic_api_key(key):
if key and key.strip():
os.environ["ANTHROPIC_API_KEY"] = key.strip()
print(f"DEBUG - Set Anthropic API key environment variable")
print("DEBUG - Set Anthropic API key environment variable")
return key
openai_api_key_input.change(
fn=set_openai_api_key,
inputs=[openai_api_key_input],
outputs=[openai_api_key_input],
queue=False
queue=False,
)
anthropic_api_key_input.change(
fn=set_anthropic_api_key,
inputs=[anthropic_api_key_input],
outputs=[anthropic_api_key_input],
queue=False
queue=False,
)
# UI update function
def update_ui(loop=None, openai_model=None, anthropic_model=None, omni_model=None, uitars_model=None):
def update_ui(
loop=None,
openai_model=None,
anthropic_model=None,
omni_model=None,
uitars_model=None,
):
loop = loop or agent_loop.value
model_value = None
if loop == "OPENAI" and openai_model:
model_value = openai_model
@@ -360,21 +407,37 @@ if __name__ == "__main__":
model_value = omni_model
elif loop == "UITARS" and uitars_model:
model_value = uitars_model
openai_visible = (loop == "OPENAI")
anthropic_visible = (loop == "ANTHROPIC")
omni_visible = (loop == "OMNI")
uitars_visible = (loop == "UITARS")
show_openai_key = not has_openai_key and (loop == "OPENAI" or (loop == "OMNI" and model_value and "OpenAI" in model_value and "Custom" not in model_value))
show_anthropic_key = not has_anthropic_key and (loop == "ANTHROPIC" or (loop == "OMNI" and model_value and "Claude" in model_value and "Custom" not in model_value))
openai_visible = loop == "OPENAI"
anthropic_visible = loop == "ANTHROPIC"
omni_visible = loop == "OMNI"
uitars_visible = loop == "UITARS"
show_openai_key = not has_openai_key and (
loop == "OPENAI"
or (
loop == "OMNI"
and model_value
and "OpenAI" in model_value
and "Custom" not in model_value
)
)
show_anthropic_key = not has_anthropic_key and (
loop == "ANTHROPIC"
or (
loop == "OMNI"
and model_value
and "Claude" in model_value
and "Custom" not in model_value
)
)
is_custom_openai_api = model_value == "Custom model (OpenAI compatible API)"
is_custom_ollama = model_value == "Custom model (ollama)"
is_any_custom = is_custom_openai_api or is_custom_ollama
model_choice_value = model_value if model_value else ""
return [
gr.update(visible=openai_visible),
gr.update(visible=anthropic_visible),
@@ -385,15 +448,18 @@ if __name__ == "__main__":
gr.update(visible=is_any_custom),
gr.update(visible=is_custom_openai_api),
gr.update(visible=is_custom_openai_api),
gr.update(value=model_choice_value)
gr.update(value=model_choice_value),
]
# Custom model inputs
custom_model = gr.Textbox(
label="Custom Model Name",
placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct or llama3)",
value=initial_custom_model,
visible=(initial_model == "Custom model (OpenAI compatible API)" or initial_model == "Custom model (ollama)"),
visible=(
initial_model == "Custom model (OpenAI compatible API)"
or initial_model == "Custom model (ollama)"
),
interactive=True,
)
@@ -413,36 +479,56 @@ if __name__ == "__main__":
interactive=True,
type="password",
)
# Provider visibility update function
def update_provider_visibility(provider):
"""Update visibility of container name and API key based on selected provider."""
is_localhost = provider == "localhost"
return [
gr.update(visible=not is_localhost), # container_name
gr.update(visible=not is_localhost and not has_cua_key) # cua_cloud_api_key
gr.update(
visible=not is_localhost and not has_cua_key
), # cua_cloud_api_key
]
# Connect provider change event
computer_provider.change(
fn=update_provider_visibility,
inputs=[computer_provider],
outputs=[container_name, cua_cloud_api_key],
queue=False
queue=False,
)
# Connect UI update events
for dropdown in [agent_loop, omni_model_choice, uitars_model_choice, openai_model_choice, anthropic_model_choice]:
for dropdown in [
agent_loop,
omni_model_choice,
uitars_model_choice,
openai_model_choice,
anthropic_model_choice,
]:
dropdown.change(
fn=update_ui,
inputs=[agent_loop, openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice],
outputs=[
openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice,
openai_key_group, anthropic_key_group,
custom_model, provider_base_url, provider_api_key,
model_choice
inputs=[
agent_loop,
openai_model_choice,
anthropic_model_choice,
omni_model_choice,
uitars_model_choice,
],
queue=False
outputs=[
openai_model_choice,
anthropic_model_choice,
omni_model_choice,
uitars_model_choice,
openai_key_group,
anthropic_key_group,
custom_model,
provider_base_url,
provider_api_key,
model_choice,
],
queue=False,
)
save_trajectory = gr.Checkbox(
@@ -461,7 +547,7 @@ if __name__ == "__main__":
info="Number of recent images to keep in context",
interactive=True,
)
max_budget = gr.Number(
label="Max Budget ($)",
value=lambda: None,
@@ -479,9 +565,7 @@ if __name__ == "__main__":
)
chatbot_history = gr.Chatbot(type="messages")
msg = gr.Textbox(
placeholder="Ask me to perform tasks in a virtual environment"
)
msg = gr.Textbox(placeholder="Ask me to perform tasks in a virtual environment")
clear = gr.Button("Clear")
cancel_button = gr.Button("Cancel", variant="stop")
@@ -498,11 +582,23 @@ if __name__ == "__main__":
global global_agent
if global_agent:
print("DEBUG - Cancelling agent task")
history.append(gr.ChatMessage(role="assistant", content="Task cancelled by user", metadata={"title": "❌ Cancelled"}))
history.append(
gr.ChatMessage(
role="assistant",
content="Task cancelled by user",
metadata={"title": "❌ Cancelled"},
)
)
else:
history.append(gr.ChatMessage(role="assistant", content="No active agent task to cancel", metadata={"title": " Info"}))
history.append(
gr.ChatMessage(
role="assistant",
content="No active agent task to cancel",
metadata={"title": " Info"},
)
)
return history
# Process response function
async def process_response(
history,
@@ -542,10 +638,13 @@ if __name__ == "__main__":
model_choice_value = uitars_model_value
else:
model_choice_value = "No models available"
# Determine if this is a custom model selection
is_custom_model_selected = model_choice_value in ["Custom model (OpenAI compatible API)", "Custom model (ollama)"]
is_custom_model_selected = model_choice_value in [
"Custom model (OpenAI compatible API)",
"Custom model (ollama)",
]
# Determine the model name string to analyze
if is_custom_model_selected:
model_string_to_analyze = custom_model_value
@@ -583,13 +682,19 @@ if __name__ == "__main__":
model_string=model_string,
save_trajectory=save_traj,
only_n_most_recent_images=recent_imgs,
custom_model_name=custom_model_value if is_custom_model_selected else None,
custom_model_name=(
custom_model_value if is_custom_model_selected else None
),
computer_os=computer_os,
computer_provider=computer_provider,
computer_name=container_name,
computer_api_key=cua_cloud_api_key,
verbosity=logging.DEBUG,
max_trajectory_budget=max_budget_value if max_budget_value and max_budget_value > 0 else None,
max_trajectory_budget=(
max_budget_value
if max_budget_value and max_budget_value > 0
else None
),
)
if global_agent is None:
@@ -605,7 +710,7 @@ if __name__ == "__main__":
# Add user message to global history
global global_messages
global_messages.append({"role": "user", "content": last_user_message})
# Stream responses from the agent
async for result in global_agent.run(global_messages):
global_messages += result.get("output", [])
@@ -613,18 +718,20 @@ if __name__ == "__main__":
# from pprint import pprint
# pprint(result)
# print(f"DEBUG - Agent response ------- END")
# Process the result output
for item in result.get("output", []):
if item.get("type") == "message":
content = item.get("content", [])
for content_part in content:
if content_part.get("text"):
history.append(gr.ChatMessage(
role=item.get("role", "assistant"),
content=content_part.get("text", ""),
metadata=content_part.get("metadata", {})
))
history.append(
gr.ChatMessage(
role=item.get("role", "assistant"),
content=content_part.get("text", ""),
metadata=content_part.get("metadata", {}),
)
)
elif item.get("type") == "computer_call":
action = item.get("action", {})
action_type = action.get("type", "")
@@ -632,43 +739,52 @@ if __name__ == "__main__":
action_title = f"🛠️ Performing {action_type}"
if action.get("x") and action.get("y"):
action_title += f" at ({action['x']}, {action['y']})"
history.append(gr.ChatMessage(
role="assistant",
content=f"```json\n{json.dumps(action)}\n```",
metadata={"title": action_title}
))
history.append(
gr.ChatMessage(
role="assistant",
content=f"```json\n{json.dumps(action)}\n```",
metadata={"title": action_title},
)
)
elif item.get("type") == "function_call":
function_name = item.get("name", "")
arguments = item.get("arguments", "{}")
history.append(gr.ChatMessage(
role="assistant",
content=f"🔧 Calling function: {function_name}\n```json\n{arguments}\n```",
metadata={"title": f"Function Call: {function_name}"}
))
history.append(
gr.ChatMessage(
role="assistant",
content=f"🔧 Calling function: {function_name}\n```json\n{arguments}\n```",
metadata={"title": f"Function Call: {function_name}"},
)
)
elif item.get("type") == "function_call_output":
output = item.get("output", "")
history.append(gr.ChatMessage(
role="assistant",
content=f"📤 Function output:\n```\n{output}\n```",
metadata={"title": "Function Output"}
))
history.append(
gr.ChatMessage(
role="assistant",
content=f"📤 Function output:\n```\n{output}\n```",
metadata={"title": "Function Output"},
)
)
elif item.get("type") == "computer_call_output":
output = item.get("output", {}).get("image_url", "")
image_markdown = f"![Computer output]({output})"
history.append(gr.ChatMessage(
role="assistant",
content=image_markdown,
metadata={"title": "🖥️ Computer Output"}
))
history.append(
gr.ChatMessage(
role="assistant",
content=image_markdown,
metadata={"title": "🖥️ Computer Output"},
)
)
yield history
except Exception as e:
import traceback
traceback.print_exc()
history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
yield history
# Connect the submit button
submit_event = msg.submit(
fn=chat_submit,
@@ -706,44 +822,77 @@ if __name__ == "__main__":
global global_messages
global_messages.clear()
return None
clear.click(clear_chat, None, chatbot_history, queue=False)
# Connect cancel button
cancel_button.click(
cancel_agent_task,
[chatbot_history],
[chatbot_history],
queue=False
cancel_agent_task, [chatbot_history], [chatbot_history], queue=False
)
# Code display update function
def update_code_display(agent_loop, model_choice_val, custom_model_val, chat_history, recent_images_val, save_trajectory_val, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget_val):
def update_code_display(
agent_loop,
model_choice_val,
custom_model_val,
chat_history,
recent_images_val,
save_trajectory_val,
computer_os,
computer_provider,
container_name,
cua_cloud_api_key,
max_budget_val,
):
messages = []
if chat_history:
for msg in chat_history:
if isinstance(msg, dict) and msg.get("role") == "user":
messages.append(msg.get("content", ""))
return generate_python_code(
agent_loop,
model_choice_val or custom_model_val or "gpt-4o",
messages,
agent_loop,
model_choice_val or custom_model_val or "gpt-4o",
messages,
recent_images_val,
save_trajectory_val,
computer_os,
computer_provider,
container_name,
cua_cloud_api_key,
max_budget_val
max_budget_val,
)
# Update code display when configuration changes
for component in [agent_loop, model_choice, custom_model, chatbot_history, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget]:
for component in [
agent_loop,
model_choice,
custom_model,
chatbot_history,
recent_images,
save_trajectory,
computer_os,
computer_provider,
container_name,
cua_cloud_api_key,
max_budget,
]:
component.change(
update_code_display,
inputs=[agent_loop, model_choice, custom_model, chatbot_history, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget],
outputs=[code_display]
inputs=[
agent_loop,
model_choice,
custom_model,
chatbot_history,
recent_images,
save_trajectory,
computer_os,
computer_provider,
container_name,
cua_cloud_api_key,
max_budget,
],
outputs=[code_display],
)
return demo

View File

@@ -5,26 +5,30 @@ This directory contains benchmarks designed to test agent providers in the Compu
## Overview
The benchmark system evaluates models on GUI grounding tasks, specifically click prediction accuracy. It supports both:
- **Computer Agent SDK providers** (using model strings like `"huggingface-local/HelloKKMe/GTA1-7B"`)
- **Reference agent implementations** (custom model classes implementing the `ModelProtocol`)
## Available Benchmarks
### 1. ScreenSpot-v2 (`ss-v2.py`)
- **Dataset**: ScreenSpot-v2 (click-only GUI grounding)
- **Format**: Standard resolution screenshots
- **Task**: Predict click coordinates given an instruction and image
- **Metrics**: Accuracy, Error Rate, Timing, VRAM usage
### 2. ScreenSpot-Pro (`ss-pro.py`)
### 2. ScreenSpot-Pro (`ss-pro.py`)
- **Dataset**: ScreenSpot-Pro (high-resolution click-only GUI grounding)
- **Format**: High-resolution screenshots
- **Task**: Predict click coordinates given an instruction and image
- **Metrics**: Accuracy, Error Rate, Timing, VRAM usage
### 3. Interactive Testing (`interactive.py`)
- **Real-time testing**: Take screenshots and visualize model predictions
- **Commands**:
- **Commands**:
- Type instruction → test all models on last screenshot
- `screenshot` → take screenshot
- `models` → list available models
@@ -34,14 +38,16 @@ The benchmark system evaluates models on GUI grounding tasks, specifically click
## Running Benchmarks
### 1. Configure Models
Edit `utils.py` to specify which models you want to test in `get_available_models()`.
### 2. Run Benchmark
```bash
# ScreenSpot-v2 benchmark
python ss-v2.py --samples 50
# ScreenSpot-Pro benchmark
# ScreenSpot-Pro benchmark
python ss-pro.py --samples 50
# Interactive testing
@@ -51,6 +57,7 @@ python interactive.py
## Output
### Console Output
```
Model Results:
Accuracy: 85.50% (171/200)
@@ -59,10 +66,11 @@ Model Results:
```
### Generated Files
- **Markdown Report**: `*_results.md` with detailed results tables
- **Visualizations**: `output/` directory with prediction visualizations
- **Interactive Output**: `interactive_output/` for interactive session results
## Contributing
To add a new reference model, follow the instructions in [contrib.md](contrib.md).
To add a new reference model, follow the instructions in [contrib.md](contrib.md).

View File

@@ -17,29 +17,29 @@ class YourModelName(ModelProtocol):
def __init__(self, model_path: str):
self.model_path = model_path
self._model = None
@property
def model_name(self) -> str:
return self.model_path
async def load_model(self) -> None:
"""Load the model into memory."""
# Your model loading logic here
pass
async def unload_model(self) -> None:
"""Unload the model from memory."""
# Your model cleanup logic here
pass
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates for the given image and instruction.
Args:
image: PIL Image to analyze
instruction: Text instruction describing what to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
@@ -56,7 +56,7 @@ def get_available_models() -> List[Union[str, ModelProtocol]]:
models = [
# Computer Agent SDK providers
"huggingface-local/HelloKKMe/GTA1-7B",
# Reference implementations
GTA1Model("HelloKKMe/GTA1-7B"),
YourModelName("path/to/your/model"), # Add your model here
@@ -79,6 +79,7 @@ This will help you verify that your model loads correctly and produces reasonabl
Here's a complete example of adding a hypothetical "MyVisionModel":
1. **Create `models/my_vision_model.py`:**
```python
import torch
from transformers import AutoModel, AutoProcessor
@@ -91,11 +92,11 @@ class MyVisionModel(ModelProtocol):
self.model_path = model_path
self.model = None
self.processor = None
@property
def model_name(self) -> str:
return f"MyVisionModel({self.model_path})"
async def load_model(self) -> None:
"""Load the model and processor."""
self.processor = AutoProcessor.from_pretrained(self.model_path)
@@ -104,7 +105,7 @@ class MyVisionModel(ModelProtocol):
torch_dtype=torch.float16,
device_map="auto"
)
async def unload_model(self) -> None:
"""Clean up model resources."""
del self.model
@@ -112,7 +113,7 @@ class MyVisionModel(ModelProtocol):
self.model = None
self.processor = None
torch.cuda.empty_cache()
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
"""Predict click coordinates."""
try:
@@ -122,19 +123,19 @@ class MyVisionModel(ModelProtocol):
images=image,
return_tensors="pt"
)
# Run inference
with torch.no_grad():
outputs = self.model(**inputs)
# Extract coordinates (model-specific logic)
x, y = self._extract_coordinates(outputs)
return (int(x), int(y))
except Exception as e:
print(f"Prediction failed: {e}")
return None
def _extract_coordinates(self, outputs):
"""Extract x, y coordinates from model outputs."""
# Your model-specific coordinate extraction logic
@@ -142,6 +143,7 @@ class MyVisionModel(ModelProtocol):
```
2. **Update `models/__init__.py`:**
```python
from .gta1 import GTA1Model
from .my_vision_model import MyVisionModel
@@ -150,6 +152,7 @@ __all__ = ["GTA1Model", "MyVisionModel"]
```
3. **Update `utils.py`:**
```python
from models import GTA1Model, MyVisionModel

View File

@@ -9,60 +9,56 @@ Models are loaded/unloaded one at a time to avoid memory issues.
import asyncio
import os
from datetime import datetime
from typing import List, Dict, Any
from typing import Any, Dict, List
from utils import (
ModelWrapper,
take_screenshot,
get_available_models,
save_prediction_visualization,
get_available_models
take_screenshot,
)
async def predict_with_all_models(image, instruction: str, models) -> List[Dict[str, Any]]:
"""
Predict click coordinates with all models sequentially.
Args:
image: PIL Image to analyze
instruction: Instruction text
models: List of model instances
Returns:
List of prediction results
"""
predictions = []
for model in models:
model_wrapper = ModelWrapper(model)
print(f"\n🔄 Loading {model_wrapper.model_name}...")
try:
# Load model
await model_wrapper.load_model()
# Predict
coords = await model_wrapper.predict_click(image, instruction)
predictions.append({
'model_name': model_wrapper.model_name,
'coords': coords,
'error': None
})
predictions.append(
{"model_name": model_wrapper.model_name, "coords": coords, "error": None}
)
if coords:
print(f"{model_wrapper.model_name}: ({coords[0]}, {coords[1]})")
else:
print(f"{model_wrapper.model_name}: No prediction")
except Exception as e:
print(f"{model_wrapper.model_name}: ERROR - {str(e)}")
predictions.append({
'model_name': model_wrapper.model_name,
'coords': None,
'error': str(e)
})
predictions.append(
{"model_name": model_wrapper.model_name, "coords": None, "error": str(e)}
)
finally:
# Always unload model to free memory
try:
@@ -70,7 +66,7 @@ async def predict_with_all_models(image, instruction: str, models) -> List[Dict[
print(f"🗑️ Unloaded {model_wrapper.model_name}")
except Exception as e:
print(f"⚠️ Error unloading {model_wrapper.model_name}: {e}")
return predictions
@@ -103,87 +99,91 @@ async def main():
Main interactive loop.
"""
print_header()
# Get available models
models = get_available_models()
print_models(models)
# Create output directory for visualizations
output_dir = "interactive_output"
os.makedirs(output_dir, exist_ok=True)
session_count = 0
last_screenshot = None
screenshot_timestamp = None
while True:
try:
# Get user input
print(f"\n{'='*40}")
user_input = input("🎯 Enter instruction (or command): ").strip()
if not user_input:
continue
# Handle commands
if user_input.lower() in ['quit', 'exit', 'q']:
if user_input.lower() in ["quit", "exit", "q"]:
print("👋 Goodbye!")
break
elif user_input.lower() == 'models':
elif user_input.lower() == "models":
print_models(models)
continue
elif user_input.lower() == 'screenshot':
elif user_input.lower() == "screenshot":
print("📸 Taking screenshot...")
try:
last_screenshot = take_screenshot()
screenshot_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
screenshot_path = os.path.join(output_dir, f"screenshot_{screenshot_timestamp}.png")
screenshot_path = os.path.join(
output_dir, f"screenshot_{screenshot_timestamp}.png"
)
last_screenshot.save(screenshot_path)
print(f"✅ Screenshot captured and saved to: {screenshot_path}")
print(f"📝 Ready for instructions! Screenshot size: {last_screenshot.size}")
except Exception as e:
print(f"❌ Error taking screenshot: {e}")
continue
# Handle instruction input
if last_screenshot is None:
print("⚠️ No screenshot available! Please take a screenshot first using 'screenshot' command.")
print(
"⚠️ No screenshot available! Please take a screenshot first using 'screenshot' command."
)
continue
session_count += 1
print(f"\n🎯 Session {session_count}: '{user_input}'")
print(f"📷 Using screenshot from: {screenshot_timestamp}")
# Predict with all models using last screenshot
print(f"\n🤖 Testing {len(models)} models on screenshot...")
predictions = await predict_with_all_models(last_screenshot, user_input, models)
# Display results summary
print(f"\n📊 Results Summary:")
print("\n📊 Results Summary:")
print("-" * 50)
for pred in predictions:
if pred['coords']:
if pred["coords"]:
print(f"{pred['model_name']}: ({pred['coords'][0]}, {pred['coords'][1]})")
elif pred['error']:
elif pred["error"]:
print(f"{pred['model_name']}: ERROR - {pred['error']}")
else:
print(f"{pred['model_name']}: No prediction")
# Save visualization
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
vis_filename = f"session_{session_count:03d}_{timestamp}.png"
vis_path = os.path.join(output_dir, vis_filename)
try:
save_prediction_visualization(last_screenshot, user_input, predictions, vis_path)
print(f"\n💾 Visualization saved to: {vis_path}")
except Exception as e:
print(f"⚠️ Error saving visualization: {e}")
print(f"\n✨ Session {session_count} completed!")
except KeyboardInterrupt:
print("\n\n👋 Interrupted by user. Goodbye!")
break

View File

@@ -2,34 +2,37 @@
Base protocol for benchmark models.
"""
from typing import Protocol, Optional, Tuple
from typing import Optional, Protocol, Tuple
from PIL import Image
class ModelProtocol(Protocol):
"""Protocol for benchmark models that can predict click coordinates."""
@property
def model_name(self) -> str:
"""Return the name of the model."""
...
async def load_model(self) -> None:
"""Load the model into memory."""
...
async def unload_model(self) -> None:
"""Unload the model from memory."""
...
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
async def predict_click(
self, image: Image.Image, instruction: str
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates for the given image and instruction.
Args:
image: PIL Image to analyze
instruction: Text instruction describing what to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""

View File

@@ -2,54 +2,51 @@
GTA1 model implementation for benchmarking.
"""
from typing import Optional, Tuple
from PIL import Image
import torch
import re
import gc
import re
from typing import Optional, Tuple
import torch
from PIL import Image
from qwen_vl_utils import process_vision_info, smart_resize
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from .base import ModelProtocol
class GTA1Model:
"""Ground truth GTA1 model implementation."""
def __init__(self, model_path: str = "HelloKKMe/GTA1-7B"):
self.model_path = model_path
self.model = None
self.processor = None
self.max_new_tokens = 32
self.system_prompt = '''
self.system_prompt = """
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
Output the coordinate pair exactly:
(x,y)
'''.strip()
""".strip()
@property
def model_name(self) -> str:
"""Return the name of the model."""
return f"GTA1-{self.model_path.split('/')[-1]}"
async def load_model(self) -> None:
"""Load the model into memory."""
if self.model is None:
print(f"Loading GTA1 model: {self.model_path}")
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.model_path,
torch_dtype=torch.bfloat16,
device_map="auto"
self.model_path, torch_dtype=torch.bfloat16, device_map="auto"
)
self.processor = AutoProcessor.from_pretrained(
self.model_path,
min_pixels=3136,
max_pixels=4096 * 2160
self.model_path, min_pixels=3136, max_pixels=4096 * 2160
)
print("GTA1 model loaded successfully")
async def unload_model(self) -> None:
"""Unload the model from memory."""
if self.model is not None:
@@ -62,23 +59,25 @@ Output the coordinate pair exactly:
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("GTA1 model unloaded")
def _extract_coordinates(self, raw_string: str) -> Tuple[int, int]:
"""Extract coordinates from model output."""
try:
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
return tuple(map(int, map(float, matches[0]))) # type: ignore
return tuple(map(int, map(float, matches[0]))) # type: ignore
except:
return (0, 0)
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
async def predict_click(
self, image: Image.Image, instruction: str
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates for the given image and instruction.
Args:
image: PIL Image to analyze
instruction: Text instruction describing what to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
@@ -87,76 +86,73 @@ Output the coordinate pair exactly:
assert self.processor is not None
assert self.model is not None
try:
width, height = image.width, image.height
# Resize image according to processor requirements
resized_height, resized_width = smart_resize(
image.height,
image.width,
factor=self.processor.image_processor.patch_size * self.processor.image_processor.merge_size,
factor=self.processor.image_processor.patch_size
* self.processor.image_processor.merge_size,
min_pixels=self.processor.image_processor.min_pixels,
max_pixels=self.processor.image_processor.max_pixels,
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height
# Prepare messages
system_message = {
"role": "system",
"content": self.system_prompt.format(height=resized_height, width=resized_width)
"content": self.system_prompt.format(height=resized_height, width=resized_width),
}
user_message = {
"role": "user",
"content": [
{"type": "image", "image": resized_image},
{"type": "text", "text": instruction}
]
{"type": "text", "text": instruction},
],
}
# Process inputs
image_inputs, video_inputs = process_vision_info([system_message, user_message]) # type: ignore
image_inputs, video_inputs = process_vision_info([system_message, user_message]) # type: ignore
text = self.processor.apply_chat_template(
[system_message, user_message],
tokenize=False,
add_generation_prompt=True
[system_message, user_message], tokenize=False, add_generation_prompt=True
)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
# Generate prediction
output_ids = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=False,
temperature=1.0,
use_cache=True
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=False,
temperature=1.0,
use_cache=True,
)
generated_ids = [
output_ids[len(input_ids):]
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
output_text = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)[0]
# Extract and rescale coordinates
pred_x, pred_y = self._extract_coordinates(output_text)
pred_x = int(pred_x * scale_x)
pred_y = int(pred_y * scale_y)
return (pred_x, pred_y)
except Exception as e:
print(f"Error in GTA1 prediction: {e}")
return None

View File

@@ -15,103 +15,106 @@ from typing import Optional
from datasets import load_dataset
from tqdm import tqdm
from utils import (
ModelWrapper,
is_click_in_bbox,
save_results_to_markdown,
save_visualizations,
ModelWrapper,
get_available_models,
get_gpu_memory
get_gpu_memory,
is_click_in_bbox,
save_results_to_markdown,
save_visualizations,
)
async def evaluate_model(model_wrapper: ModelWrapper, dataset, max_samples: Optional[int] = None) -> dict:
async def evaluate_model(
model_wrapper: ModelWrapper, dataset, max_samples: Optional[int] = None
) -> dict:
"""
Evaluate a model on the ScreenSpot-Pro dataset.
Args:
model_wrapper: ModelWrapper instance
dataset: ScreenSpot-Pro dataset (list of samples)
max_samples: Maximum number of samples to evaluate (None for all)
Returns:
Dictionary with evaluation results
"""
print(f"\nEvaluating model: {model_wrapper.model_name}")
# Load model
await model_wrapper.load_model()
total_samples = len(dataset)
if max_samples is not None:
total_samples = min(max_samples, total_samples)
correct_predictions = 0
error_predictions = 0
results = []
for i in tqdm(range(total_samples), desc=f"Evaluating {model_wrapper.model_name}"):
sample = dataset[i]
# Extract sample data
image = sample['image']
instruction = sample['instruction']
bbox = sample['bbox'] # [x1, y1, x2, y2]
sample_id = sample['img_filename']
image = sample["image"]
instruction = sample["instruction"]
bbox = sample["bbox"] # [x1, y1, x2, y2]
sample_id = sample["img_filename"]
# Predict click coordinates with timing
start_time = time.time()
click_coords = await model_wrapper.predict_click(image, instruction)
prediction_time = time.time() - start_time
# Check if prediction is correct
is_correct = is_click_in_bbox(click_coords, bbox)
if is_correct:
correct_predictions += 1
results.append({
'id': sample_id,
'instruction': instruction,
'bbox': bbox,
'predicted_coords': click_coords,
'is_correct': is_correct,
'failed': False,
'prediction_time': prediction_time
})
results.append(
{
"id": sample_id,
"instruction": instruction,
"bbox": bbox,
"predicted_coords": click_coords,
"is_correct": is_correct,
"failed": False,
"prediction_time": prediction_time,
}
)
# Unload model
await model_wrapper.unload_model()
# Calculate metrics
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
error_rate = error_predictions / total_samples if total_samples > 0 else 0.0
# Calculate timing statistics
successful_times = [r['prediction_time'] for r in results if not r['failed']]
successful_times = [r["prediction_time"] for r in results if not r["failed"]]
avg_prediction_time = sum(successful_times) / len(successful_times) if successful_times else 0.0
median_prediction_time = statistics.median(successful_times) if successful_times else 0.0
min_prediction_time = min(successful_times) if successful_times else 0.0
max_prediction_time = max(successful_times) if successful_times else 0.0
# Get VRAM statistics
vram_stats = model_wrapper.get_vram_stats()
return {
'model_name': model_wrapper.model_name,
'total_samples': total_samples,
'correct_predictions': correct_predictions,
'failed_predictions': error_predictions,
'accuracy': accuracy,
'failure_rate': error_rate,
'avg_prediction_time': avg_prediction_time,
'median_prediction_time': median_prediction_time,
'min_prediction_time': min_prediction_time,
'max_prediction_time': max_prediction_time,
'vram_max_mb': vram_stats['max_mb'],
'vram_avg_mb': vram_stats['avg_mb'],
'results': results
"model_name": model_wrapper.model_name,
"total_samples": total_samples,
"correct_predictions": correct_predictions,
"failed_predictions": error_predictions,
"accuracy": accuracy,
"failure_rate": error_rate,
"avg_prediction_time": avg_prediction_time,
"median_prediction_time": median_prediction_time,
"min_prediction_time": min_prediction_time,
"max_prediction_time": max_prediction_time,
"vram_max_mb": vram_stats["max_mb"],
"vram_avg_mb": vram_stats["avg_mb"],
"results": results,
}
@@ -120,42 +123,44 @@ async def main():
Main function to run the benchmark.
"""
# Parse command line arguments
parser = argparse.ArgumentParser(description='ScreenSpot-Pro Benchmark Script')
parser.add_argument('--samples', type=int, default=300,
help='Number of samples to evaluate (default: 300)')
parser.add_argument('--seed', type=int, default=42,
help='Random seed for shuffling (default: 42)')
parser = argparse.ArgumentParser(description="ScreenSpot-Pro Benchmark Script")
parser.add_argument(
"--samples", type=int, default=300, help="Number of samples to evaluate (default: 300)"
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for shuffling (default: 42)"
)
args = parser.parse_args()
# Set random seed
random.seed(args.seed)
# Load dataset
print("Loading ScreenSpot-Pro dataset...")
ds = load_dataset("lmms-lab/ScreenSpot-Pro")
dataset = ds['train'] # type: ignore
dataset = ds["train"] # type: ignore
# Convert to list to support indexing
dataset_list = list(dataset)
print(f"Dataset loaded: {len(dataset_list)} samples")
# Shuffle dataset with seed
random.shuffle(dataset_list)
print(f"Dataset shuffled with seed {args.seed}")
# Get available models
models = get_available_models()
# Evaluation settings
max_samples = args.samples # Use command line argument
# Run evaluations
all_results = []
for model in models:
model_wrapper = ModelWrapper(model)
result = await evaluate_model(model_wrapper, dataset_list, max_samples)
all_results.append(result)
# Print summary
print(f"\n{result['model_name']} Results:")
print(f" Accuracy: {result['accuracy']*100:.2f}%")
@@ -164,15 +169,17 @@ async def main():
print(f" Error Rate: {result['failure_rate']*100:.2f}%")
print(f" Avg Time: {result['avg_prediction_time']:.2f}s")
print(f" Median Time: {result['median_prediction_time']:.2f}s")
print(f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s")
print(
f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s"
)
print(f" VRAM Max: {result['vram_max_mb']:.1f}MB")
print(f" VRAM Avg: {result['vram_avg_mb']:.1f}MB")
# Print GPU memory info
gpu_memory = get_gpu_memory()
if gpu_memory and gpu_memory[0] > 0:
print(f" GPU Free Memory: {gpu_memory[0]:.1f}MB")
# Save results
if all_results:
save_results_to_markdown(all_results)
@@ -183,4 +190,4 @@ async def main():
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

View File

@@ -15,36 +15,37 @@ from typing import Optional
from datasets import load_dataset
from tqdm import tqdm
from utils import (
ModelWrapper,
is_click_in_bbox,
save_results_to_markdown,
save_visualizations,
ModelWrapper,
get_available_models,
get_gpu_memory
get_gpu_memory,
is_click_in_bbox,
save_results_to_markdown,
save_visualizations,
)
async def evaluate_model(model_wrapper: ModelWrapper, samples, max_samples: Optional[int] = None) -> dict:
async def evaluate_model(
model_wrapper: ModelWrapper, samples, max_samples: Optional[int] = None
) -> dict:
"""
Evaluate a model on any iterable of samples.
Args:
model_wrapper: ModelWrapper instance
samples: Iterable of dicts with keys: image, bbox, instruction
max_samples: Maximum number of samples to evaluate (None for all)
Returns:
Dictionary with evaluation results
"""
print(f"\nEvaluating model: {model_wrapper.model_name}")
# Load model
await model_wrapper.load_model()
# Convert to list if needed and limit samples
if hasattr(samples, '__len__'):
if hasattr(samples, "__len__"):
total_samples = len(samples)
if max_samples is not None:
total_samples = min(max_samples, total_samples)
@@ -55,69 +56,71 @@ async def evaluate_model(model_wrapper: ModelWrapper, samples, max_samples: Opti
if max_samples is not None:
sample_list = sample_list[:max_samples]
total_samples = len(sample_list)
correct_predictions = 0
error_predictions = 0
results = []
for i, sample in enumerate(tqdm(sample_list, desc=f"Evaluating {model_wrapper.model_name}")):
# Extract required data (only these 3 keys matter)
image = sample['image']
instruction = sample['instruction']
bbox = sample['bbox'] # [x1, y1, x2, y2]
image = sample["image"]
instruction = sample["instruction"]
bbox = sample["bbox"] # [x1, y1, x2, y2]
# Predict click coordinates with timing
start_time = time.time()
click_coords = await model_wrapper.predict_click(image, instruction)
prediction_time = time.time() - start_time
# Check if prediction is correct
is_correct = is_click_in_bbox(click_coords, bbox)
if is_correct:
correct_predictions += 1
results.append({
'sample_idx': i,
'instruction': instruction,
'bbox': bbox,
'predicted_coords': click_coords,
'is_correct': is_correct,
'failed': False,
'prediction_time': prediction_time
})
results.append(
{
"sample_idx": i,
"instruction": instruction,
"bbox": bbox,
"predicted_coords": click_coords,
"is_correct": is_correct,
"failed": False,
"prediction_time": prediction_time,
}
)
# Unload model
await model_wrapper.unload_model()
# Calculate metrics
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
error_rate = error_predictions / total_samples if total_samples > 0 else 0.0
# Calculate timing statistics
successful_times = [r['prediction_time'] for r in results if not r['failed']]
successful_times = [r["prediction_time"] for r in results if not r["failed"]]
avg_prediction_time = sum(successful_times) / len(successful_times) if successful_times else 0.0
median_prediction_time = statistics.median(successful_times) if successful_times else 0.0
min_prediction_time = min(successful_times) if successful_times else 0.0
max_prediction_time = max(successful_times) if successful_times else 0.0
# Get VRAM statistics
vram_stats = model_wrapper.get_vram_stats()
return {
'model_name': model_wrapper.model_name,
'total_samples': total_samples,
'correct_predictions': correct_predictions,
'failed_predictions': error_predictions,
'accuracy': accuracy,
'failure_rate': error_rate,
'avg_prediction_time': avg_prediction_time,
'median_prediction_time': median_prediction_time,
'min_prediction_time': min_prediction_time,
'max_prediction_time': max_prediction_time,
'vram_max_mb': vram_stats['max_mb'],
'vram_avg_mb': vram_stats['avg_mb'],
'results': results
"model_name": model_wrapper.model_name,
"total_samples": total_samples,
"correct_predictions": correct_predictions,
"failed_predictions": error_predictions,
"accuracy": accuracy,
"failure_rate": error_rate,
"avg_prediction_time": avg_prediction_time,
"median_prediction_time": median_prediction_time,
"min_prediction_time": min_prediction_time,
"max_prediction_time": max_prediction_time,
"vram_max_mb": vram_stats["max_mb"],
"vram_avg_mb": vram_stats["avg_mb"],
"results": results,
}
@@ -126,56 +129,60 @@ async def main():
Main function to run the benchmark.
"""
# Parse command line arguments
parser = argparse.ArgumentParser(description='ScreenSpot-v2 Benchmark Script')
parser.add_argument('--samples', type=int, default=500,
help='Number of samples to evaluate (default: 500)')
parser.add_argument('--seed', type=int, default=42,
help='Random seed for shuffling (default: 42)')
parser = argparse.ArgumentParser(description="ScreenSpot-v2 Benchmark Script")
parser.add_argument(
"--samples", type=int, default=500, help="Number of samples to evaluate (default: 500)"
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for shuffling (default: 42)"
)
args = parser.parse_args()
# Set random seed
random.seed(args.seed)
# Load dataset
print("Loading ScreenSpot-v2 dataset...")
ds = load_dataset("lmms-lab/ScreenSpot-v2")
dataset = ds['train'] # type: ignore
dataset = ds["train"] # type: ignore
# Convert to simple list of dicts with only required keys
samples = []
for item in dataset:
# Convert dataset item to dict if needed
item_dict = dict(item) if hasattr(item, 'keys') else item
item_dict = dict(item) if hasattr(item, "keys") else item
# Convert ScreenSpot-v2 bbox format [x, y, w, h] to [x1, y1, x2, y2]
bbox_xywh = item_dict['bbox'] # type: ignore
bbox_xywh = item_dict["bbox"] # type: ignore
x, y, w, h = bbox_xywh
bbox_xyxy = [x, y, x + w, y + h]
samples.append({
'image': item_dict['image'], # type: ignore
'instruction': item_dict['instruction'], # type: ignore
'bbox': bbox_xyxy
})
samples.append(
{
"image": item_dict["image"], # type: ignore
"instruction": item_dict["instruction"], # type: ignore
"bbox": bbox_xyxy,
}
)
print(f"Dataset loaded: {len(samples)} samples")
# Shuffle samples with seed
random.shuffle(samples)
print(f"Samples shuffled with seed {args.seed}")
# Get available models
models = get_available_models()
# Evaluation settings
max_samples = args.samples # Use command line argument
# Run evaluations
all_results = []
for model in models:
model_wrapper = ModelWrapper(model)
result = await evaluate_model(model_wrapper, samples, max_samples)
all_results.append(result)
# Print summary
print(f"\n{result['model_name']} Results:")
print(f" Accuracy: {result['accuracy']*100:.2f}%")
@@ -184,18 +191,22 @@ async def main():
print(f" Error Rate: {result['failure_rate']*100:.2f}%")
print(f" Avg Time: {result['avg_prediction_time']:.2f}s")
print(f" Median Time: {result['median_prediction_time']:.2f}s")
print(f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s")
print(
f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s"
)
print(f" VRAM Max: {result['vram_max_mb']:.1f}MB")
print(f" VRAM Avg: {result['vram_avg_mb']:.1f}MB")
# Print GPU memory info
gpu_memory = get_gpu_memory()
if gpu_memory and gpu_memory[0] > 0:
print(f" GPU Free Memory: {gpu_memory[0]:.1f}MB")
# Save results
if all_results:
save_results_to_markdown(all_results, "screenspot_v2_results.md", title="ScreenSpot-v2 Benchmark Results")
save_results_to_markdown(
all_results, "screenspot_v2_results.md", title="ScreenSpot-v2 Benchmark Results"
)
save_visualizations(all_results, samples)
print("\nBenchmark completed successfully!")
else:
@@ -203,4 +214,4 @@ async def main():
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

View File

@@ -4,38 +4,40 @@ Shared utilities for ScreenSpot-Pro benchmarking and interactive testing.
"""
import dotenv
dotenv.load_dotenv()
import asyncio
import base64
import gc
import os
import sys
import subprocess as sp
import statistics
import subprocess as sp
import sys
from datetime import datetime
from io import BytesIO
from typing import List, Union, Tuple, Optional
from typing import List, Optional, Tuple, Union
import torch
from PIL import Image, ImageDraw
from tqdm import tqdm
import gc
import torch
# Add parent directory to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from agent.agent import ComputerAgent
from models.base import ModelProtocol
def get_gpu_memory() -> List[int]:
"""
Get GPU memory usage using nvidia-smi.
Returns:
List of free memory values in MB for each GPU
"""
try:
command = "nvidia-smi --query-gpu=memory.free --format=csv"
memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
memory_free_info = sp.check_output(command.split()).decode("ascii").split("\n")[:-1][1:]
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
return memory_free_values
except (sp.CalledProcessError, FileNotFoundError, IndexError):
@@ -51,39 +53,34 @@ def get_gpu_memory() -> List[int]:
def get_vram_usage() -> dict:
"""
Get current VRAM usage statistics.
Returns:
Dictionary with VRAM usage info (in MB)
"""
if torch.cuda.is_available():
device = torch.cuda.current_device()
allocated = torch.cuda.memory_allocated(device) / 1024 / 1024 # Convert to MB
reserved = torch.cuda.memory_reserved(device) / 1024 / 1024 # Convert to MB
reserved = torch.cuda.memory_reserved(device) / 1024 / 1024 # Convert to MB
total = torch.cuda.get_device_properties(device).total_memory / 1024 / 1024
return {
'allocated_mb': allocated,
'reserved_mb': reserved,
'total_mb': total,
'free_mb': total - reserved
"allocated_mb": allocated,
"reserved_mb": reserved,
"total_mb": total,
"free_mb": total - reserved,
}
else:
return {
'allocated_mb': 0.0,
'reserved_mb': 0.0,
'total_mb': 0.0,
'free_mb': 0.0
}
return {"allocated_mb": 0.0, "reserved_mb": 0.0, "total_mb": 0.0, "free_mb": 0.0}
def get_available_models() -> List[Union[str, ModelProtocol]]:
"""
Get list of available models for testing.
Returns:
List of model strings and model classes
"""
local_provider = "huggingface-local/" # Options: huggingface-local/ or mlx/
# from models.gta1 import GTA1Model
models = [
@@ -94,42 +91,41 @@ def get_available_models() -> List[Union[str, ModelProtocol]]:
# f"{local_provider}HelloKKMe/GTA1-32B",
"openai/computer-use-preview+openai/gpt-4o-mini",
"anthropic/claude-opus-4-20250514+openai/gpt-4o-mini",
# === Reference model classes ===
# GTA1Model("HelloKKMe/GTA1-7B"),
# GTA1Model("HelloKKMe/GTA1-32B"),
# GTA1Model("HelloKKMe/GTA1-32B"),
]
return models
def is_click_in_bbox(click_coords: Optional[Tuple[int, int]], bbox: List[int]) -> bool:
"""
Check if click coordinates are within the bounding box.
Args:
click_coords: (x, y) coordinates or None
bbox: [x1, y1, x2, y2] bounding box
Returns:
True if click is within bbox, False otherwise
"""
if click_coords is None:
return False
x, y = click_coords
x1, y1, x2, y2 = bbox
return x1 <= x <= x2 and y1 <= y <= y2
def image_to_base64(image: Image.Image) -> str:
"""
Convert PIL Image to base64 string.
Args:
image: PIL Image
Returns:
Base64 encoded image string
"""
@@ -142,213 +138,252 @@ class ModelWrapper:
"""
Wrapper to provide unified interface for both ComputerAgent and custom models.
"""
def __init__(self, model: Union[str, ModelProtocol]):
self.model = model
self.is_computer_agent = isinstance(model, str)
self.agent: Optional[ComputerAgent] = None
self.vram_usage_history: List[float] = [] # Track VRAM usage over time
if self.is_computer_agent:
self.model_name = str(model)
else:
self.model_name = f"{model.__class__.__name__}('{getattr(model, 'model_name', 'unknown')}')"
self.model_name = (
f"{model.__class__.__name__}('{getattr(model, 'model_name', 'unknown')}')"
)
async def load_model(self) -> None:
"""Load the model."""
if self.is_computer_agent:
self.agent = ComputerAgent(model=str(self.model))
else:
await self.model.load_model() # type: ignore
await self.model.load_model() # type: ignore
# Record initial VRAM usage after loading
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
self.vram_usage_history.append(vram_info["allocated_mb"])
async def unload_model(self) -> None:
"""Unload the model."""
if not self.is_computer_agent:
await self.model.unload_model() # type: ignore
await self.model.unload_model() # type: ignore
else:
del self.agent
self.agent = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Record VRAM usage after unloading
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
self.vram_usage_history.append(vram_info["allocated_mb"])
def get_vram_stats(self) -> dict:
"""Get VRAM usage statistics for this model."""
if not self.vram_usage_history:
return {'max_mb': 0.0, 'avg_mb': 0.0}
return {"max_mb": 0.0, "avg_mb": 0.0}
return {
'max_mb': max(self.vram_usage_history),
'avg_mb': sum(self.vram_usage_history) / len(self.vram_usage_history)
"max_mb": max(self.vram_usage_history),
"avg_mb": sum(self.vram_usage_history) / len(self.vram_usage_history),
}
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
async def predict_click(
self, image: Image.Image, instruction: str
) -> Optional[Tuple[int, int]]:
"""Predict click coordinates."""
# Record VRAM usage before prediction
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
self.vram_usage_history.append(vram_info["allocated_mb"])
if self.is_computer_agent:
if self.agent is None:
await self.load_model()
if self.agent is not None:
image_b64 = image_to_base64(image)
result = await self.agent.predict_click(instruction=instruction, image_b64=image_b64)
result = await self.agent.predict_click(
instruction=instruction, image_b64=image_b64
)
# Record VRAM usage after prediction
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
self.vram_usage_history.append(vram_info["allocated_mb"])
return result
return None
else:
result = await self.model.predict_click(image, instruction) # type: ignore
result = await self.model.predict_click(image, instruction) # type: ignore
# Record VRAM usage after prediction
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
self.vram_usage_history.append(vram_info["allocated_mb"])
return result
def save_results_to_markdown(all_results: List[dict],output_file: str = "screenspot_pro_results.md", title: str = "ScreenSpot-Pro Benchmark Results") -> None:
def save_results_to_markdown(
all_results: List[dict],
output_file: str = "screenspot_pro_results.md",
title: str = "ScreenSpot-Pro Benchmark Results",
) -> None:
"""
Save evaluation results to a markdown table.
Args:
all_results: List of evaluation results for each model
output_file: Output markdown file path
"""
with open(output_file, 'w', encoding='utf-8') as f:
with open(output_file, "w", encoding="utf-8") as f:
f.write(f"# {title}\n\n")
f.write(f"**Evaluation Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
# Summary table
f.write("## Summary\n\n")
f.write("| Model | Total Samples | Correct | Errors | Accuracy | Error Rate | Avg Time (s) | Median Time (s) | Time Range (s) | VRAM Max (GB) | VRAM Avg (GB) |\n")
f.write("|-------|---------------|---------|--------|----------|------------|--------------|-----------------|----------------|---------------|---------------|\n")
f.write(
"| Model | Total Samples | Correct | Errors | Accuracy | Error Rate | Avg Time (s) | Median Time (s) | Time Range (s) | VRAM Max (GB) | VRAM Avg (GB) |\n"
)
f.write(
"|-------|---------------|---------|--------|----------|------------|--------------|-----------------|----------------|---------------|---------------|\n"
)
for result in all_results:
model_name = result['model_name']
total = result['total_samples']
correct = result['correct_predictions']
errors = result['failed_predictions']
accuracy = result['accuracy'] * 100
error_rate = result['failure_rate'] * 100
avg_time = result.get('avg_prediction_time', 0.0)
median_time = result.get('median_prediction_time', 0.0)
min_time = result.get('min_prediction_time', 0.0)
max_time = result.get('max_prediction_time', 0.0)
model_name = result["model_name"]
total = result["total_samples"]
correct = result["correct_predictions"]
errors = result["failed_predictions"]
accuracy = result["accuracy"] * 100
error_rate = result["failure_rate"] * 100
avg_time = result.get("avg_prediction_time", 0.0)
median_time = result.get("median_prediction_time", 0.0)
min_time = result.get("min_prediction_time", 0.0)
max_time = result.get("max_prediction_time", 0.0)
time_range = f"{min_time:.2f} - {max_time:.2f}"
vram_max = result.get('vram_max_mb', 0.0) / 1024
vram_avg = result.get('vram_avg_mb', 0.0) / 1024
f.write(f"| {model_name} | {total} | {correct} | {errors} | {accuracy:.2f}% | {error_rate:.2f}% | {avg_time:.2f} | {median_time:.2f} | {time_range} | {vram_max:.1f} | {vram_avg:.1f} |\n")
vram_max = result.get("vram_max_mb", 0.0) / 1024
vram_avg = result.get("vram_avg_mb", 0.0) / 1024
f.write(
f"| {model_name} | {total} | {correct} | {errors} | {accuracy:.2f}% | {error_rate:.2f}% | {avg_time:.2f} | {median_time:.2f} | {time_range} | {vram_max:.1f} | {vram_avg:.1f} |\n"
)
# Detailed results for each model
for result in all_results:
f.write(f"\n## {result['model_name']} - Detailed Results\n\n")
f.write("| Sample Index | Instruction | BBox | Predicted | Correct | Error | Time (s) |\n")
f.write(
"| Sample Index | Instruction | BBox | Predicted | Correct | Error | Time (s) |\n"
)
f.write("|-----------|-------------|------|-----------|---------|-------|----------|\n")
for sample_result in result['results'][:10]: # Show first 10 samples
sample_idx = sample_result['sample_idx']
instruction = sample_result['instruction'][:50] + "..." if len(sample_result['instruction']) > 50 else sample_result['instruction']
bbox = str(sample_result['bbox'])
predicted = str(sample_result['predicted_coords']) if sample_result['predicted_coords'] else "None"
correct = "PASS" if sample_result['is_correct'] else "FAIL"
error = "YES" if sample_result['failed'] else "NO"
pred_time = sample_result.get('prediction_time', 0.0)
f.write(f"| {sample_idx} | {instruction} | {bbox} | {predicted} | {correct} | {error} | {pred_time:.2f} |\n")
if len(result['results']) > 10:
for sample_result in result["results"][:10]: # Show first 10 samples
sample_idx = sample_result["sample_idx"]
instruction = (
sample_result["instruction"][:50] + "..."
if len(sample_result["instruction"]) > 50
else sample_result["instruction"]
)
bbox = str(sample_result["bbox"])
predicted = (
str(sample_result["predicted_coords"])
if sample_result["predicted_coords"]
else "None"
)
correct = "PASS" if sample_result["is_correct"] else "FAIL"
error = "YES" if sample_result["failed"] else "NO"
pred_time = sample_result.get("prediction_time", 0.0)
f.write(
f"| {sample_idx} | {instruction} | {bbox} | {predicted} | {correct} | {error} | {pred_time:.2f} |\n"
)
if len(result["results"]) > 10:
f.write(f"\n*Showing first 10 of {len(result['results'])} samples*\n")
print(f"\nResults saved to: {output_file}")
def save_visualizations(all_results: List[dict], samples, output_dir: str = "output") -> None:
"""
Save visualizations of predicted coordinates vs bboxes to an output folder.
Args:
all_results: List of evaluation results for each model
samples: List of sample dicts with image, bbox, instruction keys
output_dir: Output directory path
"""
os.makedirs(output_dir, exist_ok=True)
for result in all_results:
model_name = result['model_name'].replace('/', '_').replace('\\', '_')
model_name = result["model_name"].replace("/", "_").replace("\\", "_")
model_dir = os.path.join(output_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
print(f"Saving visualizations for {result['model_name']}...")
# Save first 10 samples for visualization
for i, sample_result in enumerate(tqdm(result['results'][:10], desc=f"Saving {model_name} visualizations")):
for i, sample_result in enumerate(
tqdm(result["results"][:10], desc=f"Saving {model_name} visualizations")
):
# Get sample data using index
sample_idx = sample_result['sample_idx']
sample_idx = sample_result["sample_idx"]
if sample_idx < len(samples):
sample = samples[sample_idx]
image = sample['image'].copy() # Make a copy to avoid modifying original
image = sample["image"].copy() # Make a copy to avoid modifying original
else:
print(f"Warning: Could not find sample at index {sample_idx}")
continue
bbox = sample_result['bbox']
predicted_coords = sample_result['predicted_coords']
is_correct = sample_result['is_correct']
bbox = sample_result["bbox"]
predicted_coords = sample_result["predicted_coords"]
is_correct = sample_result["is_correct"]
# Draw on image
draw = ImageDraw.Draw(image)
# Draw bounding box (ground truth) in green
x1, y1, x2, y2 = bbox
draw.rectangle([x1, y1, x2, y2], outline="green", width=3)
draw.text((x1, y1-20), "Ground Truth", fill="green")
draw.text((x1, y1 - 20), "Ground Truth", fill="green")
# Draw predicted click in red or blue
if predicted_coords is not None:
px, py = predicted_coords
color = "blue" if is_correct else "red"
# Draw crosshair
crosshair_size = 15
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=3)
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=3)
draw.text((px+10, py-20), f"Predicted ({px},{py})", fill=color)
draw.line(
[(px - crosshair_size, py), (px + crosshair_size, py)], fill=color, width=3
)
draw.line(
[(px, py - crosshair_size), (px, py + crosshair_size)], fill=color, width=3
)
draw.text((px + 10, py - 20), f"Predicted ({px},{py})", fill=color)
# Add status text
status = "CORRECT" if is_correct else "INCORRECT"
status_color = "blue" if is_correct else "red"
draw.text((10, 10), f"Status: {status}", fill=status_color)
draw.text((10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black")
draw.text(
(10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black"
)
# Save image
filename = f"sample_{i+1:02d}_idx{sample_idx}_{status.lower()}.png"
filepath = os.path.join(model_dir, filename)
image.save(filepath)
print(f"Visualizations saved to: {model_dir}")
def save_prediction_visualization(image: Image.Image, instruction: str, predictions: List[dict],
output_file: str = "interactive_prediction.png") -> None:
def save_prediction_visualization(
image: Image.Image,
instruction: str,
predictions: List[dict],
output_file: str = "interactive_prediction.png",
) -> None:
"""
Save visualization of multiple model predictions on a single image.
Args:
image: PIL Image to visualize
instruction: Instruction text
@@ -358,32 +393,32 @@ def save_prediction_visualization(image: Image.Image, instruction: str, predicti
# Create a copy of the image
vis_image = image.copy()
draw = ImageDraw.Draw(vis_image)
# Colors for different models
colors = ["red", "blue", "orange", "purple", "brown", "pink", "gray", "olive"]
# Draw predictions
for i, pred in enumerate(predictions):
color = colors[i % len(colors)]
model_name = pred['model_name']
coords = pred.get('coords')
error = pred.get('error')
model_name = pred["model_name"]
coords = pred.get("coords")
error = pred.get("error")
if coords is not None:
px, py = coords
# Draw crosshair
crosshair_size = 20
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=4)
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=4)
draw.line([(px - crosshair_size, py), (px + crosshair_size, py)], fill=color, width=4)
draw.line([(px, py - crosshair_size), (px, py + crosshair_size)], fill=color, width=4)
# Draw model name
draw.text((px+15, py+15), f"{model_name}: ({px},{py})", fill=color)
draw.text((px + 15, py + 15), f"{model_name}: ({px},{py})", fill=color)
else:
# Draw error text
draw.text((10, 50 + i*20), f"{model_name}: ERROR - {error}", fill=color)
draw.text((10, 50 + i * 20), f"{model_name}: ERROR - {error}", fill=color)
# Add instruction at the top
draw.text((10, 10), f"Instruction: {instruction}", fill="black")
# Save image
vis_image.save(output_file)
print(f"Prediction visualization saved to: {output_file}")
@@ -392,12 +427,13 @@ def save_prediction_visualization(image: Image.Image, instruction: str, predicti
def take_screenshot() -> Image.Image:
"""
Take a screenshot of the current screen.
Returns:
PIL Image of the screenshot
"""
try:
import pyautogui
screenshot = pyautogui.screenshot()
return screenshot
except ImportError:
@@ -406,4 +442,3 @@ def take_screenshot() -> Image.Image:
except Exception as e:
print(f"Error taking screenshot: {e}")
raise

View File

@@ -9,58 +9,61 @@ from agent import ComputerAgent
from computer import Computer
from computer.helpers import sandboxed
@sandboxed()
def read_file(location: str) -> str:
"""Read contents of a file
Parameters
----------
location : str
Path to the file to read
Returns
-------
str
Contents of the file or error message
"""
try:
with open(location, 'r') as f:
with open(location, "r") as f:
return f.read()
except Exception as e:
return f"Error reading file: {str(e)}"
def save_note(content: str, filename: str = "note.txt") -> str:
"""Save content to a note file
Parameters
----------
content : str
Content to save to the file
filename : str, optional
Name of the file to save to (default is "note.txt")
Returns
-------
str
Success or error message
"""
try:
with open(filename, 'w') as f:
with open(filename, "w") as f:
f.write(content)
return f"Saved note to {filename}"
except Exception as e:
return f"Error saving note: {str(e)}"
def calculate(a: int, b: int) -> int:
"""Calculate the sum of two integers
Parameters
----------
a : int
First integer
b : int
Second integer
Returns
-------
int
@@ -68,15 +71,18 @@ def calculate(a: int, b: int) -> int:
"""
return a + b
async def main():
"""Example usage of ComputerAgent with different models"""
# Example 1: Using Claude with computer and custom tools
print("=== Example 1: Claude with Computer ===")
import os
import dotenv
import json
import os
import dotenv
dotenv.load_dotenv()
assert os.getenv("CUA_CONTAINER_NAME") is not None, "CUA_CONTAINER_NAME is not set"
@@ -86,38 +92,37 @@ async def main():
os_type="linux",
provider_type="cloud",
name=os.getenv("CUA_CONTAINER_NAME") or "",
api_key=os.getenv("CUA_API_KEY") or ""
api_key=os.getenv("CUA_API_KEY") or "",
) as computer:
agent = ComputerAgent(
# Supported models:
# == OpenAI CUA (computer-use-preview) ==
model="openai/computer-use-preview",
# == Anthropic CUA (Claude > 3.5) ==
# model="anthropic/claude-opus-4-20250514",
# model="anthropic/claude-opus-4-20250514",
# model="anthropic/claude-sonnet-4-20250514",
# model="anthropic/claude-3-7-sonnet-20250219",
# model="anthropic/claude-3-5-sonnet-20241022",
# == UI-TARS ==
# model="huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
# TODO: add local mlx provider
# model="mlx-community/UI-TARS-1.5-7B-6bit",
# model="ollama_chat/0000/ui-tars-1.5-7b",
# == Omniparser + Any LLM ==
# model="omniparser+..."
# model="omniparser+anthropic/claude-opus-4-20250514",
tools=[computer],
only_n_most_recent_images=3,
verbosity=logging.INFO,
trajectory_dir="trajectories",
use_prompt_caching=True,
max_trajectory_budget={ "max_budget": 1.0, "raise_error": True, "reset_after_each_run": False },
max_trajectory_budget={
"max_budget": 1.0,
"raise_error": True,
"reset_after_each_run": False,
},
)
history = []
while True:
user_input = input("> ")
@@ -143,5 +148,6 @@ async def main():
# elif item["type"] == "function_call_output":
# print("===>", item["output"])
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

View File

@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
[project]
name = "cua-agent"
version = "0.4.0"
version = "0.4.35"
description = "CUA (Computer Use) Agent for AI-driven computer interaction"
readme = "README.md"
authors = [
@@ -29,6 +29,11 @@ requires-python = ">=3.12"
[project.optional-dependencies]
openai = []
anthropic = []
qwen = [
"qwen-vl-utils",
"qwen-agent",
"Pillow>=10.0.0",
]
omni = [
"cua-som>=0.1.0,<0.2.0",
]
@@ -49,7 +54,7 @@ glm45v-hf = [
opencua-hf = [
"accelerate",
"torch",
"transformers==4.53.0",
"transformers>=4.53.0",
"tiktoken>=0.11.0",
"blobfile>=3.0.0"
]
@@ -60,6 +65,11 @@ internvl-hf = [
"einops",
"timm"
]
moondream3 = [
"accelerate",
"torch",
"transformers>=4.55.0"
]
ui = [
"gradio>=5.23.3",
"python-dotenv>=1.0.1",
@@ -68,7 +78,10 @@ cli = [
"yaspin>=3.1.0",
]
hud = [
"hud-python==0.4.26",
"hud-python==0.4.52",
]
gemini = [
"google-genai>=1.41.0",
]
all = [
# uitars requirements
@@ -88,7 +101,13 @@ all = [
# cli requirements
"yaspin>=3.1.0",
# hud requirements
"hud-python==0.4.26",
"hud-python==0.4.52",
# gemini requirements
"google-genai>=1.41.0",
# qwen requirements
"qwen-vl-utils",
"qwen-agent",
"Pillow>=10.0.0",
]
[tool.uv]
@@ -98,4 +117,4 @@ constraint-dependencies = ["fastrtc>0.43.0", "mlx-audio>0.2.3"]
distribution = true
[tool.pdm.build]
includes = ["agent/"]
includes = ["agent/"]

View File

@@ -0,0 +1,10 @@
[bumpversion]
current_version = 0.1.27
commit = True
tag = True
tag_name = computer-server-v{new_version}
message = Bump cua-computer-server to v{new_version}
[bumpversion:file:pyproject.toml]
search = version = "{current_version}"
replace = version = "{new_version}"

View File

@@ -8,10 +8,11 @@
</picture>
</div>
[![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#)
[![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85)
[![PyPI](https://img.shields.io/pypi/v/cua-computer-server?color=333333)](https://pypi.org/project/cua-computer-server/)
[![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#)
[![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85)
[![PyPI](https://img.shields.io/pypi/v/cua-computer-server?color=333333)](https://pypi.org/project/cua-computer-server/)
</h1>
</div>
@@ -42,4 +43,4 @@ Refer to this notebook for a step-by-step guide on how to use the Computer-Use S
- [Commands](https://trycua.com/docs/libraries/computer-server/Commands)
- [REST-API](https://trycua.com/docs/libraries/computer-server/REST-API)
- [WebSocket-API](https://trycua.com/docs/libraries/computer-server/WebSocket-API)
- [Index](https://trycua.com/docs/libraries/computer-server/index)
- [Index](https://trycua.com/docs/libraries/computer-server/index)

View File

@@ -4,6 +4,7 @@ This allows the server to be started with `python -m computer_server`.
"""
import sys
from .cli import main
if __name__ == "__main__":

View File

@@ -36,7 +36,7 @@ def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
help="Path to SSL private key file (enables HTTPS)",
)
parser.add_argument(
"--ssl-certfile",
"--ssl-certfile",
type=str,
help="Path to SSL certificate file (enables HTTPS)",
)
@@ -72,17 +72,19 @@ def main() -> None:
# Check if watchdog should be enabled
container_name = os.environ.get("CONTAINER_NAME")
enable_watchdog = args.watchdog or bool(container_name)
enable_watchdog = (args.watchdog or bool(container_name)) and not sys.platform.startswith("win")
if container_name:
logger.info(f"Container environment detected (CONTAINER_NAME={container_name}), enabling watchdog")
logger.info(
f"Container environment detected (CONTAINER_NAME={container_name}), enabling watchdog"
)
elif args.watchdog:
logger.info("Watchdog explicitly enabled via --watchdog flag")
# Start watchdog if enabled
if enable_watchdog:
logger.info(f"Starting watchdog monitoring with {args.watchdog_interval}s interval")
def run_watchdog_thread():
"""Run watchdog in a separate thread."""
loop = asyncio.new_event_loop()
@@ -90,38 +92,32 @@ def main() -> None:
try:
# Create CLI args dict for watchdog
cli_args = {
'host': args.host,
'port': args.port,
'log_level': args.log_level,
'ssl_keyfile': args.ssl_keyfile,
'ssl_certfile': args.ssl_certfile
"host": args.host,
"port": args.port,
"log_level": args.log_level,
"ssl_keyfile": args.ssl_keyfile,
"ssl_certfile": args.ssl_certfile,
}
# Create watchdog with restart settings
from .watchdog import Watchdog
watchdog = Watchdog(
cli_args=cli_args,
ping_interval=args.watchdog_interval
)
watchdog = Watchdog(cli_args=cli_args, ping_interval=args.watchdog_interval)
watchdog.restart_enabled = not args.no_restart
loop.run_until_complete(watchdog.start_monitoring())
except Exception as e:
logger.error(f"Watchdog error: {e}")
finally:
loop.close()
# Start watchdog in background thread
watchdog_thread = threading.Thread(
target=run_watchdog_thread,
daemon=True,
name="watchdog"
)
watchdog_thread = threading.Thread(target=run_watchdog_thread, daemon=True, name="watchdog")
watchdog_thread.start()
# Create and start the server
logger.info(f"Starting CUA Computer API server on {args.host}:{args.port}...")
# Handle SSL configuration
ssl_args = {}
if args.ssl_keyfile and args.ssl_certfile:
@@ -131,10 +127,12 @@ def main() -> None:
}
logger.info("HTTPS mode enabled with SSL certificates")
elif args.ssl_keyfile or args.ssl_certfile:
logger.warning("Both --ssl-keyfile and --ssl-certfile are required for HTTPS. Running in HTTP mode.")
logger.warning(
"Both --ssl-keyfile and --ssl-certfile are required for HTTPS. Running in HTTP mode."
)
else:
logger.info("HTTP mode (no SSL certificates provided)")
server = Server(host=args.host, port=args.port, log_level=args.log_level, **ssl_args)
try:

View File

@@ -1,4 +1,5 @@
class BaseDioramaHandler:
"""Base Diorama handler for unsupported OSes."""
async def diorama_cmd(self, action: str, arguments: dict = None) -> dict:
return {"success": False, "error": "Diorama is not supported on this OS yet."}

View File

@@ -1,31 +1,38 @@
#!/usr/bin/env python3
"""Diorama: A virtual desktop manager for macOS"""
import os
import asyncio
import logging
import sys
import io
import logging
import os
import sys
from typing import Union
from PIL import Image, ImageDraw
from computer_server.diorama.draw import capture_all_apps, AppActivationContext, get_frontmost_and_active_app, get_all_windows, get_running_apps
from computer_server.diorama.diorama_computer import DioramaComputer
from computer_server.diorama.draw import (
AppActivationContext,
capture_all_apps,
get_all_windows,
get_frontmost_and_active_app,
get_running_apps,
)
from computer_server.handlers.macos import *
from PIL import Image, ImageDraw
# simple, nicely formatted logging
logger = logging.getLogger(__name__)
automation_handler = MacOSAutomationHandler()
class Diorama:
"""Virtual desktop manager that provides automation capabilities for macOS applications.
Manages application windows and provides an interface for taking screenshots,
mouse interactions, keyboard input, and coordinate transformations between
screenshot space and screen space.
"""
_scheduler_queue = None
_scheduler_task = None
_loop = None
@@ -34,10 +41,10 @@ class Diorama:
@classmethod
def create_from_apps(cls, *args) -> DioramaComputer:
"""Create a DioramaComputer instance from a list of application names.
Args:
*args: Variable number of application names to include in the desktop
Returns:
DioramaComputer: A computer interface for the specified applications
"""
@@ -46,10 +53,10 @@ class Diorama:
# Dictionary to store cursor positions for each unique app_list hash
_cursor_positions = {}
def __init__(self, app_list):
"""Initialize a Diorama instance for the specified applications.
Args:
app_list: List of application names to manage
"""
@@ -57,10 +64,10 @@ class Diorama:
self.interface = self.Interface(self)
self.computer = DioramaComputer(self)
self.focus_context = None
# Create a hash for this app_list to use as a key
self.app_list_hash = hash(tuple(sorted(app_list)))
# Initialize cursor position for this app_list if it doesn't exist
if self.app_list_hash not in Diorama._cursor_positions:
Diorama._cursor_positions[self.app_list_hash] = (0, 0)
@@ -68,7 +75,7 @@ class Diorama:
@classmethod
def _ensure_scheduler(cls):
"""Ensure the async scheduler loop is running.
Creates and starts the scheduler task if it hasn't been started yet.
"""
if not cls._scheduler_started:
@@ -81,7 +88,7 @@ class Diorama:
@classmethod
async def _scheduler_loop(cls):
"""Main scheduler loop that processes automation commands.
Continuously processes commands from the scheduler queue, handling
screenshots, mouse actions, keyboard input, and scrolling operations.
"""
@@ -91,31 +98,37 @@ class Diorama:
args = cmd.get("arguments", {})
future = cmd.get("future")
logger.info(f"Processing command: {action} | args={args}")
app_whitelist = args.get("app_list", [])
all_windows = get_all_windows()
running_apps = get_running_apps()
frontmost_app, active_app_to_use, active_app_pid = get_frontmost_and_active_app(all_windows, running_apps, app_whitelist)
frontmost_app, active_app_to_use, active_app_pid = get_frontmost_and_active_app(
all_windows, running_apps, app_whitelist
)
focus_context = AppActivationContext(active_app_pid, active_app_to_use, logger)
with focus_context:
try:
if action == "screenshot":
logger.info(f"Taking screenshot for apps: {app_whitelist}")
result, img = capture_all_apps(
app_whitelist=app_whitelist,
save_to_disk=False,
take_focus=False
app_whitelist=app_whitelist, save_to_disk=False, take_focus=False
)
logger.info("Screenshot complete.")
if future:
future.set_result((result, img))
# Mouse actions
elif action in ["left_click", "right_click", "double_click", "move_cursor", "drag_to"]:
elif action in [
"left_click",
"right_click",
"double_click",
"move_cursor",
"drag_to",
]:
x = args.get("x")
y = args.get("y")
duration = args.get("duration", 0.5)
if action == "left_click":
await automation_handler.left_click(x, y)
@@ -134,7 +147,7 @@ class Diorama:
y = args.get("y")
if x is not None and y is not None:
await automation_handler.move_cursor(x, y)
clicks = args.get("clicks", 1)
if action == "scroll_up":
await automation_handler.scroll_up(clicks)
@@ -171,31 +184,31 @@ class Diorama:
if future:
future.set_exception(e)
class Interface():
class Interface:
"""Interface for interacting with the virtual desktop.
Provides methods for taking screenshots, mouse interactions, keyboard input,
and coordinate transformations between screenshot and screen coordinates.
"""
def __init__(self, diorama):
"""Initialize the interface with a reference to the parent Diorama instance.
Args:
diorama: The parent Diorama instance
"""
self._diorama = diorama
self._scene_hitboxes = []
self._scene_size = None
async def _send_cmd(self, action, arguments=None):
"""Send a command to the scheduler queue.
Args:
action (str): The action to perform
arguments (dict, optional): Arguments for the action
Returns:
The result of the command execution
"""
@@ -203,11 +216,13 @@ class Diorama:
loop = asyncio.get_event_loop()
future = loop.create_future()
logger.info(f"Enqueuing {action} command for apps: {self._diorama.app_list}")
await Diorama._scheduler_queue.put({
"action": action,
"arguments": {"app_list": self._diorama.app_list, **(arguments or {})},
"future": future
})
await Diorama._scheduler_queue.put(
{
"action": action,
"arguments": {"app_list": self._diorama.app_list, **(arguments or {})},
"future": future,
}
)
try:
return await future
except asyncio.CancelledError:
@@ -216,21 +231,23 @@ class Diorama:
async def screenshot(self, as_bytes: bool = True) -> Union[str, Image.Image]:
"""Take a screenshot of the managed applications.
Args:
as_bytes (bool): If True, return base64-encoded bytes; if False, return PIL Image
Returns:
Union[str, Image.Image]: Base64-encoded PNG bytes or PIL Image object
"""
import base64
result, img = await self._send_cmd("screenshot")
self._scene_hitboxes = result.get("hitboxes", [])
self._scene_size = img.size
if as_bytes:
# PIL Image to bytes, then base64 encode for JSON
import io
img_byte_arr = io.BytesIO()
img.save(img_byte_arr, format="PNG")
img_bytes = img_byte_arr.getvalue()
@@ -241,7 +258,7 @@ class Diorama:
async def left_click(self, x, y):
"""Perform a left mouse click at the specified coordinates.
Args:
x (int): X coordinate in screenshot space (or None to use last position)
y (int): Y coordinate in screenshot space (or None to use last position)
@@ -258,7 +275,7 @@ class Diorama:
async def right_click(self, x, y):
"""Perform a right mouse click at the specified coordinates.
Args:
x (int): X coordinate in screenshot space (or None to use last position)
y (int): Y coordinate in screenshot space (or None to use last position)
@@ -269,13 +286,13 @@ class Diorama:
x, y = x or last_pos[0], y or last_pos[1]
# Update cursor position for this app_list hash
Diorama._cursor_positions[app_list_hash] = (x, y)
sx, sy = await self.to_screen_coordinates(x, y)
await self._send_cmd("right_click", {"x": sx, "y": sy})
async def double_click(self, x, y):
"""Perform a double mouse click at the specified coordinates.
Args:
x (int): X coordinate in screenshot space (or None to use last position)
y (int): Y coordinate in screenshot space (or None to use last position)
@@ -286,13 +303,13 @@ class Diorama:
x, y = x or last_pos[0], y or last_pos[1]
# Update cursor position for this app_list hash
Diorama._cursor_positions[app_list_hash] = (x, y)
sx, sy = await self.to_screen_coordinates(x, y)
await self._send_cmd("double_click", {"x": sx, "y": sy})
async def move_cursor(self, x, y):
"""Move the mouse cursor to the specified coordinates.
Args:
x (int): X coordinate in screenshot space (or None to use last position)
y (int): Y coordinate in screenshot space (or None to use last position)
@@ -303,13 +320,13 @@ class Diorama:
x, y = x or last_pos[0], y or last_pos[1]
# Update cursor position for this app_list hash
Diorama._cursor_positions[app_list_hash] = (x, y)
sx, sy = await self.to_screen_coordinates(x, y)
await self._send_cmd("move_cursor", {"x": sx, "y": sy})
async def drag_to(self, x, y, duration=0.5):
"""Drag the mouse from current position to the specified coordinates.
Args:
x (int): X coordinate in screenshot space (or None to use last position)
y (int): Y coordinate in screenshot space (or None to use last position)
@@ -321,13 +338,13 @@ class Diorama:
x, y = x or last_pos[0], y or last_pos[1]
# Update cursor position for this app_list hash
Diorama._cursor_positions[app_list_hash] = (x, y)
sx, sy = await self.to_screen_coordinates(x, y)
await self._send_cmd("drag_to", {"x": sx, "y": sy, "duration": duration})
async def get_cursor_position(self):
"""Get the current cursor position in screen coordinates.
Returns:
tuple: (x, y) coordinates of the cursor in screen space
"""
@@ -335,7 +352,7 @@ class Diorama:
async def type_text(self, text):
"""Type the specified text using the keyboard.
Args:
text (str): The text to type
"""
@@ -343,7 +360,7 @@ class Diorama:
async def press_key(self, key):
"""Press a single key on the keyboard.
Args:
key (str): The key to press
"""
@@ -351,7 +368,7 @@ class Diorama:
async def hotkey(self, keys):
"""Press a combination of keys simultaneously.
Args:
keys (list): List of keys to press together
"""
@@ -359,7 +376,7 @@ class Diorama:
async def scroll_up(self, clicks: int = 1):
"""Scroll up at the current cursor position.
Args:
clicks (int): Number of scroll clicks to perform
"""
@@ -367,12 +384,12 @@ class Diorama:
app_list_hash = hash(tuple(sorted(self._diorama.app_list)))
last_pos = Diorama._cursor_positions.get(app_list_hash, (0, 0))
x, y = last_pos[0], last_pos[1]
await self._send_cmd("scroll_up", {"clicks": clicks, "x": x, "y": y})
async def scroll_down(self, clicks: int = 1):
"""Scroll down at the current cursor position.
Args:
clicks (int): Number of scroll clicks to perform
"""
@@ -380,18 +397,18 @@ class Diorama:
app_list_hash = hash(tuple(sorted(self._diorama.app_list)))
last_pos = Diorama._cursor_positions.get(app_list_hash, (0, 0))
x, y = last_pos[0], last_pos[1]
await self._send_cmd("scroll_down", {"clicks": clicks, "x": x, "y": y})
async def get_screen_size(self) -> dict[str, int]:
"""Get the size of the screenshot area.
Returns:
dict[str, int]: Dictionary with 'width' and 'height' keys
"""
if not self._scene_size:
await self.screenshot()
return { "width": self._scene_size[0], "height": self._scene_size[1] }
return {"width": self._scene_size[0], "height": self._scene_size[1]}
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
"""Convert screenshot coordinates to screen coordinates.
@@ -404,29 +421,29 @@ class Diorama:
tuple[float, float]: (x, y) absolute coordinates in screen space
"""
if not self._scene_hitboxes:
await self.screenshot() # get hitboxes
await self.screenshot() # get hitboxes
# Try all hitboxes
for h in self._scene_hitboxes[::-1]:
rect_from = h.get("hitbox")
rect_to = h.get("target")
if not rect_from or len(rect_from) != 4:
continue
# check if (x, y) is inside rect_from
x0, y0, x1, y1 = rect_from
if x0 <= x <= x1 and y0 <= y <= y1:
logger.info(f"Found hitbox: {h}")
# remap (x, y) to rect_to
tx0, ty0, tx1, ty1 = rect_to
# calculate offset from x0, y0
offset_x = x - x0
offset_y = y - y0
# remap offset to rect_to
tx = tx0 + offset_x
ty = ty0 + offset_y
return tx, ty
return x, y
@@ -441,34 +458,37 @@ class Diorama:
tuple[float, float]: (x, y) absolute coordinates in screenshot space
"""
if not self._scene_hitboxes:
await self.screenshot() # get hitboxes
await self.screenshot() # get hitboxes
# Try all hitboxes
for h in self._scene_hitboxes[::-1]:
rect_from = h.get("target")
rect_to = h.get("hitbox")
if not rect_from or len(rect_from) != 4:
continue
# check if (x, y) is inside rect_from
x0, y0, x1, y1 = rect_from
if x0 <= x <= x1 and y0 <= y <= y1:
# remap (x, y) to rect_to
tx0, ty0, tx1, ty1 = rect_to
# calculate offset from x0, y0
offset_x = x - x0
offset_y = y - y0
# remap offset to rect_to
tx = tx0 + offset_x
ty = ty0 + offset_y
return tx, ty
return x, y
import pyautogui
import time
import pyautogui
async def main():
"""Main function demonstrating Diorama usage with multiple desktops and mouse tracking."""
desktop1 = Diorama.create_from_apps(["Discord", "Notes"])
@@ -511,7 +531,7 @@ async def main():
# Draw on a copy of the screenshot
frame = base_img.copy()
frame_draw = ImageDraw.Draw(frame)
frame_draw.ellipse((sx-5, sy-5, sx+5, sy+5), fill="blue", outline="blue")
frame_draw.ellipse((sx - 5, sy - 5, sx + 5, sy + 5), fill="blue", outline="blue")
# Save the frame
frame.save("app_screenshots/desktop3_mouse.png")
print(f"Mouse at screen ({mouse_x}, {mouse_y}) -> screenshot ({sx:.1f}, {sy:.1f})")
@@ -520,15 +540,13 @@ async def main():
print("Stopped tracking.")
draw.text((rect[0], rect[1]), str(idx), fill="red")
canvas.save("app_screenshots/desktop3_hitboxes.png")
# move mouse in a square spiral around the screen
import math
import random
step = 20 # pixels per move
dot_radius = 10
width = screen_size["width"]
@@ -539,11 +557,12 @@ async def main():
await desktop3.interface.move_cursor(x, y)
img = await desktop3.interface.screenshot(as_bytes=False)
draw = ImageDraw.Draw(img)
draw.ellipse((x-dot_radius, y-dot_radius, x+dot_radius, y+dot_radius), fill="red")
draw.ellipse((x - dot_radius, y - dot_radius, x + dot_radius, y + dot_radius), fill="red")
img.save("current.png")
await asyncio.sleep(0.03)
x += step
y = math.sin(x / width * math.pi * 2) * 50 + 25
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,14 +1,16 @@
import asyncio
class DioramaComputer:
"""
A minimal Computer-like interface for Diorama, compatible with ComputerAgent.
Implements _initialized, run(), and __aenter__ for agent compatibility.
"""
def __init__(self, diorama):
"""
Initialize the DioramaComputer with a diorama instance.
Args:
diorama: The diorama instance to wrap with a computer-like interface.
"""
@@ -19,10 +21,10 @@ class DioramaComputer:
async def __aenter__(self):
"""
Async context manager entry method for compatibility with ComputerAgent.
Ensures an event loop is running and marks the instance as initialized.
Creates a new event loop if none is currently running.
Returns:
DioramaComputer: The initialized instance.
"""
@@ -37,10 +39,10 @@ class DioramaComputer:
async def run(self):
"""
Run method stub for compatibility with ComputerAgent interface.
Ensures the instance is initialized before returning. If not already
initialized, calls __aenter__ to perform initialization.
Returns:
DioramaComputer: The initialized instance.
"""

File diff suppressed because it is too large Load Diff

View File

@@ -1,13 +1,15 @@
import inspect
import platform
import sys
import platform
import inspect
from computer_server.diorama.diorama import Diorama
from computer_server.diorama.base import BaseDioramaHandler
from typing import Optional
from computer_server.diorama.base import BaseDioramaHandler
from computer_server.diorama.diorama import Diorama
class MacOSDioramaHandler(BaseDioramaHandler):
"""Handler for Diorama commands on macOS, using local diorama module."""
async def diorama_cmd(self, action: str, arguments: Optional[dict] = None) -> dict:
if platform.system().lower() != "darwin":
return {"success": False, "error": "Diorama is only supported on macOS."}
@@ -30,4 +32,5 @@ class MacOSDioramaHandler(BaseDioramaHandler):
return {"success": True, "result": result}
except Exception as e:
import traceback
return {"success": False, "error": str(e), "trace": traceback.format_exc()}

View File

@@ -8,31 +8,31 @@ like the menubar and dock, which are needed for proper screenshot composition.
import sys
import time
from typing import Dict, Any, Optional, Tuple
from typing import Any, Dict, Optional, Tuple
# Import Objective-C bridge libraries
try:
import AppKit
import Foundation
from AppKit import NSRunningApplication, NSWorkspace
from ApplicationServices import (
AXUIElementCreateSystemWide,
AXUIElementCreateApplication,
AXUIElementCopyAttributeValue,
AXUIElementCopyAttributeValues,
kAXChildrenAttribute,
kAXRoleAttribute,
kAXTitleAttribute,
kAXPositionAttribute,
kAXSizeAttribute,
kAXErrorSuccess,
AXValueGetType,
kAXValueCGSizeType,
kAXValueCGPointType,
AXUIElementCreateApplication,
AXUIElementCreateSystemWide,
AXUIElementGetTypeID,
AXValueGetType,
AXValueGetValue,
kAXChildrenAttribute,
kAXErrorSuccess,
kAXMenuBarAttribute,
kAXPositionAttribute,
kAXRoleAttribute,
kAXSizeAttribute,
kAXTitleAttribute,
kAXValueCGPointType,
kAXValueCGSizeType,
)
from AppKit import NSWorkspace, NSRunningApplication
import Foundation
except ImportError:
print("Error: This script requires PyObjC to be installed.")
print("Please install it with: pip install pyobjc")
@@ -74,13 +74,8 @@ def element_value(element, type):
def get_element_bounds(element):
"""Get the bounds of an accessibility element"""
bounds = {
"x": 0,
"y": 0,
"width": 0,
"height": 0
}
bounds = {"x": 0, "y": 0, "width": 0, "height": 0}
# Get position
position_value = element_attribute(element, kAXPositionAttribute)
if position_value:
@@ -88,7 +83,7 @@ def get_element_bounds(element):
if position_value:
bounds["x"] = position_value.x
bounds["y"] = position_value.y
# Get size
size_value = element_attribute(element, kAXSizeAttribute)
if size_value:
@@ -96,7 +91,7 @@ def get_element_bounds(element):
if size_value:
bounds["width"] = size_value.width
bounds["height"] = size_value.height
return bounds
@@ -111,13 +106,13 @@ def find_dock_process():
def get_menubar_bounds():
"""Get the bounds of the macOS menubar
Returns:
Dictionary with x, y, width, height of the menubar
"""
# Get the system-wide accessibility element
system_element = AXUIElementCreateSystemWide()
# Try to find the menubar
menubar = element_attribute(system_element, kAXMenuBarAttribute)
if menubar is None:
@@ -127,19 +122,19 @@ def get_menubar_bounds():
app_pid = frontmost_app.processIdentifier()
app_element = AXUIElementCreateApplication(app_pid)
menubar = element_attribute(app_element, kAXMenuBarAttribute)
if menubar is None:
print("Error: Could not get menubar")
# Return default menubar bounds as fallback
return {"x": 0, "y": 0, "width": 1800, "height": 24}
# Get menubar bounds
return get_element_bounds(menubar)
def get_dock_bounds():
"""Get the bounds of the macOS Dock
Returns:
Dictionary with x, y, width, height of the Dock
"""
@@ -148,19 +143,19 @@ def get_dock_bounds():
print("Error: Could not find Dock process")
# Return empty bounds as fallback
return {"x": 0, "y": 0, "width": 0, "height": 0}
# Create an accessibility element for the Dock
dock_element = AXUIElementCreateApplication(dock_pid)
if dock_element is None:
print(f"Error: Could not create accessibility element for Dock (PID {dock_pid})")
return {"x": 0, "y": 0, "width": 0, "height": 0}
# Get the Dock's children
children = element_attribute(dock_element, kAXChildrenAttribute)
if not children or len(children) == 0:
print("Error: Could not get Dock children")
return {"x": 0, "y": 0, "width": 0, "height": 0}
# Find the Dock's list (first child is usually the main dock list)
dock_list = None
for child in children:
@@ -168,28 +163,25 @@ def get_dock_bounds():
if role == "AXList":
dock_list = child
break
if dock_list is None:
print("Error: Could not find Dock list")
return {"x": 0, "y": 0, "width": 0, "height": 0}
# Get the bounds of the dock list
return get_element_bounds(dock_list)
def get_ui_element_bounds():
"""Get the bounds of important UI elements like menubar and dock
Returns:
Dictionary with menubar and dock bounds
"""
menubar_bounds = get_menubar_bounds()
dock_bounds = get_dock_bounds()
return {
"menubar": menubar_bounds,
"dock": dock_bounds
}
return {"menubar": menubar_bounds, "dock": dock_bounds}
if __name__ == "__main__":

View File

@@ -1,24 +1,26 @@
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List, Tuple
from typing import Any, Dict, List, Optional, Tuple
class BaseAccessibilityHandler(ABC):
"""Abstract base class for OS-specific accessibility handlers."""
@abstractmethod
async def get_accessibility_tree(self) -> Dict[str, Any]:
"""Get the accessibility tree of the current window."""
pass
@abstractmethod
async def find_element(self, role: Optional[str] = None,
title: Optional[str] = None,
value: Optional[str] = None) -> Dict[str, Any]:
async def find_element(
self, role: Optional[str] = None, title: Optional[str] = None, value: Optional[str] = None
) -> Dict[str, Any]:
"""Find an element in the accessibility tree by criteria."""
pass
class BaseFileHandler(ABC):
"""Abstract base class for OS-specific file handlers."""
@abstractmethod
async def file_exists(self, path: str) -> Dict[str, Any]:
"""Check if a file exists at the specified path."""
@@ -43,7 +45,7 @@ class BaseFileHandler(ABC):
async def write_text(self, path: str, content: str) -> Dict[str, Any]:
"""Write text content to a file."""
pass
@abstractmethod
async def write_bytes(self, path: str, content_b64: str) -> Dict[str, Any]:
"""Write binary content to a file. Sent over the websocket as a base64 string."""
@@ -65,9 +67,11 @@ class BaseFileHandler(ABC):
pass
@abstractmethod
async def read_bytes(self, path: str, offset: int = 0, length: Optional[int] = None) -> Dict[str, Any]:
async def read_bytes(
self, path: str, offset: int = 0, length: Optional[int] = None
) -> Dict[str, Any]:
"""Read the binary contents of a file. Sent over the websocket as a base64 string.
Args:
path: Path to the file
offset: Byte offset to start reading from (default: 0)
@@ -80,9 +84,10 @@ class BaseFileHandler(ABC):
"""Get the size of a file in bytes."""
pass
class BaseAutomationHandler(ABC):
"""Abstract base class for OS-specific automation handlers.
Categories:
- Mouse Actions: Methods for mouse control
- Keyboard Actions: Methods for keyboard input
@@ -90,18 +95,22 @@ class BaseAutomationHandler(ABC):
- Screen Actions: Methods for screen interaction
- Clipboard Actions: Methods for clipboard operations
"""
# Mouse Actions
@abstractmethod
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_down(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Perform a mouse down at the current or specified position."""
pass
@abstractmethod
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_up(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Perform a mouse up at the current or specified position."""
pass
@abstractmethod
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a left click at the current or specified position."""
@@ -113,7 +122,9 @@ class BaseAutomationHandler(ABC):
pass
@abstractmethod
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
async def double_click(
self, x: Optional[int] = None, y: Optional[int] = None
) -> Dict[str, Any]:
"""Perform a double click at the current or specified position."""
pass
@@ -123,9 +134,11 @@ class BaseAutomationHandler(ABC):
pass
@abstractmethod
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
async def drag_to(
self, x: int, y: int, button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag the cursor from current position to specified coordinates.
Args:
x: The x coordinate to drag to
y: The y coordinate to drag to
@@ -133,11 +146,13 @@ class BaseAutomationHandler(ABC):
duration: How long the drag should take in seconds
"""
pass
@abstractmethod
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
async def drag(
self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag the cursor from current position to specified coordinates.
Args:
path: A list of tuples of x and y coordinates to drag to
button: The mouse button to use ('left', 'middle', 'right')
@@ -150,12 +165,12 @@ class BaseAutomationHandler(ABC):
async def key_down(self, key: str) -> Dict[str, Any]:
"""Press and hold the specified key."""
pass
@abstractmethod
async def key_up(self, key: str) -> Dict[str, Any]:
"""Release the specified key."""
pass
@abstractmethod
async def type_text(self, text: str) -> Dict[str, Any]:
"""Type the specified text."""
@@ -176,7 +191,7 @@ class BaseAutomationHandler(ABC):
async def scroll(self, x: int, y: int) -> Dict[str, Any]:
"""Scroll the specified amount."""
pass
@abstractmethod
async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll down by the specified number of clicks."""
@@ -212,9 +227,9 @@ class BaseAutomationHandler(ABC):
@abstractmethod
async def set_clipboard(self, text: str) -> Dict[str, Any]:
"""Set the clipboard content."""
pass
pass
@abstractmethod
async def run_command(self, command: str) -> Dict[str, Any]:
"""Run a command and return the output."""
pass
pass

View File

@@ -1,68 +1,89 @@
import platform
import subprocess
from typing import Tuple, Type
from .base import BaseAccessibilityHandler, BaseAutomationHandler, BaseFileHandler
from computer_server.diorama.base import BaseDioramaHandler
from .base import BaseAccessibilityHandler, BaseAutomationHandler, BaseFileHandler
# Conditionally import platform-specific handlers
system = platform.system().lower()
if system == 'darwin':
from .macos import MacOSAccessibilityHandler, MacOSAutomationHandler
if system == "darwin":
from computer_server.diorama.macos import MacOSDioramaHandler
elif system == 'linux':
from .macos import MacOSAccessibilityHandler, MacOSAutomationHandler
elif system == "linux":
from .linux import LinuxAccessibilityHandler, LinuxAutomationHandler
elif system == 'windows':
elif system == "windows":
from .windows import WindowsAccessibilityHandler, WindowsAutomationHandler
from .generic import GenericFileHandler
class HandlerFactory:
"""Factory for creating OS-specific handlers."""
@staticmethod
def _get_current_os() -> str:
"""Determine the current OS.
Returns:
str: The OS type ('darwin' for macOS, 'linux' for Linux, or 'windows' for Windows)
Raises:
RuntimeError: If unable to determine the current OS
"""
try:
# Use platform.system() as primary method
system = platform.system().lower()
if system in ['darwin', 'linux', 'windows']:
if system in ["darwin", "linux", "windows"]:
return system
# Fallback to uname if platform.system() doesn't return expected values (Unix-like systems only)
result = subprocess.run(['uname', '-s'], capture_output=True, text=True)
result = subprocess.run(["uname", "-s"], capture_output=True, text=True)
if result.returncode == 0:
return result.stdout.strip().lower()
raise RuntimeError(f"Unsupported OS: {system}")
except Exception as e:
raise RuntimeError(f"Failed to determine current OS: {str(e)}")
@staticmethod
def create_handlers() -> Tuple[BaseAccessibilityHandler, BaseAutomationHandler, BaseDioramaHandler, BaseFileHandler]:
def create_handlers() -> (
Tuple[BaseAccessibilityHandler, BaseAutomationHandler, BaseDioramaHandler, BaseFileHandler]
):
"""Create and return appropriate handlers for the current OS.
Returns:
Tuple[BaseAccessibilityHandler, BaseAutomationHandler, BaseDioramaHandler, BaseFileHandler]: A tuple containing
the appropriate accessibility, automation, diorama, and file handlers for the current OS.
Raises:
NotImplementedError: If the current OS is not supported
RuntimeError: If unable to determine the current OS
"""
os_type = HandlerFactory._get_current_os()
if os_type == 'darwin':
return MacOSAccessibilityHandler(), MacOSAutomationHandler(), MacOSDioramaHandler(), GenericFileHandler()
elif os_type == 'linux':
return LinuxAccessibilityHandler(), LinuxAutomationHandler(), BaseDioramaHandler(), GenericFileHandler()
elif os_type == 'windows':
return WindowsAccessibilityHandler(), WindowsAutomationHandler(), BaseDioramaHandler(), GenericFileHandler()
if os_type == "darwin":
return (
MacOSAccessibilityHandler(),
MacOSAutomationHandler(),
MacOSDioramaHandler(),
GenericFileHandler(),
)
elif os_type == "linux":
return (
LinuxAccessibilityHandler(),
LinuxAutomationHandler(),
BaseDioramaHandler(),
GenericFileHandler(),
)
elif os_type == "windows":
return (
WindowsAccessibilityHandler(),
WindowsAutomationHandler(),
BaseDioramaHandler(),
GenericFileHandler(),
)
else:
raise NotImplementedError(f"OS '{os_type}' is not supported")

View File

@@ -6,38 +6,41 @@ Includes:
"""
from pathlib import Path
from typing import Dict, Any, Optional
from .base import BaseFileHandler
import base64
from pathlib import Path
from typing import Any, Dict, Optional
from .base import BaseFileHandler
def resolve_path(path: str) -> Path:
"""Resolve a path to its absolute path. Expand ~ to the user's home directory.
Args:
path: The file or directory path to resolve
Returns:
Path: The resolved absolute path
"""
return Path(path).expanduser().resolve()
class GenericFileHandler(BaseFileHandler):
"""
Generic file handler that provides file system operations for all operating systems.
This class implements the BaseFileHandler interface and provides methods for
file and directory operations including reading, writing, creating, and deleting
files and directories.
"""
async def file_exists(self, path: str) -> Dict[str, Any]:
"""
Check if a file exists at the specified path.
Args:
path: The file path to check
Returns:
Dict containing 'success' boolean and either 'exists' boolean or 'error' string
"""
@@ -49,10 +52,10 @@ class GenericFileHandler(BaseFileHandler):
async def directory_exists(self, path: str) -> Dict[str, Any]:
"""
Check if a directory exists at the specified path.
Args:
path: The directory path to check
Returns:
Dict containing 'success' boolean and either 'exists' boolean or 'error' string
"""
@@ -64,25 +67,30 @@ class GenericFileHandler(BaseFileHandler):
async def list_dir(self, path: str) -> Dict[str, Any]:
"""
List all files and directories in the specified directory.
Args:
path: The directory path to list
Returns:
Dict containing 'success' boolean and either 'files' list of names or 'error' string
"""
try:
return {"success": True, "files": [p.name for p in resolve_path(path).iterdir() if p.is_file() or p.is_dir()]}
return {
"success": True,
"files": [
p.name for p in resolve_path(path).iterdir() if p.is_file() or p.is_dir()
],
}
except Exception as e:
return {"success": False, "error": str(e)}
async def read_text(self, path: str) -> Dict[str, Any]:
"""
Read the contents of a text file.
Args:
path: The file path to read from
Returns:
Dict containing 'success' boolean and either 'content' string or 'error' string
"""
@@ -94,11 +102,11 @@ class GenericFileHandler(BaseFileHandler):
async def write_text(self, path: str, content: str) -> Dict[str, Any]:
"""
Write text content to a file.
Args:
path: The file path to write to
content: The text content to write
Returns:
Dict containing 'success' boolean and optionally 'error' string
"""
@@ -108,60 +116,64 @@ class GenericFileHandler(BaseFileHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def write_bytes(self, path: str, content_b64: str, append: bool = False) -> Dict[str, Any]:
async def write_bytes(
self, path: str, content_b64: str, append: bool = False
) -> Dict[str, Any]:
"""
Write binary content to a file from base64 encoded string.
Args:
path: The file path to write to
content_b64: Base64 encoded binary content
append: If True, append to existing file; if False, overwrite
Returns:
Dict containing 'success' boolean and optionally 'error' string
"""
try:
mode = 'ab' if append else 'wb'
mode = "ab" if append else "wb"
with open(resolve_path(path), mode) as f:
f.write(base64.b64decode(content_b64))
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def read_bytes(self, path: str, offset: int = 0, length: Optional[int] = None) -> Dict[str, Any]:
async def read_bytes(
self, path: str, offset: int = 0, length: Optional[int] = None
) -> Dict[str, Any]:
"""
Read binary content from a file and return as base64 encoded string.
Args:
path: The file path to read from
offset: Byte offset to start reading from
length: Number of bytes to read; if None, read entire file from offset
Returns:
Dict containing 'success' boolean and either 'content_b64' string or 'error' string
"""
try:
file_path = resolve_path(path)
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
if offset > 0:
f.seek(offset)
if length is not None:
content = f.read(length)
else:
content = f.read()
return {"success": True, "content_b64": base64.b64encode(content).decode('utf-8')}
return {"success": True, "content_b64": base64.b64encode(content).decode("utf-8")}
except Exception as e:
return {"success": False, "error": str(e)}
async def get_file_size(self, path: str) -> Dict[str, Any]:
"""
Get the size of a file in bytes.
Args:
path: The file path to get size for
Returns:
Dict containing 'success' boolean and either 'size' integer or 'error' string
"""
@@ -175,10 +187,10 @@ class GenericFileHandler(BaseFileHandler):
async def delete_file(self, path: str) -> Dict[str, Any]:
"""
Delete a file at the specified path.
Args:
path: The file path to delete
Returns:
Dict containing 'success' boolean and optionally 'error' string
"""
@@ -191,13 +203,13 @@ class GenericFileHandler(BaseFileHandler):
async def create_dir(self, path: str) -> Dict[str, Any]:
"""
Create a directory at the specified path.
Creates parent directories if they don't exist and doesn't raise an error
if the directory already exists.
Args:
path: The directory path to create
Returns:
Dict containing 'success' boolean and optionally 'error' string
"""
@@ -210,10 +222,10 @@ class GenericFileHandler(BaseFileHandler):
async def delete_dir(self, path: str) -> Dict[str, Any]:
"""
Delete an empty directory at the specified path.
Args:
path: The directory path to delete
Returns:
Dict containing 'success' boolean and optionally 'error' string
"""

View File

@@ -7,14 +7,15 @@ To use GUI automation in a headless environment:
1. Install Xvfb: sudo apt-get install xvfb
2. Run with virtual display: xvfb-run python -m computer_server
"""
from typing import Dict, Any, List, Tuple, Optional
import logging
import subprocess
import asyncio
import base64
import os
import json
import logging
import os
import subprocess
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple
# Configure logger
logger = logging.getLogger(__name__)
@@ -23,30 +24,36 @@ logger = logging.getLogger(__name__)
# This allows the server to run in headless environments
try:
import pyautogui
pyautogui.FAILSAFE = False
logger.info("pyautogui successfully imported, GUI automation available")
except Exception as e:
logger.warning(f"pyautogui import failed: {str(e)}. GUI operations will be simulated.")
from pynput.mouse import Button, Controller as MouseController
from pynput.keyboard import Key, Controller as KeyboardController
from pynput.keyboard import Controller as KeyboardController
from pynput.keyboard import Key
from pynput.mouse import Button
from pynput.mouse import Controller as MouseController
from .base import BaseAccessibilityHandler, BaseAutomationHandler
class LinuxAccessibilityHandler(BaseAccessibilityHandler):
"""Linux implementation of accessibility handler."""
async def get_accessibility_tree(self) -> Dict[str, Any]:
"""Get the accessibility tree of the current window.
Returns:
Dict[str, Any]: A dictionary containing success status and a simulated tree structure
since Linux doesn't have equivalent accessibility API like macOS.
"""
# Linux doesn't have equivalent accessibility API like macOS
# Return a minimal dummy tree
logger.info("Getting accessibility tree (simulated, no accessibility API available on Linux)")
logger.info(
"Getting accessibility tree (simulated, no accessibility API available on Linux)"
)
return {
"success": True,
"tree": {
@@ -54,32 +61,31 @@ class LinuxAccessibilityHandler(BaseAccessibilityHandler):
"title": "Linux Window",
"position": {"x": 0, "y": 0},
"size": {"width": 1920, "height": 1080},
"children": []
}
"children": [],
},
}
async def find_element(self, role: Optional[str] = None,
title: Optional[str] = None,
value: Optional[str] = None) -> Dict[str, Any]:
async def find_element(
self, role: Optional[str] = None, title: Optional[str] = None, value: Optional[str] = None
) -> Dict[str, Any]:
"""Find an element in the accessibility tree by criteria.
Args:
role: The role of the element to find.
title: The title of the element to find.
value: The value of the element to find.
Returns:
Dict[str, Any]: A dictionary indicating that element search is not supported on Linux.
"""
logger.info(f"Finding element with role={role}, title={title}, value={value} (not supported on Linux)")
return {
"success": False,
"message": "Element search not supported on Linux"
}
logger.info(
f"Finding element with role={role}, title={title}, value={value} (not supported on Linux)"
)
return {"success": False, "message": "Element search not supported on Linux"}
def get_cursor_position(self) -> Tuple[int, int]:
"""Get the current cursor position.
Returns:
Tuple[int, int]: The x and y coordinates of the cursor position.
Returns (0, 0) if pyautogui is not available.
@@ -89,13 +95,13 @@ class LinuxAccessibilityHandler(BaseAccessibilityHandler):
return pos.x, pos.y
except Exception as e:
logger.warning(f"Failed to get cursor position with pyautogui: {e}")
logger.info("Getting cursor position (simulated)")
return 0, 0
def get_screen_size(self) -> Tuple[int, int]:
"""Get the screen size.
Returns:
Tuple[int, int]: The width and height of the screen in pixels.
Returns (1920, 1080) if pyautogui is not available.
@@ -105,24 +111,28 @@ class LinuxAccessibilityHandler(BaseAccessibilityHandler):
return size.width, size.height
except Exception as e:
logger.warning(f"Failed to get screen size with pyautogui: {e}")
logger.info("Getting screen size (simulated)")
return 1920, 1080
class LinuxAutomationHandler(BaseAutomationHandler):
"""Linux implementation of automation handler using pyautogui."""
keyboard = KeyboardController()
mouse = MouseController()
# Mouse Actions
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_down(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Press and hold a mouse button at the specified coordinates.
Args:
x: The x coordinate to move to before pressing. If None, uses current position.
y: The y coordinate to move to before pressing. If None, uses current position.
button: The mouse button to press ("left", "right", or "middle").
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -133,15 +143,17 @@ class LinuxAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_up(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Release a mouse button at the specified coordinates.
Args:
x: The x coordinate to move to before releasing. If None, uses current position.
y: The y coordinate to move to before releasing. If None, uses current position.
button: The mouse button to release ("left", "right", or "middle").
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -152,14 +164,14 @@ class LinuxAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def move_cursor(self, x: int, y: int) -> Dict[str, Any]:
"""Move the cursor to the specified coordinates.
Args:
x: The x coordinate to move to.
y: The y coordinate to move to.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -171,11 +183,11 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a left mouse click at the specified coordinates.
Args:
x: The x coordinate to click at. If None, clicks at current position.
y: The y coordinate to click at. If None, clicks at current position.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -189,11 +201,11 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a right mouse click at the specified coordinates.
Args:
x: The x coordinate to click at. If None, clicks at current position.
y: The y coordinate to click at. If None, clicks at current position.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -205,13 +217,15 @@ class LinuxAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
async def double_click(
self, x: Optional[int] = None, y: Optional[int] = None
) -> Dict[str, Any]:
"""Perform a double click at the specified coordinates.
Args:
x: The x coordinate to double click at. If None, clicks at current position.
y: The y coordinate to double click at. If None, clicks at current position.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -223,14 +237,16 @@ class LinuxAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def click(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def click(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Perform a mouse click with the specified button at the given coordinates.
Args:
x: The x coordinate to click at. If None, clicks at current position.
y: The y coordinate to click at. If None, clicks at current position.
button: The mouse button to click ("left", "right", or "middle").
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -242,15 +258,17 @@ class LinuxAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
async def drag_to(
self, x: int, y: int, button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag from the current position to the specified coordinates.
Args:
x: The x coordinate to drag to.
y: The y coordinate to drag to.
button: The mouse button to use for dragging.
duration: The time in seconds to take for the drag operation.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -260,16 +278,18 @@ class LinuxAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def drag(self, start_x: int, start_y: int, end_x: int, end_y: int, button: str = "left") -> Dict[str, Any]:
async def drag(
self, start_x: int, start_y: int, end_x: int, end_y: int, button: str = "left"
) -> Dict[str, Any]:
"""Drag from start coordinates to end coordinates.
Args:
start_x: The starting x coordinate.
start_y: The starting y coordinate.
end_x: The ending x coordinate.
end_y: The ending y coordinate.
button: The mouse button to use for dragging.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -280,14 +300,16 @@ class LinuxAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def drag_path(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
async def drag_path(
self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag along a path defined by a list of coordinates.
Args:
path: A list of (x, y) coordinate tuples defining the drag path.
button: The mouse button to use for dragging.
duration: The time in seconds to take for each segment of the drag.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -304,10 +326,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
# Keyboard Actions
async def key_down(self, key: str) -> Dict[str, Any]:
"""Press and hold a key.
Args:
key: The key to press down.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -316,13 +338,13 @@ class LinuxAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def key_up(self, key: str) -> Dict[str, Any]:
"""Release a key.
Args:
key: The key to release.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -331,13 +353,13 @@ class LinuxAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def type_text(self, text: str) -> Dict[str, Any]:
"""Type the specified text using the keyboard.
Args:
text: The text to type.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -350,10 +372,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def press_key(self, key: str) -> Dict[str, Any]:
"""Press and release a key.
Args:
key: The key to press.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -365,10 +387,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def hotkey(self, keys: List[str]) -> Dict[str, Any]:
"""Press a combination of keys simultaneously.
Args:
keys: A list of keys to press together as a hotkey combination.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -381,11 +403,11 @@ class LinuxAutomationHandler(BaseAutomationHandler):
# Scrolling Actions
async def scroll(self, x: int, y: int) -> Dict[str, Any]:
"""Scroll the mouse wheel.
Args:
x: The horizontal scroll amount.
y: The vertical scroll amount.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -394,13 +416,13 @@ class LinuxAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll down by the specified number of clicks.
Args:
clicks: The number of scroll clicks to perform downward.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -412,10 +434,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def scroll_up(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll up by the specified number of clicks.
Args:
clicks: The number of scroll clicks to perform upward.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
@@ -428,13 +450,14 @@ class LinuxAutomationHandler(BaseAutomationHandler):
# Screen Actions
async def screenshot(self) -> Dict[str, Any]:
"""Take a screenshot of the current screen.
Returns:
Dict[str, Any]: A dictionary containing success status and base64-encoded image data,
or error message if failed.
"""
try:
from PIL import Image
screenshot = pyautogui.screenshot()
if not isinstance(screenshot, Image.Image):
return {"success": False, "error": "Failed to capture screenshot"}
@@ -448,7 +471,7 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def get_screen_size(self) -> Dict[str, Any]:
"""Get the size of the screen.
Returns:
Dict[str, Any]: A dictionary containing success status and screen dimensions,
or error message if failed.
@@ -461,7 +484,7 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def get_cursor_position(self) -> Dict[str, Any]:
"""Get the current position of the cursor.
Returns:
Dict[str, Any]: A dictionary containing success status and cursor coordinates,
or error message if failed.
@@ -475,13 +498,14 @@ class LinuxAutomationHandler(BaseAutomationHandler):
# Clipboard Actions
async def copy_to_clipboard(self) -> Dict[str, Any]:
"""Get the current content of the clipboard.
Returns:
Dict[str, Any]: A dictionary containing success status and clipboard content,
or error message if failed.
"""
try:
import pyperclip
content = pyperclip.paste()
return {"success": True, "content": content}
except Exception as e:
@@ -489,15 +513,16 @@ class LinuxAutomationHandler(BaseAutomationHandler):
async def set_clipboard(self, text: str) -> Dict[str, Any]:
"""Set the clipboard content to the specified text.
Args:
text: The text to copy to the clipboard.
Returns:
Dict[str, Any]: A dictionary with success status and error message if failed.
"""
try:
import pyperclip
pyperclip.copy(text)
return {"success": True}
except Exception as e:
@@ -506,10 +531,10 @@ class LinuxAutomationHandler(BaseAutomationHandler):
# Command Execution
async def run_command(self, command: str) -> Dict[str, Any]:
"""Execute a shell command asynchronously.
Args:
command: The shell command to execute.
Returns:
Dict[str, Any]: A dictionary containing success status, stdout, stderr,
and return code, or error message if failed.
@@ -517,18 +542,16 @@ class LinuxAutomationHandler(BaseAutomationHandler):
try:
# Create subprocess
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
# Wait for the subprocess to finish
stdout, stderr = await process.communicate()
# Return decoded output
return {
"success": True,
"stdout": stdout.decode() if stdout else "",
"stderr": stderr.decode() if stderr else "",
"return_code": process.returncode
"success": True,
"stdout": stdout.decode() if stdout else "",
"stderr": stderr.decode() if stderr else "",
"return_code": process.returncode,
}
except Exception as e:
return {"success": False, "error": str(e)}

View File

@@ -1,54 +1,57 @@
import pyautogui
pyautogui.FAILSAFE = False
from pynput.mouse import Button, Controller as MouseController
from pynput.keyboard import Key, Controller as KeyboardController
import time
import asyncio
import base64
import copy
import json
import logging
import re
import time
from ctypes import POINTER, byref, c_void_p
from io import BytesIO
from typing import Optional, Dict, Any, List, Tuple
from ctypes import byref, c_void_p, POINTER
from AppKit import NSWorkspace # type: ignore
from typing import Any, Dict, List, Optional, Tuple
import AppKit
import Foundation
import objc
from AppKit import NSWorkspace # type: ignore
from ApplicationServices import AXUIElementCopyAttributeValue # type: ignore
from ApplicationServices import AXUIElementCopyAttributeValues # type: ignore
from ApplicationServices import AXUIElementCreateApplication # type: ignore
from ApplicationServices import AXUIElementCreateSystemWide # type: ignore
from ApplicationServices import AXUIElementGetTypeID # type: ignore
from ApplicationServices import AXValueGetType # type: ignore
from ApplicationServices import AXValueGetValue # type: ignore
from ApplicationServices import kAXChildrenAttribute # type: ignore
from ApplicationServices import kAXDescriptionAttribute # type: ignore
from ApplicationServices import kAXEnabledAttribute # type: ignore
from ApplicationServices import kAXErrorSuccess # type: ignore
from ApplicationServices import kAXFocusedApplicationAttribute # type: ignore
from ApplicationServices import kAXFocusedUIElementAttribute # type: ignore
from ApplicationServices import kAXFocusedWindowAttribute # type: ignore
from ApplicationServices import kAXMainWindowAttribute # type: ignore
from ApplicationServices import kAXPositionAttribute # type: ignore
from ApplicationServices import kAXRoleAttribute # type: ignore
from ApplicationServices import kAXRoleDescriptionAttribute # type: ignore
from ApplicationServices import kAXSelectedTextAttribute # type: ignore
from ApplicationServices import kAXSelectedTextRangeAttribute # type: ignore
from ApplicationServices import kAXSizeAttribute # type: ignore
from ApplicationServices import kAXTitleAttribute # type: ignore
from ApplicationServices import kAXValueAttribute # type: ignore
from ApplicationServices import kAXValueCFRangeType # type: ignore
from ApplicationServices import kAXValueCGPointType # type: ignore
from ApplicationServices import kAXValueCGSizeType # type: ignore
from ApplicationServices import kAXVisibleChildrenAttribute # type: ignore
from ApplicationServices import kAXWindowsAttribute # type: ignore
from pynput.keyboard import Controller as KeyboardController
from pynput.keyboard import Key
from pynput.mouse import Button
from pynput.mouse import Controller as MouseController
from Quartz.CoreGraphics import * # type: ignore
from Quartz.CoreGraphics import CGPoint, CGSize # type: ignore
import Foundation
from ApplicationServices import (
AXUIElementCreateSystemWide, # type: ignore
AXUIElementCreateApplication, # type: ignore
AXUIElementCopyAttributeValue, # type: ignore
AXUIElementCopyAttributeValues, # type: ignore
kAXFocusedWindowAttribute, # type: ignore
kAXWindowsAttribute, # type: ignore
kAXMainWindowAttribute, # type: ignore
kAXChildrenAttribute, # type: ignore
kAXRoleAttribute, # type: ignore
kAXTitleAttribute, # type: ignore
kAXValueAttribute, # type: ignore
kAXDescriptionAttribute, # type: ignore
kAXEnabledAttribute, # type: ignore
kAXPositionAttribute, # type: ignore
kAXSizeAttribute, # type: ignore
kAXErrorSuccess, # type: ignore
AXValueGetType, # type: ignore
kAXValueCGSizeType, # type: ignore
kAXValueCGPointType, # type: ignore
kAXValueCFRangeType, # type: ignore
AXUIElementGetTypeID, # type: ignore
AXValueGetValue, # type: ignore
kAXVisibleChildrenAttribute, # type: ignore
kAXRoleDescriptionAttribute, # type: ignore
kAXFocusedApplicationAttribute, # type: ignore
kAXFocusedUIElementAttribute, # type: ignore
kAXSelectedTextAttribute, # type: ignore
kAXSelectedTextRangeAttribute, # type: ignore
)
import objc
import re
import json
import copy
import asyncio
from .base import BaseAccessibilityHandler, BaseAutomationHandler
import logging
logger = logging.getLogger(__name__)
@@ -73,24 +76,26 @@ kCGWindowAlpha = "kCGWindowAlpha" # Window opacity
NSApplicationActivationOptions = {
"regular": 0, # Default activation
"bringing_all_windows_forward": 1 << 0, # NSApplicationActivateAllWindows
"ignoring_other_apps": 1 << 1 # NSApplicationActivateIgnoringOtherApps
"ignoring_other_apps": 1 << 1, # NSApplicationActivateIgnoringOtherApps
}
def CFAttributeToPyObject(attrValue):
"""Convert Core Foundation attribute values to Python objects.
Args:
attrValue: Core Foundation attribute value to convert
Returns:
Converted Python object or None if conversion fails
"""
def list_helper(list_value):
"""Helper function to convert CF arrays to Python lists.
Args:
list_value: Core Foundation array to convert
Returns:
Python list containing converted items
"""
@@ -101,10 +106,10 @@ def CFAttributeToPyObject(attrValue):
def number_helper(number_value):
"""Helper function to convert CF numbers to Python numbers.
Args:
number_value: Core Foundation number to convert
Returns:
Python int or float, or None if conversion fails
"""
@@ -123,10 +128,10 @@ def CFAttributeToPyObject(attrValue):
def axuielement_helper(element_value):
"""Helper function to handle AX UI elements.
Args:
element_value: Accessibility UI element to process
Returns:
The element value unchanged
"""
@@ -164,11 +169,11 @@ def CFAttributeToPyObject(attrValue):
def element_attribute(element, attribute):
"""Get an attribute value from an accessibility element.
Args:
element: The accessibility element
attribute: The attribute name to retrieve
Returns:
The attribute value or None if not found
"""
@@ -190,11 +195,11 @@ def element_attribute(element, attribute):
def element_value(element, type):
"""Extract a typed value from an accessibility element.
Args:
element: The accessibility element containing the value
type: The expected value type
Returns:
The extracted value or None if extraction fails
"""
@@ -206,10 +211,10 @@ def element_value(element, type):
class UIElement:
"""Represents a UI element in the accessibility tree with position, size, and hierarchy information."""
def __init__(self, element, offset_x=0, offset_y=0, max_depth=None, parents_visible_bbox=None):
"""Initialize a UIElement from an accessibility element.
Args:
element: The accessibility element to wrap
offset_x: X offset for position calculations
@@ -297,7 +302,7 @@ class UIElement:
def _set_bboxes(self, parents_visible_bbox):
"""Set bounding box and visible bounding box for the element.
Args:
parents_visible_bbox: Parent's visible bounding box for intersection calculation
"""
@@ -332,13 +337,13 @@ class UIElement:
def _get_children(self, element, start_position, offset_x, offset_y):
"""Get child elements from the accessibility element.
Args:
element: The parent accessibility element
start_position: Starting position for offset calculations
offset_x: X offset for child positioning
offset_y: Y offset for child positioning
Returns:
List of UIElement children
"""
@@ -371,7 +376,7 @@ class UIElement:
def component_hash(self):
"""Generate a hash identifier for this component based on its properties.
Returns:
MD5 hash string of component properties
"""
@@ -388,10 +393,10 @@ class UIElement:
def hash_from_string(self, string):
"""Generate MD5 hash from a string.
Args:
string: Input string to hash
Returns:
MD5 hash hexdigest or empty string if input is None/empty
"""
@@ -403,10 +408,10 @@ class UIElement:
def children_content_hash(self, children):
"""Generate a hash representing the content and structure of child elements.
Args:
children: List of child UIElement objects
Returns:
Combined hash of children content and structure
"""
@@ -426,16 +431,17 @@ class UIElement:
def to_dict(self):
"""Convert the UIElement to a dictionary representation.
Returns:
Dictionary containing all element properties and children
"""
def children_to_dict(children):
"""Convert list of children to dictionary format.
Args:
children: List of UIElement children to convert
Returns:
List of dictionaries representing the children
"""
@@ -464,7 +470,7 @@ class UIElement:
size = f"{self.size.width:.0f};{self.size.height:.0f}"
else:
size = ""
return {
"id": self.identifier,
"name": self.name,
@@ -482,36 +488,38 @@ class UIElement:
}
import Quartz
from AppKit import NSWorkspace, NSRunningApplication
from pathlib import Path
import Quartz
from AppKit import NSRunningApplication, NSWorkspace
def get_all_windows_zorder():
"""Get all windows in the system with their z-order information.
Returns:
List of window dictionaries sorted by z-index, containing window properties
like id, name, pid, owner, bounds, layer, and opacity
"""
window_list = Quartz.CGWindowListCopyWindowInfo(
Quartz.kCGWindowListOptionOnScreenOnly,
Quartz.kCGNullWindowID
Quartz.kCGWindowListOptionOnScreenOnly, Quartz.kCGNullWindowID
)
z_order = {window['kCGWindowNumber']: z_index for z_index, window in enumerate(window_list[::-1])}
z_order = {
window["kCGWindowNumber"]: z_index for z_index, window in enumerate(window_list[::-1])
}
window_list_all = Quartz.CGWindowListCopyWindowInfo(
Quartz.kCGWindowListOptionAll,
Quartz.kCGNullWindowID
Quartz.kCGWindowListOptionAll, Quartz.kCGNullWindowID
)
windows = []
for window in window_list_all:
window_id = window.get('kCGWindowNumber', 0)
window_name = window.get('kCGWindowName', '')
window_pid = window.get('kCGWindowOwnerPID', 0)
window_bounds = window.get('kCGWindowBounds', {})
window_owner = window.get('kCGWindowOwnerName', '')
window_is_on_screen = window.get('kCGWindowIsOnscreen', False)
layer = window.get('kCGWindowLayer', 0)
opacity = window.get('kCGWindowAlpha', 1.0)
window_id = window.get("kCGWindowNumber", 0)
window_name = window.get("kCGWindowName", "")
window_pid = window.get("kCGWindowOwnerPID", 0)
window_bounds = window.get("kCGWindowBounds", {})
window_owner = window.get("kCGWindowOwnerName", "")
window_is_on_screen = window.get("kCGWindowIsOnscreen", False)
layer = window.get("kCGWindowLayer", 0)
opacity = window.get("kCGWindowAlpha", 1.0)
z_index = z_order.get(window_id, -1)
if window_name == "Dock" and window_owner == "Dock":
role = "dock"
@@ -522,32 +530,35 @@ def get_all_windows_zorder():
else:
role = "app"
if window_bounds:
windows.append({
"id": window_id,
"name": window_name or "Unnamed Window",
"pid": window_pid,
"owner": window_owner,
"role": role,
"is_on_screen": window_is_on_screen,
"bounds": {
"x": window_bounds.get('X', 0),
"y": window_bounds.get('Y', 0),
"width": window_bounds.get('Width', 0),
"height": window_bounds.get('Height', 0)
},
"layer": layer,
"z_index": z_index,
"opacity": opacity
})
windows.append(
{
"id": window_id,
"name": window_name or "Unnamed Window",
"pid": window_pid,
"owner": window_owner,
"role": role,
"is_on_screen": window_is_on_screen,
"bounds": {
"x": window_bounds.get("X", 0),
"y": window_bounds.get("Y", 0),
"width": window_bounds.get("Width", 0),
"height": window_bounds.get("Height", 0),
},
"layer": layer,
"z_index": z_index,
"opacity": opacity,
}
)
windows = sorted(windows, key=lambda x: x["z_index"])
return windows
def get_app_info(app):
"""Extract information from an NSRunningApplication object.
Args:
app: NSRunningApplication instance
Returns:
Dictionary containing app name, bundle ID, PID, and status flags
"""
@@ -560,12 +571,13 @@ def get_app_info(app):
"terminated": app.isTerminated(),
}
def get_menubar_items(active_app_pid=None):
"""Get menubar items for the active application.
Args:
active_app_pid: Process ID of the active application, or None to use frontmost app
Returns:
List of menubar item dictionaries with title, bounds, index, and app_pid
"""
@@ -591,26 +603,24 @@ def get_menubar_items(active_app_pid=None):
position_value = element_attribute(item, kAXPositionAttribute)
if position_value:
position_value = element_value(position_value, kAXValueCGPointType)
bounds["x"] = getattr(position_value, 'x', 0)
bounds["y"] = getattr(position_value, 'y', 0)
bounds["x"] = getattr(position_value, "x", 0)
bounds["y"] = getattr(position_value, "y", 0)
size_value = element_attribute(item, kAXSizeAttribute)
if size_value:
size_value = element_value(size_value, kAXValueCGSizeType)
bounds["width"] = getattr(size_value, 'width', 0)
bounds["height"] = getattr(size_value, 'height', 0)
menubar_items.append({
"title": title,
"bounds": bounds,
"index": i,
"app_pid": active_app_pid
})
bounds["width"] = getattr(size_value, "width", 0)
bounds["height"] = getattr(size_value, "height", 0)
menubar_items.append(
{"title": title, "bounds": bounds, "index": i, "app_pid": active_app_pid}
)
return menubar_items
def get_dock_items():
"""Get all items in the macOS Dock.
Returns:
List of dock item dictionaries with title, description, bounds, index,
List of dock item dictionaries with title, description, bounds, index,
type, role, and subrole information
"""
dock_items = []
@@ -648,13 +658,13 @@ def get_dock_items():
position_value = element_attribute(item, kAXPositionAttribute)
if position_value:
position_value = element_value(position_value, kAXValueCGPointType)
bounds["x"] = getattr(position_value, 'x', 0)
bounds["y"] = getattr(position_value, 'y', 0)
bounds["x"] = getattr(position_value, "x", 0)
bounds["y"] = getattr(position_value, "y", 0)
size_value = element_attribute(item, kAXSizeAttribute)
if size_value:
size_value = element_value(size_value, kAXValueCGSizeType)
bounds["width"] = getattr(size_value, 'width', 0)
bounds["height"] = getattr(size_value, 'height', 0)
bounds["width"] = getattr(size_value, "width", 0)
bounds["height"] = getattr(size_value, "height", 0)
item_type = "unknown"
if subrole == "AXApplicationDockItem":
item_type = "application"
@@ -666,23 +676,26 @@ def get_dock_items():
item_type = "separator"
elif "trash" in title.lower():
item_type = "trash"
dock_items.append({
"title": title,
"description": description,
"bounds": bounds,
"index": i,
"type": item_type,
"role": role,
"subrole": subrole
})
dock_items.append(
{
"title": title,
"description": description,
"bounds": bounds,
"index": i,
"type": item_type,
"role": role,
"subrole": subrole,
}
)
return dock_items
class MacOSAccessibilityHandler(BaseAccessibilityHandler):
"""Handler for macOS accessibility features and UI element inspection."""
def get_desktop_state(self):
"""Get the current state of the desktop including windows, apps, menubar, and dock.
Returns:
Dictionary containing applications, windows, menubar_items, and dock_items
"""
@@ -696,7 +709,9 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
pid = app.processIdentifier()
try:
app_elem = AXUIElementCreateApplication(pid)
err, app_windows = AXUIElementCopyAttributeValue(app_elem, kAXWindowsAttribute, None)
err, app_windows = AXUIElementCopyAttributeValue(
app_elem, kAXWindowsAttribute, None
)
trees = []
if err == kAXErrorSuccess and app_windows:
for ax_win in app_windows:
@@ -713,31 +728,32 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
pid = win["pid"]
idx = pid_to_idx.get(pid, 0)
ax_trees = pid_to_ax_trees.get(pid, [])
win["children"] = ax_trees[idx]["children"] if idx < len(ax_trees) and "children" in ax_trees[idx] else []
win["children"] = (
ax_trees[idx]["children"]
if idx < len(ax_trees) and "children" in ax_trees[idx]
else []
)
pid_to_idx[pid] = idx + 1
pid_to_window_ids.setdefault(pid, []).append(win["id"])
for app in running_apps:
info = get_app_info(app)
app_pid = info["pid"]
applications.append({
"info": info,
"windows": pid_to_window_ids.get(app_pid, [])
})
applications.append({"info": info, "windows": pid_to_window_ids.get(app_pid, [])})
menubar_items = get_menubar_items()
dock_items = get_dock_items()
return {
"applications": applications,
"windows": windows,
"menubar_items": menubar_items,
"dock_items": dock_items
"dock_items": dock_items,
}
def get_application_windows(self, pid: int):
"""Get all windows for a specific application.
Args:
pid: Process ID of the application
Returns:
List of accessibility window elements or empty list if none found
"""
@@ -753,7 +769,7 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
def get_all_windows(self):
"""Get all visible windows in the system.
Returns:
List of window dictionaries with app information and window details
"""
@@ -791,7 +807,7 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
def get_running_apps(self):
"""Get all currently running applications.
Returns:
List of NSRunningApplication objects
"""
@@ -803,11 +819,11 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
def get_ax_attribute(self, element, attribute):
"""Get an accessibility attribute from an element.
Args:
element: The accessibility element
attribute: The attribute name to retrieve
Returns:
The attribute value or None if not found
"""
@@ -815,10 +831,10 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
def serialize_node(self, element):
"""Create a serializable dictionary representation of an accessibility element.
Args:
element: The accessibility element to serialize
Returns:
Dictionary containing element properties like role, title, value, position, and size
"""
@@ -851,16 +867,13 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
async def get_accessibility_tree(self) -> Dict[str, Any]:
"""Get the complete accessibility tree for the current desktop state.
Returns:
Dictionary containing success status and desktop state information
"""
"""
try:
desktop_state = self.get_desktop_state()
return {
"success": True,
**desktop_state
}
return {"success": True, **desktop_state}
except Exception as e:
return {"success": False, "error": str(e)}
@@ -869,12 +882,12 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
self, role: Optional[str] = None, title: Optional[str] = None, value: Optional[str] = None
) -> Dict[str, Any]:
"""Find an accessibility element matching the specified criteria.
Args:
role: The accessibility role to match (optional)
title: The title to match (optional)
value: The value to match (optional)
Returns:
Dictionary containing success status and the found element or error message
"""
@@ -883,10 +896,10 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
def match_element(element):
"""Check if an element matches the search criteria.
Args:
element: The accessibility element to check
Returns:
True if element matches all specified criteria, False otherwise
"""
@@ -900,10 +913,10 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
def search_tree(element):
"""Recursively search the accessibility tree for matching elements.
Args:
element: The accessibility element to search from
Returns:
Serialized element dictionary if match found, None otherwise
"""
@@ -924,58 +937,71 @@ class MacOSAccessibilityHandler(BaseAccessibilityHandler):
except Exception as e:
return {"success": False, "error": str(e)}
class MacOSAutomationHandler(BaseAutomationHandler):
"""Handler for macOS automation including mouse, keyboard, and screen operations."""
# Mouse Actions
mouse = MouseController()
keyboard = KeyboardController()
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_down(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Press and hold a mouse button at the specified coordinates.
Args:
x: X coordinate (optional, uses current position if None)
y: Y coordinate (optional, uses current position if None)
button: Mouse button to press ("left", "right", or "middle")
Returns:
Dictionary containing success status and error message if failed
"""
try:
if x is not None and y is not None:
self.mouse.position = (x, y)
self.mouse.press(Button.left if button == "left" else Button.right if button == "right" else Button.middle)
self.mouse.press(
Button.left
if button == "left"
else Button.right if button == "right" else Button.middle
)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_up(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Release a mouse button at the specified coordinates.
Args:
x: X coordinate (optional, uses current position if None)
y: Y coordinate (optional, uses current position if None)
button: Mouse button to release ("left", "right", or "middle")
Returns:
Dictionary containing success status and error message if failed
"""
try:
if x is not None and y is not None:
self.mouse.position = (x, y)
self.mouse.release(Button.left if button == "left" else Button.right if button == "right" else Button.middle)
self.mouse.release(
Button.left
if button == "left"
else Button.right if button == "right" else Button.middle
)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a left mouse click at the specified coordinates.
Args:
x: X coordinate (optional, uses current position if None)
y: Y coordinate (optional, uses current position if None)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -989,11 +1015,11 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a right mouse click at the specified coordinates.
Args:
x: X coordinate (optional, uses current position if None)
y: Y coordinate (optional, uses current position if None)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1009,11 +1035,11 @@ class MacOSAutomationHandler(BaseAutomationHandler):
self, x: Optional[int] = None, y: Optional[int] = None
) -> Dict[str, Any]:
"""Perform a double left mouse click at the specified coordinates.
Args:
x: X coordinate (optional, uses current position if None)
y: Y coordinate (optional, uses current position if None)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1027,11 +1053,11 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def move_cursor(self, x: int, y: int) -> Dict[str, Any]:
"""Move the mouse cursor to the specified coordinates.
Args:
x: Target X coordinate
y: Target Y coordinate
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1045,18 +1071,22 @@ class MacOSAutomationHandler(BaseAutomationHandler):
self, x: int, y: int, button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag from current position to target coordinates.
Args:
x: Target X coordinate
y: Target Y coordinate
button: Mouse button to use for dragging ("left", "right", or "middle")
duration: Duration of the drag operation in seconds
Returns:
Dictionary containing success status and error message if failed
"""
try:
btn = Button.left if button == "left" else Button.right if button == "right" else Button.middle
btn = (
Button.left
if button == "left"
else Button.right if button == "right" else Button.middle
)
# Press
self.mouse.press(btn)
# Move with sleep to simulate drag duration
@@ -1082,19 +1112,23 @@ class MacOSAutomationHandler(BaseAutomationHandler):
self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag the mouse along a specified path of coordinates.
Args:
path: List of (x, y) coordinate tuples defining the drag path
button: Mouse button to use for dragging ("left", "right", or "middle")
duration: Total duration of the drag operation in seconds
Returns:
Dictionary containing success status and error message if failed
"""
try:
if not path or len(path) < 2:
return {"success": False, "error": "Path must contain at least 2 points"}
btn = Button.left if button == "left" else Button.right if button == "right" else Button.middle
btn = (
Button.left
if button == "left"
else Button.right if button == "right" else Button.middle
)
# Move to the first point
self.mouse.position = path[0]
self.mouse.press(btn)
@@ -1114,10 +1148,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
# Keyboard Actions
async def key_down(self, key: str) -> Dict[str, Any]:
"""Press and hold a keyboard key.
Args:
key: Key name to press (using pyautogui key names)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1127,13 +1161,13 @@ class MacOSAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def key_up(self, key: str) -> Dict[str, Any]:
"""Release a keyboard key.
Args:
key: Key name to release (using pyautogui key names)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1143,13 +1177,13 @@ class MacOSAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def type_text(self, text: str) -> Dict[str, Any]:
"""Type text using the keyboard with Unicode support.
Args:
text: Text string to type
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1162,10 +1196,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def press_key(self, key: str) -> Dict[str, Any]:
"""Press and release a keyboard key.
Args:
key: Key name to press (using pyautogui key names)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1178,10 +1212,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def hotkey(self, keys: List[str]) -> Dict[str, Any]:
"""Press a combination of keys simultaneously.
Args:
keys: List of key names to press together (using pyautogui key names)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1195,11 +1229,11 @@ class MacOSAutomationHandler(BaseAutomationHandler):
# Scrolling Actions
async def scroll(self, x: int, y: int) -> Dict[str, Any]:
"""Scroll the mouse wheel in the specified direction.
Args:
x: Horizontal scroll amount
y: Vertical scroll amount (positive for up, negative for down)
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1208,13 +1242,13 @@ class MacOSAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll down by the specified number of clicks.
Args:
clicks: Number of scroll clicks to perform
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1226,10 +1260,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def scroll_up(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll up by the specified number of clicks.
Args:
clicks: Number of scroll clicks to perform
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1242,7 +1276,7 @@ class MacOSAutomationHandler(BaseAutomationHandler):
# Screen Actions
async def screenshot(self) -> Dict[str, Any]:
"""Capture a screenshot of the current screen.
Returns:
Dictionary containing success status and base64-encoded image data or error message
"""
@@ -1263,7 +1297,7 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def get_screen_size(self) -> Dict[str, Any]:
"""Get the dimensions of the current screen.
Returns:
Dictionary containing success status and screen size or error message
"""
@@ -1275,7 +1309,7 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def get_cursor_position(self) -> Dict[str, Any]:
"""Get the current position of the mouse cursor.
Returns:
Dictionary containing success status and cursor position or error message
"""
@@ -1288,7 +1322,7 @@ class MacOSAutomationHandler(BaseAutomationHandler):
# Clipboard Actions
async def copy_to_clipboard(self) -> Dict[str, Any]:
"""Get the current content of the system clipboard.
Returns:
Dictionary containing success status and clipboard content or error message
"""
@@ -1302,10 +1336,10 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def set_clipboard(self, text: str) -> Dict[str, Any]:
"""Set the content of the system clipboard.
Args:
text: Text to copy to the clipboard
Returns:
Dictionary containing success status and error message if failed
"""
@@ -1319,28 +1353,26 @@ class MacOSAutomationHandler(BaseAutomationHandler):
async def run_command(self, command: str) -> Dict[str, Any]:
"""Run a shell command and return its output.
Args:
command: Shell command to execute
Returns:
Dictionary containing success status, stdout, stderr, and return code
"""
try:
# Create subprocess
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
# Wait for the subprocess to finish
stdout, stderr = await process.communicate()
# Return decoded output
return {
"success": True,
"stdout": stdout.decode() if stdout else "",
"success": True,
"stdout": stdout.decode() if stdout else "",
"stderr": stderr.decode() if stderr else "",
"return_code": process.returncode
"return_code": process.returncode,
}
except Exception as e:
return {"success": False, "error": str(e)}

View File

@@ -4,15 +4,17 @@ Windows implementation of automation and accessibility handlers.
This implementation uses pyautogui for GUI automation and Windows-specific APIs
for accessibility and system operations.
"""
from typing import Dict, Any, List, Tuple, Optional
import logging
import subprocess
import asyncio
import base64
import logging
import os
import subprocess
from io import BytesIO
from pynput.mouse import Controller as MouseController
from typing import Any, Dict, List, Optional, Tuple
from pynput.keyboard import Controller as KeyboardController
from pynput.mouse import Controller as MouseController
# Configure logger
logger = logging.getLogger(__name__)
@@ -20,6 +22,7 @@ logger = logging.getLogger(__name__)
# Try to import pyautogui
try:
import pyautogui
pyautogui.FAILSAFE = False
logger.info("pyautogui successfully imported, GUI automation available")
except Exception as e:
@@ -28,58 +31,62 @@ except Exception as e:
# Try to import Windows-specific modules
try:
import win32gui
import win32con
import win32api
import win32con
import win32gui
logger.info("Windows API modules successfully imported")
WINDOWS_API_AVAILABLE = True
except Exception as e:
logger.error(f"Windows API modules import failed: {str(e)}. Some Windows-specific features will be unavailable.")
logger.error(
f"Windows API modules import failed: {str(e)}. Some Windows-specific features will be unavailable."
)
WINDOWS_API_AVAILABLE = False
from .base import BaseAccessibilityHandler, BaseAutomationHandler
class WindowsAccessibilityHandler(BaseAccessibilityHandler):
"""Windows implementation of accessibility handler."""
async def get_accessibility_tree(self) -> Dict[str, Any]:
"""Get the accessibility tree of the current window.
Returns:
Dict[str, Any]: A dictionary containing the success status and either
the accessibility tree or an error message.
Structure: {"success": bool, "tree": dict} or
Structure: {"success": bool, "tree": dict} or
{"success": bool, "error": str}
"""
if not WINDOWS_API_AVAILABLE:
return {"success": False, "error": "Windows API not available"}
try:
# Get the foreground window
hwnd = win32gui.GetForegroundWindow()
if not hwnd:
return {"success": False, "error": "No foreground window found"}
# Get window information
window_text = win32gui.GetWindowText(hwnd)
rect = win32gui.GetWindowRect(hwnd)
tree = {
"role": "Window",
"title": window_text,
"position": {"x": rect[0], "y": rect[1]},
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]},
"children": []
"children": [],
}
# Enumerate child windows
def enum_child_proc(hwnd_child, children_list):
"""Callback function to enumerate child windows and collect their information.
Args:
hwnd_child: Handle to the child window being enumerated.
children_list: List to append child window information to.
Returns:
bool: True to continue enumeration, False to stop.
"""
@@ -87,46 +94,49 @@ class WindowsAccessibilityHandler(BaseAccessibilityHandler):
child_text = win32gui.GetWindowText(hwnd_child)
child_rect = win32gui.GetWindowRect(hwnd_child)
child_class = win32gui.GetClassName(hwnd_child)
child_info = {
"role": child_class,
"title": child_text,
"position": {"x": child_rect[0], "y": child_rect[1]},
"size": {"width": child_rect[2] - child_rect[0], "height": child_rect[3] - child_rect[1]},
"children": []
"size": {
"width": child_rect[2] - child_rect[0],
"height": child_rect[3] - child_rect[1],
},
"children": [],
}
children_list.append(child_info)
except Exception as e:
logger.debug(f"Error getting child window info: {e}")
return True
win32gui.EnumChildWindows(hwnd, enum_child_proc, tree["children"])
return {"success": True, "tree": tree}
except Exception as e:
logger.error(f"Error getting accessibility tree: {e}")
return {"success": False, "error": str(e)}
async def find_element(self, role: Optional[str] = None,
title: Optional[str] = None,
value: Optional[str] = None) -> Dict[str, Any]:
async def find_element(
self, role: Optional[str] = None, title: Optional[str] = None, value: Optional[str] = None
) -> Dict[str, Any]:
"""Find an element in the accessibility tree by criteria.
Args:
role (Optional[str]): The role or class name of the element to find.
title (Optional[str]): The title or text of the element to find.
value (Optional[str]): The value of the element (not used in Windows implementation).
Returns:
Dict[str, Any]: A dictionary containing the success status and either
the found element or an error message.
Structure: {"success": bool, "element": dict} or
Structure: {"success": bool, "element": dict} or
{"success": bool, "error": str}
"""
if not WINDOWS_API_AVAILABLE:
return {"success": False, "error": "Windows API not available"}
try:
# Find window by title if specified
if title:
@@ -139,10 +149,10 @@ class WindowsAccessibilityHandler(BaseAccessibilityHandler):
"role": "Window",
"title": title,
"position": {"x": rect[0], "y": rect[1]},
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]}
}
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]},
},
}
# Find window by class name if role is specified
if role:
hwnd = win32gui.FindWindow(role, None)
@@ -155,36 +165,40 @@ class WindowsAccessibilityHandler(BaseAccessibilityHandler):
"role": role,
"title": window_text,
"position": {"x": rect[0], "y": rect[1]},
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]}
}
"size": {"width": rect[2] - rect[0], "height": rect[3] - rect[1]},
},
}
return {"success": False, "error": "Element not found"}
except Exception as e:
logger.error(f"Error finding element: {e}")
return {"success": False, "error": str(e)}
class WindowsAutomationHandler(BaseAutomationHandler):
"""Windows implementation of automation handler using pyautogui and Windows APIs."""
mouse = MouseController()
keyboard = KeyboardController()
# Mouse Actions
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_down(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Press and hold a mouse button at the specified coordinates.
Args:
x (Optional[int]): The x-coordinate to move to before pressing. If None, uses current position.
y (Optional[int]): The y-coordinate to move to before pressing. If None, uses current position.
button (str): The mouse button to press ("left", "right", or "middle").
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
if x is not None and y is not None:
pyautogui.moveTo(x, y)
@@ -192,21 +206,23 @@ class WindowsAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left") -> Dict[str, Any]:
async def mouse_up(
self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left"
) -> Dict[str, Any]:
"""Release a mouse button at the specified coordinates.
Args:
x (Optional[int]): The x-coordinate to move to before releasing. If None, uses current position.
y (Optional[int]): The y-coordinate to move to before releasing. If None, uses current position.
button (str): The mouse button to release ("left", "right", or "middle").
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
if x is not None and y is not None:
pyautogui.moveTo(x, y)
@@ -214,20 +230,20 @@ class WindowsAutomationHandler(BaseAutomationHandler):
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def move_cursor(self, x: int, y: int) -> Dict[str, Any]:
"""Move the mouse cursor to the specified coordinates.
Args:
x (int): The x-coordinate to move to.
y (int): The y-coordinate to move to.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.moveTo(x, y)
return {"success": True}
@@ -236,17 +252,17 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a left mouse click at the specified coordinates.
Args:
x (Optional[int]): The x-coordinate to click at. If None, clicks at current position.
y (Optional[int]): The y-coordinate to click at. If None, clicks at current position.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
if x is not None and y is not None:
pyautogui.moveTo(x, y)
@@ -257,17 +273,17 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
"""Perform a right mouse click at the specified coordinates.
Args:
x (Optional[int]): The x-coordinate to click at. If None, clicks at current position.
y (Optional[int]): The y-coordinate to click at. If None, clicks at current position.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
if x is not None and y is not None:
pyautogui.moveTo(x, y)
@@ -276,19 +292,21 @@ class WindowsAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None) -> Dict[str, Any]:
async def double_click(
self, x: Optional[int] = None, y: Optional[int] = None
) -> Dict[str, Any]:
"""Perform a double left mouse click at the specified coordinates.
Args:
x (Optional[int]): The x-coordinate to double-click at. If None, clicks at current position.
y (Optional[int]): The y-coordinate to double-click at. If None, clicks at current position.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
if x is not None and y is not None:
pyautogui.moveTo(x, y)
@@ -297,52 +315,56 @@ class WindowsAutomationHandler(BaseAutomationHandler):
except Exception as e:
return {"success": False, "error": str(e)}
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
async def drag_to(
self, x: int, y: int, button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag from the current position to the specified coordinates.
Args:
x (int): The x-coordinate to drag to.
y (int): The y-coordinate to drag to.
button (str): The mouse button to use for dragging ("left", "right", or "middle").
duration (float): The time in seconds to take for the drag operation.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.dragTo(x, y, duration=duration, button=button)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5) -> Dict[str, Any]:
async def drag(
self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5
) -> Dict[str, Any]:
"""Drag the mouse through a series of coordinates.
Args:
path (List[Tuple[int, int]]): A list of (x, y) coordinate tuples to drag through.
button (str): The mouse button to use for dragging ("left", "right", or "middle").
duration (float): The total time in seconds for the entire drag operation.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
if not path:
return {"success": False, "error": "Path is empty"}
# Move to first position
pyautogui.moveTo(*path[0])
# Drag through all positions
for x, y in path[1:]:
pyautogui.dragTo(x, y, duration=duration/len(path), button=button)
pyautogui.dragTo(x, y, duration=duration / len(path), button=button)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
@@ -350,70 +372,68 @@ class WindowsAutomationHandler(BaseAutomationHandler):
# Keyboard Actions
async def key_down(self, key: str) -> Dict[str, Any]:
"""Press and hold a keyboard key.
Args:
key (str): The key to press down (e.g., 'ctrl', 'shift', 'a').
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.keyDown(key)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def key_up(self, key: str) -> Dict[str, Any]:
"""Release a keyboard key.
Args:
key (str): The key to release (e.g., 'ctrl', 'shift', 'a').
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.keyUp(key)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def type_text(self, text: str) -> Dict[str, Any]:
"""Type the specified text.
Args:
text (str): The text to type.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.write(text)
# use pynput for Unicode support
self.keyboard.type(text)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def press_key(self, key: str) -> Dict[str, Any]:
"""Press and release a keyboard key.
Args:
key (str): The key to press (e.g., 'enter', 'space', 'tab').
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.press(key)
return {"success": True}
@@ -422,16 +442,16 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def hotkey(self, keys: List[str]) -> Dict[str, Any]:
"""Press a combination of keys simultaneously.
Args:
keys (List[str]): The keys to press together (e.g., ['ctrl', 'c'], ['alt', 'tab']).
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.hotkey(*keys)
return {"success": True}
@@ -441,35 +461,35 @@ class WindowsAutomationHandler(BaseAutomationHandler):
# Scrolling Actions
async def scroll(self, x: int, y: int) -> Dict[str, Any]:
"""Scroll vertically at the current cursor position.
Args:
x (int): Horizontal scroll amount (not used in pyautogui implementation).
y (int): Vertical scroll amount. Positive values scroll up, negative values scroll down.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
self.mouse.scroll(x, y)
return {"success": True}
except Exception as e:
return {"success": False, "error": str(e)}
async def scroll_down(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll down by the specified number of clicks.
Args:
clicks (int): The number of scroll clicks to perform downward.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.scroll(-clicks)
return {"success": True}
@@ -478,16 +498,16 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def scroll_up(self, clicks: int = 1) -> Dict[str, Any]:
"""Scroll up by the specified number of clicks.
Args:
clicks (int): The number of scroll clicks to perform upward.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
pyautogui.scroll(clicks)
return {"success": True}
@@ -497,22 +517,23 @@ class WindowsAutomationHandler(BaseAutomationHandler):
# Screen Actions
async def screenshot(self) -> Dict[str, Any]:
"""Capture a screenshot of the entire screen.
Returns:
Dict[str, Any]: A dictionary containing the success status and either
base64-encoded image data or an error message.
Structure: {"success": bool, "image_data": str} or
Structure: {"success": bool, "image_data": str} or
{"success": bool, "error": str}
"""
if not pyautogui:
return {"success": False, "error": "pyautogui not available"}
try:
from PIL import Image
screenshot = pyautogui.screenshot()
if not isinstance(screenshot, Image.Image):
return {"success": False, "error": "Failed to capture screenshot"}
buffered = BytesIO()
screenshot.save(buffered, format="PNG", optimize=True)
buffered.seek(0)
@@ -523,11 +544,11 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def get_screen_size(self) -> Dict[str, Any]:
"""Get the size of the screen in pixels.
Returns:
Dict[str, Any]: A dictionary containing the success status and either
screen size information or an error message.
Structure: {"success": bool, "size": {"width": int, "height": int}} or
Structure: {"success": bool, "size": {"width": int, "height": int}} or
{"success": bool, "error": str}
"""
try:
@@ -546,11 +567,11 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def get_cursor_position(self) -> Dict[str, Any]:
"""Get the current position of the mouse cursor.
Returns:
Dict[str, Any]: A dictionary containing the success status and either
cursor position or an error message.
Structure: {"success": bool, "position": {"x": int, "y": int}} or
Structure: {"success": bool, "position": {"x": int, "y": int}} or
{"success": bool, "error": str}
"""
try:
@@ -569,15 +590,16 @@ class WindowsAutomationHandler(BaseAutomationHandler):
# Clipboard Actions
async def copy_to_clipboard(self) -> Dict[str, Any]:
"""Get the current content of the clipboard.
Returns:
Dict[str, Any]: A dictionary containing the success status and either
clipboard content or an error message.
Structure: {"success": bool, "content": str} or
Structure: {"success": bool, "content": str} or
{"success": bool, "error": str}
"""
try:
import pyperclip
content = pyperclip.paste()
return {"success": True, "content": content}
except Exception as e:
@@ -585,15 +607,16 @@ class WindowsAutomationHandler(BaseAutomationHandler):
async def set_clipboard(self, text: str) -> Dict[str, Any]:
"""Set the clipboard content to the specified text.
Args:
text (str): The text to copy to the clipboard.
Returns:
Dict[str, Any]: A dictionary with success status and optional error message.
"""
try:
import pyperclip
pyperclip.copy(text)
return {"success": True}
except Exception as e:
@@ -602,31 +625,29 @@ class WindowsAutomationHandler(BaseAutomationHandler):
# Command Execution
async def run_command(self, command: str) -> Dict[str, Any]:
"""Execute a shell command asynchronously.
Args:
command (str): The shell command to execute.
Returns:
Dict[str, Any]: A dictionary containing the success status and either
command output or an error message.
Structure: {"success": bool, "stdout": str, "stderr": str, "return_code": int} or
Structure: {"success": bool, "stdout": str, "stderr": str, "return_code": int} or
{"success": bool, "error": str}
"""
try:
# Create subprocess
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
# Wait for the subprocess to finish
stdout, stderr = await process.communicate()
# Return decoded output
return {
"success": True,
"stdout": stdout.decode() if stdout else "",
"stderr": stderr.decode() if stderr else "",
"return_code": process.returncode
"success": True,
"stdout": stdout.decode() if stdout else "",
"stderr": stderr.decode() if stderr else "",
"return_code": process.returncode,
}
except Exception as e:
return {"success": False, "error": str(e)}

View File

@@ -1,27 +1,37 @@
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request, HTTPException, Header
from fastapi.responses import StreamingResponse, JSONResponse
from typing import List, Dict, Any, Optional, Union, Literal, cast
import uvicorn
import logging
import asyncio
import json
import traceback
import inspect
from contextlib import redirect_stdout, redirect_stderr
from io import StringIO
from .handlers.factory import HandlerFactory
import os
import aiohttp
import hashlib
import time
import inspect
import json
import logging
import os
import platform
import time
import traceback
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
from typing import Any, Dict, List, Literal, Optional, Union, cast
import aiohttp
import uvicorn
from fastapi import (
FastAPI,
Header,
HTTPException,
Request,
WebSocket,
WebSocketDisconnect,
)
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from .handlers.factory import HandlerFactory
# Authentication session TTL (in seconds). Override via env var CUA_AUTH_TTL_SECONDS. Default: 60s
AUTH_SESSION_TTL_SECONDS: int = int(os.environ.get("CUA_AUTH_TTL_SECONDS", "60"))
try:
from agent import ComputerAgent
HAS_AGENT = True
except ImportError:
HAS_AGENT = False
@@ -54,16 +64,20 @@ app.add_middleware(
protocol_version = 1
try:
from importlib.metadata import version
package_version = version("cua-computer-server")
except Exception:
# Fallback for cases where package is not installed or importlib.metadata is not available
try:
import pkg_resources
package_version = pkg_resources.get_distribution("cua-computer-server").version
except Exception:
package_version = "unknown"
accessibility_handler, automation_handler, diorama_handler, file_handler = HandlerFactory.create_handlers()
accessibility_handler, automation_handler, diorama_handler, file_handler = (
HandlerFactory.create_handlers()
)
handlers = {
"version": lambda: {"protocol": protocol_version, "package": package_version},
# App-Use commands
@@ -118,87 +132,91 @@ class AuthenticationManager:
def __init__(self):
self.sessions: Dict[str, Dict[str, Any]] = {}
self.container_name = os.environ.get("CONTAINER_NAME")
def _hash_credentials(self, container_name: str, api_key: str) -> str:
"""Create a hash of container name and API key for session identification"""
combined = f"{container_name}:{api_key}"
return hashlib.sha256(combined.encode()).hexdigest()
def _is_session_valid(self, session_data: Dict[str, Any]) -> bool:
"""Check if a session is still valid based on expiration time"""
if not session_data.get('valid', False):
if not session_data.get("valid", False):
return False
expires_at = session_data.get('expires_at', 0)
expires_at = session_data.get("expires_at", 0)
return time.time() < expires_at
async def auth(self, container_name: str, api_key: str) -> bool:
"""Authenticate container name and API key, using cached sessions when possible"""
# If no CONTAINER_NAME is set, always allow access (local development)
if not self.container_name:
logger.info("No CONTAINER_NAME set in environment. Allowing access (local development mode)")
logger.info(
"No CONTAINER_NAME set in environment. Allowing access (local development mode)"
)
return True
# Layer 1: VM Identity Verification
if container_name != self.container_name:
logger.warning(f"VM name mismatch. Expected: {self.container_name}, Got: {container_name}")
logger.warning(
f"VM name mismatch. Expected: {self.container_name}, Got: {container_name}"
)
return False
# Create hash for session lookup
session_hash = self._hash_credentials(container_name, api_key)
# Check if we have a valid cached session
if session_hash in self.sessions:
session_data = self.sessions[session_hash]
if self._is_session_valid(session_data):
logger.info(f"Using cached authentication for container: {container_name}")
return session_data['valid']
return session_data["valid"]
else:
# Remove expired session
del self.sessions[session_hash]
# No valid cached session, authenticate with API
logger.info(f"Authenticating with TryCUA API for container: {container_name}")
try:
async with aiohttp.ClientSession() as session:
headers = {
"Authorization": f"Bearer {api_key}"
}
headers = {"Authorization": f"Bearer {api_key}"}
async with session.get(
f"https://www.trycua.com/api/vm/auth?container_name={container_name}",
f"https://www.cua.ai/api/vm/auth?container_name={container_name}",
headers=headers,
) as resp:
is_valid = resp.status == 200 and bool((await resp.text()).strip())
# Cache the result with configurable expiration
self.sessions[session_hash] = {
'valid': is_valid,
'expires_at': time.time() + AUTH_SESSION_TTL_SECONDS
"valid": is_valid,
"expires_at": time.time() + AUTH_SESSION_TTL_SECONDS,
}
if is_valid:
logger.info(f"Authentication successful for container: {container_name}")
else:
logger.warning(f"Authentication failed for container: {container_name}. Status: {resp.status}")
logger.warning(
f"Authentication failed for container: {container_name}. Status: {resp.status}"
)
return is_valid
except aiohttp.ClientError as e:
logger.error(f"Failed to validate API key with TryCUA API: {str(e)}")
# Cache failed result to avoid repeated requests
self.sessions[session_hash] = {
'valid': False,
'expires_at': time.time() + AUTH_SESSION_TTL_SECONDS
"valid": False,
"expires_at": time.time() + AUTH_SESSION_TTL_SECONDS,
}
return False
except Exception as e:
logger.error(f"Unexpected error during authentication: {str(e)}")
# Cache failed result to avoid repeated requests
self.sessions[session_hash] = {
'valid': False,
'expires_at': time.time() + AUTH_SESSION_TTL_SECONDS
"valid": False,
"expires_at": time.time() + AUTH_SESSION_TTL_SECONDS,
}
return False
@@ -218,6 +236,7 @@ class ConnectionManager:
manager = ConnectionManager()
auth_manager = AuthenticationManager()
@app.get("/status")
async def status():
sys = platform.system().lower()
@@ -234,80 +253,67 @@ async def status():
features.append("agent")
return {"status": "ok", "os_type": os_type, "features": features}
@app.websocket("/ws", name="websocket_endpoint")
async def websocket_endpoint(websocket: WebSocket):
global handlers
# WebSocket message size is configured at the app or endpoint level, not on the instance
await manager.connect(websocket)
# Check if CONTAINER_NAME is set (indicating cloud provider)
server_container_name = os.environ.get("CONTAINER_NAME")
# If cloud provider, perform authentication handshake
if server_container_name:
try:
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Waiting for authentication...")
logger.info(
f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Waiting for authentication..."
)
# Wait for authentication message
auth_data = await websocket.receive_json()
# Validate auth message format
if auth_data.get("command") != "authenticate":
await websocket.send_json({
"success": False,
"error": "First message must be authentication"
})
await websocket.send_json(
{"success": False, "error": "First message must be authentication"}
)
await websocket.close()
manager.disconnect(websocket)
return
# Extract credentials
client_api_key = auth_data.get("params", {}).get("api_key")
client_container_name = auth_data.get("params", {}).get("container_name")
# Validate credentials using AuthenticationManager
if not client_api_key:
await websocket.send_json({
"success": False,
"error": "API key required"
})
await websocket.send_json({"success": False, "error": "API key required"})
await websocket.close()
manager.disconnect(websocket)
return
if not client_container_name:
await websocket.send_json({
"success": False,
"error": "Container name required"
})
await websocket.send_json({"success": False, "error": "Container name required"})
await websocket.close()
manager.disconnect(websocket)
return
# Use AuthenticationManager for validation
is_authenticated = await auth_manager.auth(client_container_name, client_api_key)
if not is_authenticated:
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.send_json({"success": False, "error": "Authentication failed"})
await websocket.close()
manager.disconnect(websocket)
return
logger.info(f"Authentication successful for VM: {client_container_name}")
await websocket.send_json({
"success": True,
"message": "Authentication successful"
})
await websocket.send_json({"success": True, "message": "Authentication successful"})
except Exception as e:
logger.error(f"Error during authentication handshake: {str(e)}")
await websocket.send_json({
"success": False,
"error": "Authentication failed"
})
await websocket.send_json({"success": False, "error": "Authentication failed"})
await websocket.close()
manager.disconnect(websocket)
return
@@ -330,7 +336,7 @@ async def websocket_endpoint(websocket: WebSocket):
handler_func = handlers[command]
sig = inspect.signature(handler_func)
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
# Handle both sync and async functions
if asyncio.iscoroutinefunction(handler_func):
result = await handler_func(**filtered_params)
@@ -367,20 +373,21 @@ async def websocket_endpoint(websocket: WebSocket):
pass
manager.disconnect(websocket)
@app.post("/cmd")
async def cmd_endpoint(
request: Request,
container_name: Optional[str] = Header(None, alias="X-Container-Name"),
api_key: Optional[str] = Header(None, alias="X-API-Key")
api_key: Optional[str] = Header(None, alias="X-API-Key"),
):
"""
Backup endpoint for when WebSocket connections fail.
Accepts commands via HTTP POST with streaming response.
Headers:
- X-Container-Name: Container name for cloud authentication
- X-API-Key: API key for cloud authentication
Body:
{
"command": "command_name",
@@ -388,7 +395,7 @@ async def cmd_endpoint(
}
"""
global handlers
# Parse request body
try:
body = await request.json()
@@ -396,32 +403,34 @@ async def cmd_endpoint(
params = body.get("params", {})
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid JSON body: {str(e)}")
if not command:
raise HTTPException(status_code=400, detail="Command is required")
# Check if CONTAINER_NAME is set (indicating cloud provider)
server_container_name = os.environ.get("CONTAINER_NAME")
# If cloud provider, perform authentication
if server_container_name:
logger.info(f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication...")
logger.info(
f"Cloud provider detected. CONTAINER_NAME: {server_container_name}. Performing authentication..."
)
# Validate required headers
if not container_name:
raise HTTPException(status_code=401, detail="Container name required")
if not api_key:
raise HTTPException(status_code=401, detail="API key required")
# Validate with AuthenticationManager
is_authenticated = await auth_manager.auth(container_name, api_key)
if not is_authenticated:
raise HTTPException(status_code=401, detail="Authentication failed")
if command not in handlers:
raise HTTPException(status_code=400, detail=f"Unknown command: {command}")
async def generate_response():
"""Generate streaming response for the command execution"""
try:
@@ -429,35 +438,36 @@ async def cmd_endpoint(
handler_func = handlers[command]
sig = inspect.signature(handler_func)
filtered_params = {k: v for k, v in params.items() if k in sig.parameters}
# Handle both sync and async functions
if asyncio.iscoroutinefunction(handler_func):
result = await handler_func(**filtered_params)
else:
# Run sync functions in thread pool to avoid blocking event loop
result = await asyncio.to_thread(handler_func, **filtered_params)
# Stream the successful result
response_data = {"success": True, **result}
yield f"data: {json.dumps(response_data)}\n\n"
except Exception as cmd_error:
logger.error(f"Error executing command {command}: {str(cmd_error)}")
logger.error(traceback.format_exc())
# Stream the error result
error_data = {"success": False, "error": str(cmd_error)}
yield f"data: {json.dumps(error_data)}\n\n"
return StreamingResponse(
generate_response(),
media_type="text/plain",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
}
},
)
@app.post("/responses")
async def agent_response_endpoint(
request: Request,
@@ -480,11 +490,17 @@ async def agent_response_endpoint(
"""
if not HAS_AGENT:
raise HTTPException(status_code=501, detail="ComputerAgent not available")
# Authenticate via AuthenticationManager if running in cloud (CONTAINER_NAME set)
container_name = os.environ.get("CONTAINER_NAME")
if container_name:
is_public = os.environ.get("CUA_ENABLE_PUBLIC_PROXY", "").lower().strip() in ["1", "true", "yes", "y", "on"]
is_public = os.environ.get("CUA_ENABLE_PUBLIC_PROXY", "").lower().strip() in [
"1",
"true",
"yes",
"y",
"on",
]
if not is_public:
if not api_key:
raise HTTPException(status_code=401, detail="Missing AGENT PROXY auth headers")
@@ -511,10 +527,12 @@ async def agent_response_endpoint(
def __init__(self, overrides: Dict[str, str]):
self.overrides = overrides
self._original: Dict[str, Optional[str]] = {}
def __enter__(self):
for k, v in (self.overrides or {}).items():
self._original[k] = os.environ.get(k)
os.environ[k] = str(v)
def __exit__(self, exc_type, exc, tb):
for k, old in self._original.items():
if old is None:
@@ -598,9 +616,9 @@ async def agent_response_endpoint(
start = path[0]
await self._auto.mouse_down(start["x"], start["y"])
for pt in path[1:]:
await self._auto.move_cursor(pt["x"], pt["y"])
await self._auto.move_cursor(pt["x"], pt["y"])
end = path[-1]
await self._auto.mouse_up(end["x"], end["y"])
await self._auto.mouse_up(end["x"], end["y"])
async def get_current_url(self) -> str:
# Not available in this server context
@@ -667,7 +685,11 @@ async def agent_response_endpoint(
async for result in agent.run(messages):
total_output += result["output"]
# Try to collect usage if present
if isinstance(result, dict) and "usage" in result and isinstance(result["usage"], dict):
if (
isinstance(result, dict)
and "usage" in result
and isinstance(result["usage"], dict)
):
# Merge usage counters
for k, v in result["usage"].items():
if isinstance(v, (int, float)):
@@ -686,14 +708,14 @@ async def agent_response_endpoint(
logger.error(f"Error running agent: {str(e)}")
logger.error(traceback.format_exc())
error = str(e)
# Build response payload
payload = {
"model": model,
"error": error,
"output": total_output,
"usage": total_usage,
"status": "completed" if not error else "failed"
"status": "completed" if not error else "failed",
}
# CORS: allow any origin

View File

@@ -5,8 +5,9 @@ Provides a clean API for starting and stopping the server.
import asyncio
import logging
import uvicorn
from typing import Optional
import uvicorn
from fastapi import FastAPI
from .main import app as fastapi_app
@@ -32,8 +33,14 @@ class Server:
await server.stop() # Stop the server
"""
def __init__(self, host: str = "0.0.0.0", port: int = 8000, log_level: str = "info",
ssl_keyfile: Optional[str] = None, ssl_certfile: Optional[str] = None):
def __init__(
self,
host: str = "0.0.0.0",
port: int = 8000,
log_level: str = "info",
ssl_keyfile: Optional[str] = None,
ssl_certfile: Optional[str] = None,
):
"""
Initialize the server.
@@ -58,12 +65,12 @@ class Server:
Start the server synchronously. This will block until the server is stopped.
"""
uvicorn.run(
self.app,
host=self.host,
port=self.port,
self.app,
host=self.host,
port=self.port,
log_level=self.log_level,
ssl_keyfile=self.ssl_keyfile,
ssl_certfile=self.ssl_certfile
ssl_certfile=self.ssl_certfile,
)
async def start_async(self) -> None:
@@ -72,12 +79,12 @@ class Server:
will run in the background.
"""
server_config = uvicorn.Config(
self.app,
host=self.host,
port=self.port,
self.app,
host=self.host,
port=self.port,
log_level=self.log_level,
ssl_keyfile=self.ssl_keyfile,
ssl_certfile=self.ssl_certfile
ssl_certfile=self.ssl_certfile,
)
self._should_exit.clear()

View File

@@ -12,9 +12,10 @@ import platform
import subprocess
import sys
import time
import websockets
from typing import Optional
import websockets
logger = logging.getLogger(__name__)
@@ -45,62 +46,62 @@ class Watchdog:
"""Watchdog class to monitor server health via WebSocket connection.
Unix/Linux only - provides restart capabilities.
"""
def __init__(self, cli_args: Optional[dict] = None, ping_interval: int = 30):
"""
Initialize the watchdog.
Args:
cli_args: Dictionary of CLI arguments to replicate when restarting
ping_interval: Interval between ping checks in seconds
"""
# Check if running on Unix/Linux
if platform.system() not in ['Linux', 'Darwin']:
if platform.system() not in ["Linux", "Darwin"]:
raise RuntimeError("Watchdog is only supported on Unix/Linux systems")
# Store CLI arguments for restart
self.cli_args = cli_args or {}
self.host = self.cli_args.get('host', 'localhost')
self.port = self.cli_args.get('port', 8000)
self.host = self.cli_args.get("host", "localhost")
self.port = self.cli_args.get("port", 8000)
self.ping_interval = ping_interval
self.container_name = os.environ.get("CONTAINER_NAME")
self.running = False
self.restart_enabled = True
@property
def ws_uri(self) -> str:
"""Get the WebSocket URI using the current IP address.
Returns:
WebSocket URI for the Computer API Server
"""
ip_address = "localhost" if not self.container_name else f"{self.container_name}.containers.cloud.trycua.com"
ip_address = (
"localhost"
if not self.container_name
else f"{self.container_name}.containers.cloud.trycua.com"
)
protocol = "wss" if self.container_name else "ws"
port = "8443" if self.container_name else "8000"
return f"{protocol}://{ip_address}:{port}/ws"
async def ping(self) -> bool:
"""
Test connection to the WebSocket endpoint.
Returns:
True if connection successful, False otherwise
"""
try:
# Create a simple ping message
ping_message = {
"command": "get_screen_size",
"params": {}
}
ping_message = {"command": "get_screen_size", "params": {}}
# Try to connect to the WebSocket
async with websockets.connect(
self.ws_uri,
max_size=1024 * 1024 * 10 # 10MB limit to match server
self.ws_uri, max_size=1024 * 1024 * 10 # 10MB limit to match server
) as websocket:
# Send ping message
await websocket.send(json.dumps(ping_message))
# Wait for any response or just close
try:
response = await asyncio.wait_for(websocket.recv(), timeout=5)
@@ -111,30 +112,27 @@ class Watchdog:
except Exception as e:
logger.warning(f"Ping failed: {e}")
return False
def kill_processes_on_port(self, port: int) -> bool:
"""
Kill any processes using the specified port.
Args:
port: Port number to check and kill processes on
Returns:
True if processes were killed or none found, False on error
"""
try:
# Find processes using the port
result = subprocess.run(
["lsof", "-ti", f":{port}"],
capture_output=True,
text=True,
timeout=10
["lsof", "-ti", f":{port}"], capture_output=True, text=True, timeout=10
)
if result.returncode == 0 and result.stdout.strip():
pids = result.stdout.strip().split('\n')
pids = result.stdout.strip().split("\n")
logger.info(f"Found {len(pids)} processes using port {port}: {pids}")
# Kill each process
for pid in pids:
if pid.strip():
@@ -145,42 +143,42 @@ class Watchdog:
logger.warning(f"Timeout killing process {pid}")
except Exception as e:
logger.warning(f"Error killing process {pid}: {e}")
return True
else:
logger.debug(f"No processes found using port {port}")
return True
except subprocess.TimeoutExpired:
logger.error(f"Timeout finding processes on port {port}")
return False
except Exception as e:
logger.error(f"Error finding processes on port {port}: {e}")
return False
def restart_server(self) -> bool:
"""
Attempt to restart the server by killing existing processes and starting new one.
Returns:
True if restart was attempted, False on error
"""
if not self.restart_enabled:
logger.info("Server restart is disabled")
return False
try:
logger.info("Attempting to restart server...")
# Kill processes on the port
port_to_kill = 8443 if self.container_name else self.port
if not self.kill_processes_on_port(port_to_kill):
logger.error("Failed to kill processes on port, restart aborted")
return False
# Wait a moment for processes to die
time.sleep(2)
# Try to restart the server
# In container mode, we can't easily restart, so just log
if self.container_name:
@@ -190,50 +188,50 @@ class Watchdog:
else:
# For local mode, try to restart the CLI
logger.info("Attempting to restart local server...")
# Get the current Python executable and script
python_exe = sys.executable
# Try to find the CLI module
try:
# Build command with all original CLI arguments
cmd = [python_exe, "-m", "computer_server.cli"]
# Add all CLI arguments except watchdog-related ones
for key, value in self.cli_args.items():
if key in ['watchdog', 'watchdog_interval', 'no_restart']:
if key in ["watchdog", "watchdog_interval", "no_restart"]:
continue # Skip watchdog args to avoid recursive watchdog
# Convert underscores to hyphens for CLI args
arg_name = f"--{key.replace('_', '-')}"
if isinstance(value, bool):
if value: # Only add flag if True
cmd.append(arg_name)
else:
cmd.extend([arg_name, str(value)])
logger.info(f"Starting server with command: {' '.join(cmd)}")
# Start process in background
subprocess.Popen(
cmd,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True
start_new_session=True,
)
logger.info("Server restart initiated")
return True
except Exception as e:
logger.error(f"Failed to restart server: {e}")
return False
except Exception as e:
logger.error(f"Error during server restart: {e}")
return False
async def start_monitoring(self) -> None:
"""Start the watchdog monitoring loop."""
self.running = True
@@ -241,14 +239,14 @@ class Watchdog:
logger.info(f"Ping interval: {self.ping_interval} seconds")
if self.container_name:
logger.info(f"Container mode detected: {self.container_name}")
consecutive_failures = 0
max_failures = 3
while self.running:
try:
success = await self.ping()
if success:
if consecutive_failures > 0:
logger.info("Server connection restored")
@@ -257,15 +255,17 @@ class Watchdog:
else:
consecutive_failures += 1
logger.warning(f"Ping failed ({consecutive_failures}/{max_failures})")
if consecutive_failures >= max_failures:
logger.error(f"Server appears to be down after {max_failures} consecutive failures")
logger.error(
f"Server appears to be down after {max_failures} consecutive failures"
)
# Attempt to restart the server
if self.restart_enabled:
logger.info("Attempting automatic server restart...")
restart_success = self.restart_server()
if restart_success:
logger.info("Server restart initiated, waiting before next ping...")
# Wait longer after restart attempt
@@ -275,17 +275,17 @@ class Watchdog:
logger.error("Server restart failed")
else:
logger.warning("Automatic restart is disabled")
# Wait for next ping interval
await asyncio.sleep(self.ping_interval)
except asyncio.CancelledError:
logger.info("Watchdog monitoring cancelled")
break
except Exception as e:
logger.error(f"Unexpected error in watchdog loop: {e}")
await asyncio.sleep(self.ping_interval)
def stop_monitoring(self) -> None:
"""Stop the watchdog monitoring."""
self.running = False
@@ -295,13 +295,13 @@ class Watchdog:
async def run_watchdog(cli_args: Optional[dict] = None, ping_interval: int = 30) -> None:
"""
Run the watchdog monitoring.
Args:
cli_args: Dictionary of CLI arguments to replicate when restarting
ping_interval: Interval between ping checks in seconds
"""
watchdog = Watchdog(cli_args=cli_args, ping_interval=ping_interval)
try:
await watchdog.start_monitoring()
except KeyboardInterrupt:
@@ -313,21 +313,18 @@ async def run_watchdog(cli_args: Optional[dict] = None, ping_interval: int = 30)
if __name__ == "__main__":
# For testing the watchdog standalone
import argparse
parser = argparse.ArgumentParser(description="Run Computer API server watchdog")
parser.add_argument("--host", default="localhost", help="Server host to monitor")
parser.add_argument("--port", type=int, default=8000, help="Server port to monitor")
parser.add_argument("--ping-interval", type=int, default=30, help="Ping interval in seconds")
args = parser.parse_args()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
cli_args = {
'host': args.host,
'port': args.port
}
cli_args = {"host": args.host, "port": args.port}
asyncio.run(run_watchdog(cli_args, args.ping_interval))

View File

@@ -4,14 +4,15 @@ build-backend = "pdm.backend"
[project]
name = "cua-computer-server"
version = "0.1.0"
version = "0.1.27"
description = "Server component for the Computer-Use Interface (CUI) framework powering Cua"
authors = [
{ name = "TryCua", email = "gh@trycua.com" }
]
readme = "README.md"
license = { text = "MIT" }
requires-python = ">=3.9"
requires-python = ">=3.12"
dependencies = [
"fastapi>=0.111.0",
"uvicorn[standard]>=0.27.0",
@@ -21,7 +22,14 @@ dependencies = [
"pillow>=10.2.0",
"aiohttp>=3.9.1",
"pyperclip>=1.9.0",
"websockets>=12.0"
"websockets>=12.0",
# OS-specific runtime deps
"pyobjc-framework-Cocoa>=10.1; sys_platform == 'darwin'",
"pyobjc-framework-Quartz>=10.1; sys_platform == 'darwin'",
"pyobjc-framework-ApplicationServices>=10.1; sys_platform == 'darwin'",
"python-xlib>=0.33; sys_platform == 'linux'",
"pywin32>=310; sys_platform == 'win32'",
"pip-system-certs; sys_platform == 'win32'",
]
[project.optional-dependencies]
@@ -66,23 +74,4 @@ dev = [
]
[tool.pdm.scripts]
api = "python -m computer_server"
[tool.ruff]
line-length = 100
target-version = "py310"
select = ["E", "F", "B", "I"]
fix = true
[tool.ruff.format]
docstring-code-format = true
[tool.mypy]
strict = true
python_version = "3.10"
ignore_missing_imports = true
disallow_untyped_defs = true
check_untyped_defs = true
warn_return_any = true
show_error_codes = true
warn_unused_ignores = false
api = "python -m computer_server"

View File

@@ -10,6 +10,7 @@ Usage:
"""
import sys
from computer_server.cli import main
if __name__ == "__main__":

View File

@@ -6,18 +6,22 @@ This script tests both WebSocket (/ws) and REST (/cmd) connections to the Comput
and keeps it alive, allowing you to verify the server is running correctly.
"""
import argparse
import asyncio
import json
import websockets
import argparse
import sys
import aiohttp
import os
import sys
import aiohttp
import dotenv
import websockets
dotenv.load_dotenv()
async def test_websocket_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
async def test_websocket_connection(
host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None
):
"""Test WebSocket connection to the Computer Server."""
if container_name:
# Container mode: use WSS with container domain and port 8443
@@ -37,19 +41,16 @@ async def test_websocket_connection(host="localhost", port=8000, keep_alive=Fals
if not api_key:
print("Error: API key required for container connections")
return False
print("Sending authentication...")
auth_message = {
"command": "authenticate",
"params": {
"api_key": api_key,
"container_name": container_name
}
"params": {"api_key": api_key, "container_name": container_name},
}
await websocket.send(json.dumps(auth_message))
auth_response = await websocket.recv()
print(f"Authentication response: {auth_response}")
# Check if authentication was successful
auth_data = json.loads(auth_response)
if not auth_data.get("success", False):
@@ -90,7 +91,9 @@ async def test_websocket_connection(host="localhost", port=8000, keep_alive=Fals
return True
async def test_rest_connection(host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None):
async def test_rest_connection(
host="localhost", port=8000, keep_alive=False, container_name=None, api_key=None
):
"""Test REST connection to the Computer Server."""
if container_name:
# Container mode: use HTTPS with container domain and port 8443
@@ -113,13 +116,11 @@ async def test_rest_connection(host="localhost", port=8000, keep_alive=False, co
return False
headers["X-Container-Name"] = container_name
headers["X-API-Key"] = api_key
print(f"Using container authentication headers")
print("Using container authentication headers")
# Test screenshot endpoint
async with session.post(
f"{base_url}/cmd",
json={"command": "screenshot", "params": {}},
headers=headers
f"{base_url}/cmd", json={"command": "screenshot", "params": {}}, headers=headers
) as response:
if response.status == 200:
text = await response.text()
@@ -133,7 +134,7 @@ async def test_rest_connection(host="localhost", port=8000, keep_alive=False, co
async with session.post(
f"{base_url}/cmd",
json={"command": "get_screen_size", "params": {}},
headers=headers
headers=headers,
) as response:
if response.status == 200:
text = await response.text()
@@ -151,7 +152,7 @@ async def test_rest_connection(host="localhost", port=8000, keep_alive=False, co
async with session.post(
f"{base_url}/cmd",
json={"command": "get_cursor_position", "params": {}},
headers=headers
headers=headers,
) as response:
if response.status == 200:
text = await response.text()
@@ -171,7 +172,9 @@ async def test_rest_connection(host="localhost", port=8000, keep_alive=False, co
return True
async def test_connection(host="localhost", port=8000, keep_alive=False, container_name=None, use_rest=False, api_key=None):
async def test_connection(
host="localhost", port=8000, keep_alive=False, container_name=None, use_rest=False, api_key=None
):
"""Test connection to the Computer Server using WebSocket or REST."""
if use_rest:
return await test_rest_connection(host, port, keep_alive, container_name, api_key)
@@ -183,40 +186,50 @@ def parse_args():
parser = argparse.ArgumentParser(description="Test connection to Computer Server")
parser.add_argument("--host", default="localhost", help="Host address (default: localhost)")
parser.add_argument("-p", "--port", type=int, default=8000, help="Port number (default: 8000)")
parser.add_argument("-c", "--container-name", help="Container name for cloud connection (uses WSS/HTTPS and port 8443)")
parser.add_argument("--api-key", help="API key for container authentication (can also use CUA_API_KEY env var)")
parser.add_argument(
"-c",
"--container-name",
help="Container name for cloud connection (uses WSS/HTTPS and port 8443)",
)
parser.add_argument(
"--api-key", help="API key for container authentication (can also use CUA_API_KEY env var)"
)
parser.add_argument("--keep-alive", action="store_true", help="Keep connection alive")
parser.add_argument("--rest", action="store_true", help="Use REST endpoint (/cmd) instead of WebSocket (/ws)")
parser.add_argument(
"--rest", action="store_true", help="Use REST endpoint (/cmd) instead of WebSocket (/ws)"
)
return parser.parse_args()
async def main():
args = parse_args()
# Convert hyphenated argument to underscore for function parameter
container_name = getattr(args, 'container_name', None)
container_name = getattr(args, "container_name", None)
# Get API key from argument or environment variable
api_key = getattr(args, 'api_key', None) or os.environ.get('CUA_API_KEY')
api_key = getattr(args, "api_key", None) or os.environ.get("CUA_API_KEY")
# Check if container name is provided but API key is missing
if container_name and not api_key:
print("Warning: Container name provided but no API key found.")
print("Please provide --api-key argument or set CUA_API_KEY environment variable.")
return 1
print(f"Testing {'REST' if args.rest else 'WebSocket'} connection...")
if container_name:
print(f"Container: {container_name}")
print(f"API Key: {'***' + api_key[-4:] if api_key and len(api_key) > 4 else 'Not provided'}")
print(
f"API Key: {'***' + api_key[-4:] if api_key and len(api_key) > 4 else 'Not provided'}"
)
success = await test_connection(
host=args.host,
port=args.port,
host=args.host,
port=args.port,
keep_alive=args.keep_alive,
container_name=container_name,
use_rest=args.rest,
api_key=api_key
api_key=api_key,
)
return 0 if success else 1

View File

@@ -0,0 +1,10 @@
[bumpversion]
current_version = 0.4.7
commit = True
tag = True
tag_name = computer-v{new_version}
message = Bump cua-computer to v{new_version}
[bumpversion:file:pyproject.toml]
search = version = "{current_version}"
replace = version = "{new_version}"

View File

@@ -8,10 +8,11 @@
</picture>
</div>
[![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#)
[![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85)
[![PyPI](https://img.shields.io/pypi/v/cua-computer?color=333333)](https://pypi.org/project/cua-computer/)
[![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#)
[![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85)
[![PyPI](https://img.shields.io/pypi/v/cua-computer?color=333333)](https://pypi.org/project/cua-computer/)
</h1>
</div>
@@ -29,11 +30,11 @@ from computer import Computer
computer = Computer(os_type="macos", display="1024x768", memory="8GB", cpu="4")
try:
await computer.run()
screenshot = await computer.interface.screenshot()
with open("screenshot.png", "wb") as f:
f.write(screenshot)
await computer.interface.move_cursor(100, 100)
await computer.interface.left_click()
await computer.interface.right_click(300, 300)

View File

@@ -1,19 +1,22 @@
from typing import Optional, List, Literal, Dict, Any, Union, TYPE_CHECKING, cast
import asyncio
from .models import Computer as ComputerConfig, Display
from .interface.factory import InterfaceFactory
import time
from PIL import Image
import io
import re
from .logger import Logger, LogLevel
import json
import logging
from core.telemetry import is_telemetry_enabled, record_event
import os
from . import helpers
import platform
import re
import time
import traceback
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union, cast
from core.telemetry import is_telemetry_enabled, record_event
from PIL import Image
from . import helpers
from .interface.factory import InterfaceFactory
from .logger import Logger, LogLevel
from .models import Computer as ComputerConfig
from .models import Display
SYSTEM_INFO = {
"os": platform.system().lower(),
@@ -27,6 +30,7 @@ from .providers.factory import VMProviderFactory
OSType = Literal["macos", "linux", "windows"]
class Computer:
"""Computer is the main class for interacting with the computer."""
@@ -40,8 +44,11 @@ class Computer:
Returns:
DioramaComputer: A proxy object with the Diorama interface, but using diorama_cmds.
"""
assert "app-use" in self.experiments, "App Usage is an experimental feature. Enable it by passing experiments=['app-use'] to Computer()"
assert (
"app-use" in self.experiments
), "App Usage is an experimental feature. Enable it by passing experiments=['app-use'] to Computer()"
from .diorama_computer import DioramaComputer
return DioramaComputer(self, apps)
def __init__(
@@ -63,7 +70,7 @@ class Computer:
storage: Optional[str] = None,
ephemeral: bool = False,
api_key: Optional[str] = None,
experiments: Optional[List[str]] = None
experiments: Optional[List[str]] = None,
):
"""Initialize a new Computer instance.
@@ -111,32 +118,36 @@ class Computer:
self.os_type = os_type
self.provider_type = provider_type
self.ephemeral = ephemeral
self.api_key = api_key
self.experiments = experiments or []
if "app-use" in self.experiments:
assert self.os_type == "macos", "App use experiment is only supported on macOS"
# The default is currently to use non-ephemeral storage
if storage and ephemeral and storage != "ephemeral":
raise ValueError("Storage path and ephemeral flag cannot be used together")
# Windows Sandbox always uses ephemeral storage
if self.provider_type == VMProviderType.WINSANDBOX:
if not ephemeral and storage != None and storage != "ephemeral":
self.logger.warning("Windows Sandbox storage is always ephemeral. Setting ephemeral=True.")
self.logger.warning(
"Windows Sandbox storage is always ephemeral. Setting ephemeral=True."
)
self.ephemeral = True
self.storage = "ephemeral"
else:
self.storage = "ephemeral" if ephemeral else storage
# For Lumier provider, store the first shared directory path to use
# for VM file sharing
self.shared_path = None
if shared_directories and len(shared_directories) > 0:
self.shared_path = shared_directories[0]
self.logger.info(f"Using first shared directory for VM file sharing: {self.shared_path}")
self.logger.info(
f"Using first shared directory for VM file sharing: {self.shared_path}"
)
# Store telemetry preference
self._telemetry_enabled = telemetry_enabled
@@ -154,8 +165,8 @@ class Computer:
self.interface_logger = Logger("computer.interface", verbosity)
if not use_host_computer_server:
if ":" not in image or len(image.split(":")) != 2:
raise ValueError("Image must be in the format <image_name>:<tag>")
if ":" not in image:
image = f"{image}:latest"
if not name:
# Normalize the name to be used for the VM
@@ -263,8 +274,14 @@ class Computer:
self.logger.info(f"Starting VM: {self.image}")
if not self._provider_context:
try:
provider_type_name = self.provider_type.name if isinstance(self.provider_type, VMProviderType) else self.provider_type
self.logger.verbose(f"Initializing {provider_type_name} provider context...")
provider_type_name = (
self.provider_type.name
if isinstance(self.provider_type, VMProviderType)
else self.provider_type
)
self.logger.verbose(
f"Initializing {provider_type_name} provider context..."
)
# Explicitly set provider parameters
storage = "ephemeral" if self.ephemeral else self.storage
@@ -281,9 +298,13 @@ class Computer:
if self.provider_type == VMProviderType.LUMIER:
self.logger.info(f"Using VM image for Lumier provider: {image}")
if shared_path:
self.logger.info(f"Using shared path for Lumier provider: {shared_path}")
self.logger.info(
f"Using shared path for Lumier provider: {shared_path}"
)
if noVNC_port:
self.logger.info(f"Using noVNC port for Lumier provider: {noVNC_port}")
self.logger.info(
f"Using noVNC port for Lumier provider: {noVNC_port}"
)
self.config.vm_provider = VMProviderFactory.create_provider(
self.provider_type,
port=port,
@@ -339,11 +360,17 @@ class Computer:
except ImportError as ie:
self.logger.error(f"Failed to import provider dependencies: {ie}")
if str(ie).find("lume") >= 0 and str(ie).find("lumier") < 0:
self.logger.error("Please install with: pip install cua-computer[lume]")
self.logger.error(
"Please install with: pip install cua-computer[lume]"
)
elif str(ie).find("lumier") >= 0 or str(ie).find("docker") >= 0:
self.logger.error("Please install with: pip install cua-computer[lumier] and make sure Docker is installed")
self.logger.error(
"Please install with: pip install cua-computer[lumier] and make sure Docker is installed"
)
elif str(ie).find("cloud") >= 0:
self.logger.error("Please install with: pip install cua-computer[cloud]")
self.logger.error(
"Please install with: pip install cua-computer[cloud]"
)
raise
except Exception as e:
self.logger.error(f"Failed to initialize provider context: {e}")
@@ -354,16 +381,14 @@ class Computer:
try:
if self.config.vm_provider is None:
raise RuntimeError(f"VM provider not initialized for {self.config.name}")
vm = await self.config.vm_provider.get_vm(self.config.name)
self.logger.verbose(f"Found existing VM: {self.config.name}")
is_running = vm.get("status") == "running"
except Exception as e:
self.logger.error(f"VM not found: {self.config.name}")
self.logger.error(f"Error: {e}")
raise RuntimeError(
f"VM {self.config.name} could not be found or created."
)
raise RuntimeError(f"VM {self.config.name} could not be found or created.")
# Start the VM if it's not running
if not is_running:
@@ -376,13 +401,10 @@ class Computer:
path = os.path.abspath(os.path.expanduser(path))
if os.path.exists(path):
# Add path in format expected by Lume API
shared_dirs.append({
"hostPath": path,
"readOnly": False
})
shared_dirs.append({"hostPath": path, "readOnly": False})
else:
self.logger.warning(f"Shared directory does not exist: {path}")
# Prepare run options to pass to the provider
run_opts = {}
@@ -392,11 +414,11 @@ class Computer:
"width": self.config.display.width,
"height": self.config.display.height,
}
# Check if scale_factor exists before adding it
if hasattr(self.config.display, "scale_factor"):
display_info["scale_factor"] = self.config.display.scale_factor
run_opts["display"] = display_info
# Add shared directories if available
@@ -406,21 +428,23 @@ class Computer:
# Run the VM with the provider
try:
if self.config.vm_provider is None:
raise RuntimeError(f"VM provider not initialized for {self.config.name}")
raise RuntimeError(
f"VM provider not initialized for {self.config.name}"
)
# Use the complete run_opts we prepared earlier
# Handle ephemeral storage for run_vm method too
storage_param = "ephemeral" if self.ephemeral else self.storage
# Log the image being used
self.logger.info(f"Running VM using image: {self.image}")
# Call provider.run_vm with explicit image parameter
response = await self.config.vm_provider.run_vm(
image=self.image,
name=self.config.name,
run_opts=run_opts,
storage=storage_param
storage=storage_param,
)
self.logger.info(f"VM run response: {response if response else 'None'}")
except Exception as run_error:
@@ -432,14 +456,16 @@ class Computer:
try:
if self.provider_type == VMProviderType.LUMIER:
max_retries = 60 # Increased for Lumier VM startup which takes longer
retry_delay = 3 # 3 seconds between retries for Lumier
retry_delay = 3 # 3 seconds between retries for Lumier
else:
max_retries = 30 # Default for other providers
retry_delay = 2 # 2 seconds between retries
self.logger.info(f"Waiting up to {max_retries * retry_delay} seconds for VM to be ready...")
retry_delay = 2 # 2 seconds between retries
self.logger.info(
f"Waiting up to {max_retries * retry_delay} seconds for VM to be ready..."
)
ip = await self.get_ip(max_retries=max_retries, retry_delay=retry_delay)
# If we get here, we have a valid IP
self.logger.info(f"VM is ready with IP: {ip}")
ip_address = ip
@@ -451,13 +477,16 @@ class Computer:
raise RuntimeError(f"VM failed to become ready: {wait_error}")
except Exception as e:
self.logger.error(f"Failed to initialize computer: {e}")
self.logger.error(traceback.format_exc())
raise RuntimeError(f"Failed to initialize computer: {e}")
try:
# Verify we have a valid IP before initializing the interface
if not ip_address or ip_address == "unknown" or ip_address == "0.0.0.0":
raise RuntimeError(f"Cannot initialize interface - invalid IP address: {ip_address}")
raise RuntimeError(
f"Cannot initialize interface - invalid IP address: {ip_address}"
)
# Initialize the interface using the factory with the specified OS
self.logger.info(f"Initializing interface for {self.os_type} at {ip_address}")
from .interface.base import BaseComputerInterface
@@ -467,18 +496,17 @@ class Computer:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
os=self.os_type,
ip_address=ip_address,
api_key=self.api_key,
vm_name=self.config.name
vm_name=self.config.name,
),
)
else:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address
os=self.os_type, ip_address=ip_address
),
)
@@ -508,10 +536,10 @@ class Computer:
# Set the initialization flag and clear the initializing flag
self._initialized = True
# Set this instance as the default computer for remote decorators
helpers.set_default_computer(self)
self.logger.info("Computer successfully initialized")
except Exception as e:
raise
@@ -520,7 +548,7 @@ class Computer:
duration_ms = (time.time() - start_time) * 1000
self.logger.debug(f"Computer initialization took {duration_ms:.2f}ms")
return
async def disconnect(self) -> None:
"""Disconnect from the computer's WebSocket interface."""
if self._interface:
@@ -534,13 +562,17 @@ class Computer:
self.logger.info("Stopping Computer...")
# In VM mode, first explicitly stop the VM, then exit the provider context
if not self.use_host_computer_server and self._provider_context and self.config.vm_provider is not None:
if (
not self.use_host_computer_server
and self._provider_context
and self.config.vm_provider is not None
):
try:
self.logger.info(f"Stopping VM {self.config.name}...")
await self.config.vm_provider.stop_vm(
name=self.config.name,
storage=self.storage # Pass storage explicitly for clarity
)
name=self.config.name,
storage=self.storage, # Pass storage explicitly for clarity
)
except Exception as e:
self.logger.error(f"Error stopping VM: {e}")
@@ -551,55 +583,156 @@ class Computer:
await self.disconnect()
self.logger.info("Computer stopped")
except Exception as e:
self.logger.debug(f"Error during cleanup: {e}") # Log as debug since this might be expected
self.logger.debug(
f"Error during cleanup: {e}"
) # Log as debug since this might be expected
finally:
# Log stop time for performance monitoring
duration_ms = (time.time() - start_time) * 1000
self.logger.debug(f"Computer stop process took {duration_ms:.2f}ms")
return
async def start(self) -> None:
"""Start the computer."""
await self.run()
async def restart(self) -> None:
"""Restart the computer.
If using a VM provider that supports restart, this will issue a restart
without tearing down the provider context, then reconnect the interface.
Falls back to stop()+run() when a provider restart is not available.
"""
# Host computer server: just disconnect and run again
if self.use_host_computer_server:
try:
await self.disconnect()
finally:
await self.run()
return
# If no VM provider context yet, fall back to full run
if not getattr(self, "_provider_context", None) or self.config.vm_provider is None:
self.logger.info("No provider context active; performing full restart via run()")
await self.run()
return
# Gracefully close current interface connection if present
if self._interface:
try:
self._interface.close()
except Exception as e:
self.logger.debug(f"Error closing interface prior to restart: {e}")
# Attempt provider-level restart if implemented
try:
storage_param = "ephemeral" if self.ephemeral else self.storage
if hasattr(self.config.vm_provider, "restart_vm"):
self.logger.info(f"Restarting VM {self.config.name} via provider...")
await self.config.vm_provider.restart_vm(
name=self.config.name, storage=storage_param
)
else:
# Fallback: stop then start without leaving provider context
self.logger.info(
f"Provider has no restart_vm; performing stop+start for {self.config.name}..."
)
await self.config.vm_provider.stop_vm(name=self.config.name, storage=storage_param)
await self.config.vm_provider.run_vm(
image=self.image, name=self.config.name, run_opts={}, storage=storage_param
)
except Exception as e:
self.logger.error(f"Failed to restart VM via provider: {e}")
# As a last resort, do a full stop (with provider context exit) and run
try:
await self.stop()
finally:
await self.run()
return
# Wait for VM to be ready and reconnect interface
try:
self.logger.info("Waiting for VM to be ready after restart...")
if self.provider_type == VMProviderType.LUMIER:
max_retries = 60
retry_delay = 3
else:
max_retries = 30
retry_delay = 2
ip_address = await self.get_ip(max_retries=max_retries, retry_delay=retry_delay)
self.logger.info(f"Re-initializing interface for {self.os_type} at {ip_address}")
from .interface.base import BaseComputerInterface
if self.provider_type == VMProviderType.CLOUD and self.api_key and self.config.name:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address,
api_key=self.api_key,
vm_name=self.config.name,
),
)
else:
self._interface = cast(
BaseComputerInterface,
InterfaceFactory.create_interface_for_os(
os=self.os_type,
ip_address=ip_address,
),
)
self.logger.info("Connecting to WebSocket interface after restart...")
await self._interface.wait_for_ready(timeout=30)
self.logger.info("Computer reconnected and ready after restart")
except Exception as e:
self.logger.error(f"Failed to reconnect after restart: {e}")
# Try a full reset if reconnection failed
try:
await self.stop()
finally:
await self.run()
# @property
async def get_ip(self, max_retries: int = 15, retry_delay: int = 3) -> str:
"""Get the IP address of the VM or localhost if using host computer server.
This method delegates to the provider's get_ip method, which waits indefinitely
This method delegates to the provider's get_ip method, which waits indefinitely
until the VM has a valid IP address.
Args:
max_retries: Unused parameter, kept for backward compatibility
retry_delay: Delay between retries in seconds (default: 2)
Returns:
IP address of the VM or localhost if using host computer server
"""
# For host computer server, always return localhost immediately
if self.use_host_computer_server:
return "127.0.0.1"
# Get IP from the provider - each provider implements its own waiting logic
if self.config.vm_provider is None:
raise RuntimeError("VM provider is not initialized")
# Log that we're waiting for the IP
self.logger.info(f"Waiting for VM {self.config.name} to get an IP address...")
# Call the provider's get_ip method which will wait indefinitely
storage_param = "ephemeral" if self.ephemeral else self.storage
# Log the image being used
self.logger.info(f"Running VM using image: {self.image}")
# Call provider.get_ip with explicit image parameter
ip = await self.config.vm_provider.get_ip(
name=self.config.name,
storage=storage_param,
retry_delay=retry_delay
name=self.config.name, storage=storage_param, retry_delay=retry_delay
)
# Log success
self.logger.info(f"VM {self.config.name} has IP address: {ip}")
return ip
async def wait_vm_ready(self) -> Optional[Dict[str, Any]]:
"""Wait for VM to be ready with an IP address.
@@ -687,8 +820,8 @@ class Computer:
if self.config.vm_provider is not None:
vm = await self.config.vm_provider.get_vm(self.config.name)
# VM data is returned as a dictionary from the Lumier provider
status = vm.get('status', 'unknown') if vm else "unknown"
ip = vm.get('ip_address') if vm else None
status = vm.get("status", "unknown") if vm else "unknown"
ip = vm.get("ip_address") if vm else None
else:
status = "unknown"
ip = None
@@ -705,16 +838,13 @@ class Computer:
self.logger.info(
f"Updating VM settings: CPU={cpu or self.config.cpu}, Memory={memory or self.config.memory}"
)
update_opts = {
"cpu": cpu or int(self.config.cpu),
"memory": memory or self.config.memory
}
update_opts = {"cpu": cpu or int(self.config.cpu), "memory": memory or self.config.memory}
if self.config.vm_provider is not None:
await self.config.vm_provider.update_vm(
name=self.config.name,
update_opts=update_opts,
storage=self.storage # Pass storage explicitly for clarity
)
await self.config.vm_provider.update_vm(
name=self.config.name,
update_opts=update_opts,
storage=self.storage, # Pass storage explicitly for clarity
)
else:
raise RuntimeError("VM provider not initialized")
@@ -781,65 +911,94 @@ class Computer:
"""
return await self.interface.to_screenshot_coordinates(x, y)
# Add virtual environment management functions to computer interface
async def venv_install(self, venv_name: str, requirements: list[str]):
"""Install packages in a virtual environment.
Args:
venv_name: Name of the virtual environment
requirements: List of package requirements to install
Returns:
Tuple of (stdout, stderr) from the installation command
"""
requirements = requirements or []
# Windows vs POSIX handling
if self.os_type == "windows":
# Use %USERPROFILE% for home directory and cmd.exe semantics
venv_path = f"%USERPROFILE%\\.venvs\\{venv_name}"
ensure_dir_cmd = 'if not exist "%USERPROFILE%\\.venvs" mkdir "%USERPROFILE%\\.venvs"'
create_cmd = f'if not exist "{venv_path}" python -m venv "{venv_path}"'
requirements_str = " ".join(requirements)
# Activate via activate.bat and install
install_cmd = (
f'call "{venv_path}\\Scripts\\activate.bat" && pip install {requirements_str}'
if requirements_str
else "echo No requirements to install"
)
await self.interface.run_command(ensure_dir_cmd)
await self.interface.run_command(create_cmd)
return await self.interface.run_command(install_cmd)
else:
# POSIX (macOS/Linux)
venv_path = f"$HOME/.venvs/{venv_name}"
create_cmd = f'mkdir -p "$HOME/.venvs" && python3 -m venv "{venv_path}"'
# Check if venv exists, if not create it
check_cmd = f'test -d "{venv_path}" || ({create_cmd})'
_ = await self.interface.run_command(check_cmd)
# Install packages
requirements_str = " ".join(requirements)
install_cmd = (
f'. "{venv_path}/bin/activate" && pip install {requirements_str}'
if requirements_str
else "echo No requirements to install"
)
return await self.interface.run_command(install_cmd)
# Create virtual environment if it doesn't exist
venv_path = f"~/.venvs/{venv_name}"
create_cmd = f"mkdir -p ~/.venvs && python3 -m venv {venv_path}"
# Check if venv exists, if not create it
check_cmd = f"test -d {venv_path} || ({create_cmd})"
_ = await self.interface.run_command(check_cmd)
# Install packages
requirements_str = " ".join(requirements)
install_cmd = f". {venv_path}/bin/activate && pip install {requirements_str}"
return await self.interface.run_command(install_cmd)
async def venv_cmd(self, venv_name: str, command: str):
"""Execute a shell command in a virtual environment.
Args:
venv_name: Name of the virtual environment
command: Shell command to execute in the virtual environment
Returns:
Tuple of (stdout, stderr) from the command execution
"""
venv_path = f"~/.venvs/{venv_name}"
# Check if virtual environment exists
check_cmd = f"test -d {venv_path}"
result = await self.interface.run_command(check_cmd)
if result.stderr or "test:" in result.stdout: # venv doesn't exist
return "", f"Virtual environment '{venv_name}' does not exist. Create it first using venv_install."
# Activate virtual environment and run command
full_command = f". {venv_path}/bin/activate && {command}"
return await self.interface.run_command(full_command)
if self.os_type == "windows":
# Windows (cmd.exe)
venv_path = f"%USERPROFILE%\\.venvs\\{venv_name}"
# Check existence and signal if missing
check_cmd = f'if not exist "{venv_path}" (echo VENV_NOT_FOUND) else (echo VENV_FOUND)'
result = await self.interface.run_command(check_cmd)
if "VENV_NOT_FOUND" in getattr(result, "stdout", ""):
# Auto-create the venv with no requirements
await self.venv_install(venv_name, [])
# Activate and run the command
full_command = f'call "{venv_path}\\Scripts\\activate.bat" && {command}'
return await self.interface.run_command(full_command)
else:
# POSIX (macOS/Linux)
venv_path = f"$HOME/.venvs/{venv_name}"
# Check if virtual environment exists
check_cmd = f'test -d "{venv_path}"'
result = await self.interface.run_command(check_cmd)
if result.stderr or "test:" in result.stdout: # venv doesn't exist
# Auto-create the venv with no requirements
await self.venv_install(venv_name, [])
# Activate virtual environment and run command
full_command = f'. "{venv_path}/bin/activate" && {command}'
return await self.interface.run_command(full_command)
async def venv_exec(self, venv_name: str, python_func, *args, **kwargs):
"""Execute Python function in a virtual environment using source code extraction.
Args:
venv_name: Name of the virtual environment
python_func: A callable function to execute
*args: Positional arguments to pass to the function
**kwargs: Keyword arguments to pass to the function
Returns:
The result of the function execution, or raises any exception that occurred
"""
@@ -847,29 +1006,29 @@ class Computer:
import inspect
import json
import textwrap
try:
# Get function source code using inspect.getsource
source = inspect.getsource(python_func)
# Remove common leading whitespace (dedent)
func_source = textwrap.dedent(source).strip()
# Remove decorators
while func_source.lstrip().startswith("@"):
func_source = func_source.split("\n", 1)[1].strip()
# Get function name for execution
func_name = python_func.__name__
# Serialize args and kwargs as JSON (safer than dill for cross-version compatibility)
args_json = json.dumps(args, default=str)
kwargs_json = json.dumps(kwargs, default=str)
except OSError as e:
raise Exception(f"Cannot retrieve source code for function {python_func.__name__}: {e}")
except Exception as e:
raise Exception(f"Failed to reconstruct function source: {e}")
# Create Python code that will define and execute the function
python_code = f'''
import json
@@ -914,25 +1073,27 @@ output_json = json.dumps(output_payload, default=str)
# Print the JSON output with markers
print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
'''
# Encode the Python code in base64 to avoid shell escaping issues
encoded_code = base64.b64encode(python_code.encode('utf-8')).decode('ascii')
encoded_code = base64.b64encode(python_code.encode("utf-8")).decode("ascii")
# Execute the Python code in the virtual environment
python_command = f"python -c \"import base64; exec(base64.b64decode('{encoded_code}').decode('utf-8'))\""
python_command = (
f"python -c \"import base64; exec(base64.b64decode('{encoded_code}').decode('utf-8'))\""
)
result = await self.venv_cmd(venv_name, python_command)
# Parse the output to extract the payload
start_marker = "<<<VENV_EXEC_START>>>"
end_marker = "<<<VENV_EXEC_END>>>"
# Print original stdout
print(result.stdout[:result.stdout.find(start_marker)])
print(result.stdout[: result.stdout.find(start_marker)])
if start_marker in result.stdout and end_marker in result.stdout:
start_idx = result.stdout.find(start_marker) + len(start_marker)
end_idx = result.stdout.find(end_marker)
if start_idx < end_idx:
output_json = result.stdout[start_idx:end_idx]
@@ -941,7 +1102,7 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
output_payload = json.loads(output_json)
except Exception as e:
raise Exception(f"Failed to decode output payload: {e}")
if output_payload["success"]:
return output_payload["result"]
else:
@@ -953,4 +1114,6 @@ print(f"<<<VENV_EXEC_START>>>{{output_json}}<<<VENV_EXEC_END>>>")
raise Exception("Invalid output format: markers found but no content between them")
else:
# Fallback: return stdout/stderr if no payload markers found
raise Exception(f"No output payload found. stdout: {result.stdout}, stderr: {result.stderr}")
raise Exception(
f"No output payload found. stdout: {result.stdout}, stderr: {result.stderr}"
)

View File

@@ -1,14 +1,17 @@
import asyncio
from .interface.models import KeyType, Key
from .interface.models import Key, KeyType
class DioramaComputer:
"""
A Computer-compatible proxy for Diorama that sends commands over the ComputerInterface.
"""
def __init__(self, computer, apps):
"""
Initialize the DioramaComputer with a computer instance and list of apps.
Args:
computer: The computer instance to proxy commands through
apps: List of applications available in the diorama environment
@@ -21,7 +24,7 @@ class DioramaComputer:
async def __aenter__(self):
"""
Async context manager entry point.
Returns:
self: The DioramaComputer instance
"""
@@ -31,7 +34,7 @@ class DioramaComputer:
async def run(self):
"""
Initialize and run the DioramaComputer if not already initialized.
Returns:
self: The DioramaComputer instance
"""
@@ -39,14 +42,16 @@ class DioramaComputer:
await self.__aenter__()
return self
class DioramaComputerInterface:
"""
Diorama Interface proxy that sends diorama_cmds via the Computer's interface.
"""
def __init__(self, computer, apps):
"""
Initialize the DioramaComputerInterface.
Args:
computer: The computer instance to send commands through
apps: List of applications available in the diorama environment
@@ -58,14 +63,14 @@ class DioramaComputerInterface:
async def _send_cmd(self, action, arguments=None):
"""
Send a command to the diorama interface through the computer.
Args:
action (str): The action/command to execute
arguments (dict, optional): Additional arguments for the command
Returns:
The result from the diorama command execution
Raises:
RuntimeError: If the computer interface is not initialized or command fails
"""
@@ -77,25 +82,30 @@ class DioramaComputerInterface:
raise RuntimeError("Computer interface not initialized. Call run() first.")
result = await iface.diorama_cmd(action, arguments)
if not result.get("success"):
raise RuntimeError(f"Diorama command failed: {result.get('error')}\n{result.get('trace')}")
raise RuntimeError(
f"Diorama command failed: {result.get('error')}\n{result.get('trace')}"
)
return result.get("result")
async def screenshot(self, as_bytes=True):
"""
Take a screenshot of the diorama scene.
Args:
as_bytes (bool): If True, return image as bytes; if False, return PIL Image object
Returns:
bytes or PIL.Image: Screenshot data in the requested format
"""
from PIL import Image
import base64
from PIL import Image
result = await self._send_cmd("screenshot")
# assume result is a b64 string of an image
img_bytes = base64.b64decode(result)
import io
img = Image.open(io.BytesIO(img_bytes))
self._scene_size = img.size
return img_bytes if as_bytes else img
@@ -103,7 +113,7 @@ class DioramaComputerInterface:
async def get_screen_size(self):
"""
Get the dimensions of the diorama scene.
Returns:
dict: Dictionary containing 'width' and 'height' keys with pixel dimensions
"""
@@ -114,7 +124,7 @@ class DioramaComputerInterface:
async def move_cursor(self, x, y):
"""
Move the cursor to the specified coordinates.
Args:
x (int): X coordinate to move cursor to
y (int): Y coordinate to move cursor to
@@ -124,7 +134,7 @@ class DioramaComputerInterface:
async def left_click(self, x=None, y=None):
"""
Perform a left mouse click at the specified coordinates or current cursor position.
Args:
x (int, optional): X coordinate to click at. If None, clicks at current cursor position
y (int, optional): Y coordinate to click at. If None, clicks at current cursor position
@@ -134,7 +144,7 @@ class DioramaComputerInterface:
async def right_click(self, x=None, y=None):
"""
Perform a right mouse click at the specified coordinates or current cursor position.
Args:
x (int, optional): X coordinate to click at. If None, clicks at current cursor position
y (int, optional): Y coordinate to click at. If None, clicks at current cursor position
@@ -144,7 +154,7 @@ class DioramaComputerInterface:
async def double_click(self, x=None, y=None):
"""
Perform a double mouse click at the specified coordinates or current cursor position.
Args:
x (int, optional): X coordinate to double-click at. If None, clicks at current cursor position
y (int, optional): Y coordinate to double-click at. If None, clicks at current cursor position
@@ -154,7 +164,7 @@ class DioramaComputerInterface:
async def scroll_up(self, clicks=1):
"""
Scroll up by the specified number of clicks.
Args:
clicks (int): Number of scroll clicks to perform upward. Defaults to 1
"""
@@ -163,7 +173,7 @@ class DioramaComputerInterface:
async def scroll_down(self, clicks=1):
"""
Scroll down by the specified number of clicks.
Args:
clicks (int): Number of scroll clicks to perform downward. Defaults to 1
"""
@@ -172,7 +182,7 @@ class DioramaComputerInterface:
async def drag_to(self, x, y, duration=0.5):
"""
Drag from the current cursor position to the specified coordinates.
Args:
x (int): X coordinate to drag to
y (int): Y coordinate to drag to
@@ -183,7 +193,7 @@ class DioramaComputerInterface:
async def get_cursor_position(self):
"""
Get the current cursor position.
Returns:
dict: Dictionary containing the current cursor coordinates
"""
@@ -192,7 +202,7 @@ class DioramaComputerInterface:
async def type_text(self, text):
"""
Type the specified text at the current cursor position.
Args:
text (str): The text to type
"""
@@ -201,7 +211,7 @@ class DioramaComputerInterface:
async def press_key(self, key):
"""
Press a single key.
Args:
key: The key to press
"""
@@ -210,10 +220,10 @@ class DioramaComputerInterface:
async def hotkey(self, *keys):
"""
Press multiple keys simultaneously as a hotkey combination.
Args:
*keys: Variable number of keys to press together. Can be Key enum instances or strings
Raises:
ValueError: If any key is not a Key enum or string type
"""
@@ -224,7 +234,9 @@ class DioramaComputerInterface:
elif isinstance(key, str):
# Try to convert to enum if it matches a known key
key_or_enum = Key.from_string(key)
actual_keys.append(key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum)
actual_keys.append(
key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum
)
else:
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
await self._send_cmd("hotkey", {"keys": actual_keys})
@@ -232,11 +244,11 @@ class DioramaComputerInterface:
async def to_screen_coordinates(self, x, y):
"""
Convert coordinates to screen coordinates.
Args:
x (int): X coordinate to convert
y (int): Y coordinate to convert
Returns:
dict: Dictionary containing the converted screen coordinates
"""

View File

@@ -1,8 +1,9 @@
"""
Helper functions and decorators for the Computer module.
"""
import logging
import asyncio
import logging
from functools import wraps
from typing import Any, Callable, Optional, TypeVar, cast
@@ -11,10 +12,11 @@ _default_computer = None
logger = logging.getLogger(__name__)
def set_default_computer(computer):
"""
Set the default computer instance to be used by the remote decorator.
Args:
computer: The computer instance to use as default
"""
@@ -25,21 +27,24 @@ def set_default_computer(computer):
def sandboxed(venv_name: str = "default", computer: str = "default", max_retries: int = 3):
"""
Decorator that wraps a function to be executed remotely via computer.venv_exec
Args:
venv_name: Name of the virtual environment to execute in
computer: The computer instance to use, or "default" to use the globally set default
max_retries: Maximum number of retries for the remote execution
"""
def decorator(func):
@wraps(func)
async def wrapper(*args, **kwargs):
# Determine which computer instance to use
comp = computer if computer != "default" else _default_computer
if comp is None:
raise RuntimeError("No computer instance available. Either specify a computer instance or call set_default_computer() first.")
raise RuntimeError(
"No computer instance available. Either specify a computer instance or call set_default_computer() first."
)
for i in range(max_retries):
try:
return await comp.venv_exec(venv_name, func, *args, **kwargs)
@@ -48,5 +53,7 @@ def sandboxed(venv_name: str = "default", computer: str = "default", max_retries
await asyncio.sleep(1)
if i == max_retries - 1:
raise e
return wrapper
return decorator

View File

@@ -2,12 +2,12 @@
Interface package for Computer SDK.
"""
from .factory import InterfaceFactory
from .base import BaseComputerInterface
from .factory import InterfaceFactory
from .macos import MacOSComputerInterface
__all__ = [
"InterfaceFactory",
"BaseComputerInterface",
"MacOSComputerInterface",
]
]

View File

@@ -1,14 +1,23 @@
"""Base interface for computer control."""
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, Tuple, List
from typing import Any, Dict, List, Optional, Tuple
from ..logger import Logger, LogLevel
from .models import MouseButton, CommandResult
from .models import CommandResult, MouseButton
class BaseComputerInterface(ABC):
"""Base class for computer control interfaces."""
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
def __init__(
self,
ip_address: str,
username: str = "lume",
password: str = "lume",
api_key: Optional[str] = None,
vm_name: Optional[str] = None,
):
"""Initialize interface.
Args:
@@ -24,7 +33,7 @@ class BaseComputerInterface(ABC):
self.api_key = api_key
self.vm_name = vm_name
self.logger = Logger("cua.interface", LogLevel.NORMAL)
# Optional default delay time between commands (in seconds)
self.delay: float = 0.0
@@ -55,9 +64,15 @@ class BaseComputerInterface(ABC):
# Mouse Actions
@abstractmethod
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: "MouseButton" = "left", delay: Optional[float] = None) -> None:
async def mouse_down(
self,
x: Optional[int] = None,
y: Optional[int] = None,
button: "MouseButton" = "left",
delay: Optional[float] = None,
) -> None:
"""Press and hold a mouse button.
Args:
x: X coordinate to press at. If None, uses current cursor position.
y: Y coordinate to press at. If None, uses current cursor position.
@@ -65,11 +80,17 @@ class BaseComputerInterface(ABC):
delay: Optional delay in seconds after the action
"""
pass
@abstractmethod
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: "MouseButton" = "left", delay: Optional[float] = None) -> None:
async def mouse_up(
self,
x: Optional[int] = None,
y: Optional[int] = None,
button: "MouseButton" = "left",
delay: Optional[float] = None,
) -> None:
"""Release a mouse button.
Args:
x: X coordinate to release at. If None, uses current cursor position.
y: Y coordinate to release at. If None, uses current cursor position.
@@ -77,11 +98,13 @@ class BaseComputerInterface(ABC):
delay: Optional delay in seconds after the action
"""
pass
@abstractmethod
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None) -> None:
async def left_click(
self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None
) -> None:
"""Perform a left mouse button click.
Args:
x: X coordinate to click at. If None, uses current cursor position.
y: Y coordinate to click at. If None, uses current cursor position.
@@ -90,9 +113,11 @@ class BaseComputerInterface(ABC):
pass
@abstractmethod
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None) -> None:
async def right_click(
self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None
) -> None:
"""Perform a right mouse button click.
Args:
x: X coordinate to click at. If None, uses current cursor position.
y: Y coordinate to click at. If None, uses current cursor position.
@@ -101,9 +126,11 @@ class BaseComputerInterface(ABC):
pass
@abstractmethod
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None) -> None:
async def double_click(
self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None
) -> None:
"""Perform a double left mouse button click.
Args:
x: X coordinate to double-click at. If None, uses current cursor position.
y: Y coordinate to double-click at. If None, uses current cursor position.
@@ -114,7 +141,7 @@ class BaseComputerInterface(ABC):
@abstractmethod
async def move_cursor(self, x: int, y: int, delay: Optional[float] = None) -> None:
"""Move the cursor to the specified screen coordinates.
Args:
x: X coordinate to move cursor to.
y: Y coordinate to move cursor to.
@@ -123,7 +150,14 @@ class BaseComputerInterface(ABC):
pass
@abstractmethod
async def drag_to(self, x: int, y: int, button: str = "left", duration: float = 0.5, delay: Optional[float] = None) -> None:
async def drag_to(
self,
x: int,
y: int,
button: str = "left",
duration: float = 0.5,
delay: Optional[float] = None,
) -> None:
"""Drag from current position to specified coordinates.
Args:
@@ -136,7 +170,13 @@ class BaseComputerInterface(ABC):
pass
@abstractmethod
async def drag(self, path: List[Tuple[int, int]], button: str = "left", duration: float = 0.5, delay: Optional[float] = None) -> None:
async def drag(
self,
path: List[Tuple[int, int]],
button: str = "left",
duration: float = 0.5,
delay: Optional[float] = None,
) -> None:
"""Drag the cursor along a path of coordinates.
Args:
@@ -151,27 +191,27 @@ class BaseComputerInterface(ABC):
@abstractmethod
async def key_down(self, key: str, delay: Optional[float] = None) -> None:
"""Press and hold a key.
Args:
key: The key to press and hold (e.g., 'a', 'shift', 'ctrl').
delay: Optional delay in seconds after the action.
"""
pass
@abstractmethod
async def key_up(self, key: str, delay: Optional[float] = None) -> None:
"""Release a previously pressed key.
Args:
key: The key to release (e.g., 'a', 'shift', 'ctrl').
delay: Optional delay in seconds after the action.
"""
pass
@abstractmethod
async def type_text(self, text: str, delay: Optional[float] = None) -> None:
"""Type the specified text string.
Args:
text: The text string to type.
delay: Optional delay in seconds after the action.
@@ -181,7 +221,7 @@ class BaseComputerInterface(ABC):
@abstractmethod
async def press_key(self, key: str, delay: Optional[float] = None) -> None:
"""Press and release a single key.
Args:
key: The key to press (e.g., 'a', 'enter', 'escape').
delay: Optional delay in seconds after the action.
@@ -191,7 +231,7 @@ class BaseComputerInterface(ABC):
@abstractmethod
async def hotkey(self, *keys: str, delay: Optional[float] = None) -> None:
"""Press multiple keys simultaneously (keyboard shortcut).
Args:
*keys: Variable number of keys to press together (e.g., 'ctrl', 'c').
delay: Optional delay in seconds after the action.
@@ -202,18 +242,18 @@ class BaseComputerInterface(ABC):
@abstractmethod
async def scroll(self, x: int, y: int, delay: Optional[float] = None) -> None:
"""Scroll the mouse wheel by specified amounts.
Args:
x: Horizontal scroll amount (positive = right, negative = left).
y: Vertical scroll amount (positive = up, negative = down).
delay: Optional delay in seconds after the action.
"""
pass
@abstractmethod
async def scroll_down(self, clicks: int = 1, delay: Optional[float] = None) -> None:
"""Scroll down by the specified number of clicks.
Args:
clicks: Number of scroll clicks to perform downward.
delay: Optional delay in seconds after the action.
@@ -223,7 +263,7 @@ class BaseComputerInterface(ABC):
@abstractmethod
async def scroll_up(self, clicks: int = 1, delay: Optional[float] = None) -> None:
"""Scroll up by the specified number of clicks.
Args:
clicks: Number of scroll clicks to perform upward.
delay: Optional delay in seconds after the action.
@@ -252,7 +292,7 @@ class BaseComputerInterface(ABC):
@abstractmethod
async def get_cursor_position(self) -> Dict[str, int]:
"""Get the current cursor position on screen.
Returns:
Dict with 'x' and 'y' keys containing cursor coordinates.
"""
@@ -262,7 +302,7 @@ class BaseComputerInterface(ABC):
@abstractmethod
async def copy_to_clipboard(self) -> str:
"""Get the current clipboard content.
Returns:
The text content currently stored in the clipboard.
"""
@@ -271,7 +311,7 @@ class BaseComputerInterface(ABC):
@abstractmethod
async def set_clipboard(self, text: str) -> None:
"""Set the clipboard content to the specified text.
Args:
text: The text to store in the clipboard.
"""
@@ -281,10 +321,10 @@ class BaseComputerInterface(ABC):
@abstractmethod
async def file_exists(self, path: str) -> bool:
"""Check if a file exists at the specified path.
Args:
path: The file path to check.
Returns:
True if the file exists, False otherwise.
"""
@@ -293,128 +333,128 @@ class BaseComputerInterface(ABC):
@abstractmethod
async def directory_exists(self, path: str) -> bool:
"""Check if a directory exists at the specified path.
Args:
path: The directory path to check.
Returns:
True if the directory exists, False otherwise.
"""
pass
@abstractmethod
async def list_dir(self, path: str) -> List[str]:
"""List the contents of a directory.
Args:
path: The directory path to list.
Returns:
List of file and directory names in the specified directory.
"""
pass
@abstractmethod
async def read_text(self, path: str) -> str:
"""Read the text contents of a file.
Args:
path: The file path to read from.
Returns:
The text content of the file.
"""
pass
@abstractmethod
async def write_text(self, path: str, content: str) -> None:
"""Write text content to a file.
Args:
path: The file path to write to.
content: The text content to write.
"""
pass
@abstractmethod
async def read_bytes(self, path: str, offset: int = 0, length: Optional[int] = None) -> bytes:
"""Read file binary contents with optional seeking support.
Args:
path: Path to the file
offset: Byte offset to start reading from (default: 0)
length: Number of bytes to read (default: None for entire file)
"""
pass
@abstractmethod
async def write_bytes(self, path: str, content: bytes) -> None:
"""Write binary content to a file.
Args:
path: The file path to write to.
content: The binary content to write.
"""
pass
@abstractmethod
async def delete_file(self, path: str) -> None:
"""Delete a file at the specified path.
Args:
path: The file path to delete.
"""
pass
@abstractmethod
async def create_dir(self, path: str) -> None:
"""Create a directory at the specified path.
Args:
path: The directory path to create.
"""
pass
@abstractmethod
async def delete_dir(self, path: str) -> None:
"""Delete a directory at the specified path.
Args:
path: The directory path to delete.
"""
pass
@abstractmethod
async def get_file_size(self, path: str) -> int:
"""Get the size of a file in bytes.
Args:
path: The file path to get the size of.
Returns:
The size of the file in bytes.
"""
pass
@abstractmethod
async def run_command(self, command: str) -> CommandResult:
"""Run shell command and return structured result.
Executes a shell command using subprocess.run with shell=True and check=False.
The command is run in the target environment and captures both stdout and stderr.
Args:
command (str): The shell command to execute
Returns:
CommandResult: A structured result containing:
- stdout (str): Standard output from the command
- stderr (str): Standard error from the command
- stderr (str): Standard error from the command
- returncode (int): Exit code from the command (0 indicates success)
Raises:
RuntimeError: If the command execution fails at the system level
Example:
result = await interface.run_command("ls -la")
if result.returncode == 0:
@@ -428,12 +468,12 @@ class BaseComputerInterface(ABC):
@abstractmethod
async def get_accessibility_tree(self) -> Dict:
"""Get the accessibility tree of the current screen.
Returns:
Dict containing the hierarchical accessibility information of screen elements.
"""
pass
@abstractmethod
async def to_screen_coordinates(self, x: float, y: float) -> tuple[float, float]:
"""Convert screenshot coordinates to screen coordinates.

View File

@@ -1,42 +1,44 @@
"""Factory for creating computer interfaces."""
from typing import Literal, Optional
from .base import BaseComputerInterface
class InterfaceFactory:
"""Factory for creating OS-specific computer interfaces."""
@staticmethod
def create_interface_for_os(
os: Literal['macos', 'linux', 'windows'],
os: Literal["macos", "linux", "windows"],
ip_address: str,
api_key: Optional[str] = None,
vm_name: Optional[str] = None
vm_name: Optional[str] = None,
) -> BaseComputerInterface:
"""Create an interface for the specified OS.
Args:
os: Operating system type ('macos', 'linux', or 'windows')
ip_address: IP address of the computer to control
api_key: Optional API key for cloud authentication
vm_name: Optional VM name for cloud authentication
Returns:
BaseComputerInterface: The appropriate interface for the OS
Raises:
ValueError: If the OS type is not supported
"""
# Import implementations here to avoid circular imports
from .macos import MacOSComputerInterface
from .linux import LinuxComputerInterface
from .macos import MacOSComputerInterface
from .windows import WindowsComputerInterface
if os == 'macos':
if os == "macos":
return MacOSComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
elif os == 'linux':
elif os == "linux":
return LinuxComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
elif os == 'windows':
elif os == "windows":
return WindowsComputerInterface(ip_address, api_key=api_key, vm_name=vm_name)
else:
raise ValueError(f"Unsupported OS type: {os}")

View File

@@ -2,21 +2,35 @@ import asyncio
import json
import time
from typing import Any, Dict, List, Optional, Tuple
import aiohttp
import websockets
from PIL import Image
import websockets
import aiohttp
from ..logger import Logger, LogLevel
from ..utils import (
bytes_to_image,
decode_base64_image,
draw_box,
encode_base64_image,
resize_image,
)
from .base import BaseComputerInterface
from ..utils import decode_base64_image, encode_base64_image, bytes_to_image, draw_box, resize_image
from .models import Key, KeyType, MouseButton, CommandResult
from .models import CommandResult, Key, KeyType, MouseButton
class GenericComputerInterface(BaseComputerInterface):
"""Generic interface with common functionality for all supported platforms (Windows, Linux, macOS)."""
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None, logger_name: str = "computer.interface.generic"):
def __init__(
self,
ip_address: str,
username: str = "lume",
password: str = "lume",
api_key: Optional[str] = None,
vm_name: Optional[str] = None,
logger_name: str = "computer.interface.generic",
):
super().__init__(ip_address, username, password, api_key, vm_name)
self._ws = None
self._reconnect_task = None
@@ -38,7 +52,7 @@ class GenericComputerInterface(BaseComputerInterface):
async def _handle_delay(self, delay: Optional[float] = None):
"""Handle delay between commands using async sleep.
Args:
delay: Optional delay in seconds. If None, uses self.delay.
"""
@@ -51,18 +65,18 @@ class GenericComputerInterface(BaseComputerInterface):
@property
def ws_uri(self) -> str:
"""Get the WebSocket URI using the current IP address.
Returns:
WebSocket URI for the Computer API Server
"""
protocol = "wss" if self.api_key else "ws"
port = "8443" if self.api_key else "8000"
return f"{protocol}://{self.ip_address}:{port}/ws"
@property
def rest_uri(self) -> str:
"""Get the REST URI using the current IP address.
Returns:
REST URI for the Computer API Server
"""
@@ -71,23 +85,41 @@ class GenericComputerInterface(BaseComputerInterface):
return f"{protocol}://{self.ip_address}:{port}/cmd"
# Mouse actions
async def mouse_down(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left", delay: Optional[float] = None) -> None:
async def mouse_down(
self,
x: Optional[int] = None,
y: Optional[int] = None,
button: str = "left",
delay: Optional[float] = None,
) -> None:
await self._send_command("mouse_down", {"x": x, "y": y, "button": button})
await self._handle_delay(delay)
async def mouse_up(self, x: Optional[int] = None, y: Optional[int] = None, button: str = "left", delay: Optional[float] = None) -> None:
async def mouse_up(
self,
x: Optional[int] = None,
y: Optional[int] = None,
button: str = "left",
delay: Optional[float] = None,
) -> None:
await self._send_command("mouse_up", {"x": x, "y": y, "button": button})
await self._handle_delay(delay)
async def left_click(self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None) -> None:
async def left_click(
self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None
) -> None:
await self._send_command("left_click", {"x": x, "y": y})
await self._handle_delay(delay)
async def right_click(self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None) -> None:
async def right_click(
self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None
) -> None:
await self._send_command("right_click", {"x": x, "y": y})
await self._handle_delay(delay)
async def double_click(self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None) -> None:
async def double_click(
self, x: Optional[int] = None, y: Optional[int] = None, delay: Optional[float] = None
) -> None:
await self._send_command("double_click", {"x": x, "y": y})
await self._handle_delay(delay)
@@ -95,37 +127,40 @@ class GenericComputerInterface(BaseComputerInterface):
await self._send_command("move_cursor", {"x": x, "y": y})
await self._handle_delay(delay)
async def drag_to(self, x: int, y: int, button: "MouseButton" = "left", duration: float = 0.5, delay: Optional[float] = None) -> None:
async def drag_to(
self,
x: int,
y: int,
button: "MouseButton" = "left",
duration: float = 0.5,
delay: Optional[float] = None,
) -> None:
await self._send_command(
"drag_to", {"x": x, "y": y, "button": button, "duration": duration}
)
await self._handle_delay(delay)
async def drag(self, path: List[Tuple[int, int]], button: "MouseButton" = "left", duration: float = 0.5, delay: Optional[float] = None) -> None:
await self._send_command(
"drag", {"path": path, "button": button, "duration": duration}
)
async def drag(
self,
path: List[Tuple[int, int]],
button: "MouseButton" = "left",
duration: float = 0.5,
delay: Optional[float] = None,
) -> None:
await self._send_command("drag", {"path": path, "button": button, "duration": duration})
await self._handle_delay(delay)
# Keyboard Actions
async def key_down(self, key: "KeyType", delay: Optional[float] = None) -> None:
await self._send_command("key_down", {"key": key})
await self._handle_delay(delay)
async def key_up(self, key: "KeyType", delay: Optional[float] = None) -> None:
await self._send_command("key_up", {"key": key})
await self._handle_delay(delay)
async def type_text(self, text: str, delay: Optional[float] = None) -> None:
# Temporary fix for https://github.com/trycua/cua/issues/165
# Check if text contains Unicode characters
if any(ord(char) > 127 for char in text):
# For Unicode text, use clipboard and paste
await self.set_clipboard(text)
await self.hotkey(Key.COMMAND, 'v')
else:
# For ASCII text, use the regular typing method
await self._send_command("type_text", {"text": text})
await self._send_command("type_text", {"text": text})
await self._handle_delay(delay)
async def press(self, key: "KeyType", delay: Optional[float] = None) -> None:
@@ -203,10 +238,12 @@ class GenericComputerInterface(BaseComputerInterface):
elif isinstance(key, str):
# Try to convert to enum if it matches a known key
key_or_enum = Key.from_string(key)
actual_keys.append(key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum)
actual_keys.append(
key_or_enum.value if isinstance(key_or_enum, Key) else key_or_enum
)
else:
raise ValueError(f"Invalid key type: {type(key)}. Must be Key enum or string.")
await self._send_command("hotkey", {"keys": actual_keys})
await self._handle_delay(delay)
@@ -214,11 +251,11 @@ class GenericComputerInterface(BaseComputerInterface):
async def scroll(self, x: int, y: int, delay: Optional[float] = None) -> None:
await self._send_command("scroll", {"x": x, "y": y})
await self._handle_delay(delay)
async def scroll_down(self, clicks: int = 1, delay: Optional[float] = None) -> None:
await self._send_command("scroll_down", {"clicks": clicks})
await self._handle_delay(delay)
async def scroll_up(self, clicks: int = 1, delay: Optional[float] = None) -> None:
await self._send_command("scroll_up", {"clicks": clicks})
await self._handle_delay(delay)
@@ -302,27 +339,32 @@ class GenericComputerInterface(BaseComputerInterface):
await self._send_command("set_clipboard", {"text": text})
# File Operations
async def _write_bytes_chunked(self, path: str, content: bytes, append: bool = False, chunk_size: int = 1024 * 1024) -> None:
async def _write_bytes_chunked(
self, path: str, content: bytes, append: bool = False, chunk_size: int = 1024 * 1024
) -> None:
"""Write large files in chunks to avoid memory issues."""
total_size = len(content)
current_offset = 0
while current_offset < total_size:
chunk_end = min(current_offset + chunk_size, total_size)
chunk_data = content[current_offset:chunk_end]
# First chunk uses the original append flag, subsequent chunks always append
chunk_append = append if current_offset == 0 else True
result = await self._send_command("write_bytes", {
"path": path,
"content_b64": encode_base64_image(chunk_data),
"append": chunk_append
})
result = await self._send_command(
"write_bytes",
{
"path": path,
"content_b64": encode_base64_image(chunk_data),
"append": chunk_append,
},
)
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to write file chunk"))
current_offset = chunk_end
async def write_bytes(self, path: str, content: bytes, append: bool = False) -> None:
@@ -330,36 +372,39 @@ class GenericComputerInterface(BaseComputerInterface):
if len(content) > 5 * 1024 * 1024: # 5MB threshold
await self._write_bytes_chunked(path, content, append)
return
result = await self._send_command("write_bytes", {"path": path, "content_b64": encode_base64_image(content), "append": append})
result = await self._send_command(
"write_bytes",
{"path": path, "content_b64": encode_base64_image(content), "append": append},
)
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to write file"))
async def _read_bytes_chunked(self, path: str, offset: int, total_length: int, chunk_size: int = 1024 * 1024) -> bytes:
async def _read_bytes_chunked(
self, path: str, offset: int, total_length: int, chunk_size: int = 1024 * 1024
) -> bytes:
"""Read large files in chunks to avoid memory issues."""
chunks = []
current_offset = offset
remaining = total_length
while remaining > 0:
read_size = min(chunk_size, remaining)
result = await self._send_command("read_bytes", {
"path": path,
"offset": current_offset,
"length": read_size
})
result = await self._send_command(
"read_bytes", {"path": path, "offset": current_offset, "length": read_size}
)
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to read file chunk"))
content_b64 = result.get("content_b64", "")
chunk_data = decode_base64_image(content_b64)
chunks.append(chunk_data)
current_offset += read_size
remaining -= read_size
return b''.join(chunks)
return b"".join(chunks)
async def read_bytes(self, path: str, offset: int = 0, length: Optional[int] = None) -> bytes:
# For large files, use chunked reading
@@ -368,34 +413,36 @@ class GenericComputerInterface(BaseComputerInterface):
file_size = await self.get_file_size(path)
# If file is larger than 5MB, read in chunks
if file_size > 5 * 1024 * 1024: # 5MB threshold
return await self._read_bytes_chunked(path, offset, file_size - offset if offset > 0 else file_size)
result = await self._send_command("read_bytes", {
"path": path,
"offset": offset,
"length": length
})
return await self._read_bytes_chunked(
path, offset, file_size - offset if offset > 0 else file_size
)
result = await self._send_command(
"read_bytes", {"path": path, "offset": offset, "length": length}
)
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to read file"))
content_b64 = result.get("content_b64", "")
return decode_base64_image(content_b64)
async def read_text(self, path: str, encoding: str = 'utf-8') -> str:
async def read_text(self, path: str, encoding: str = "utf-8") -> str:
"""Read text from a file with specified encoding.
Args:
path: Path to the file to read
encoding: Text encoding to use (default: 'utf-8')
Returns:
str: The decoded text content of the file
"""
content_bytes = await self.read_bytes(path)
return content_bytes.decode(encoding)
async def write_text(self, path: str, content: str, encoding: str = 'utf-8', append: bool = False) -> None:
async def write_text(
self, path: str, content: str, encoding: str = "utf-8", append: bool = False
) -> None:
"""Write text to a file with specified encoding.
Args:
path: Path to the file to write
content: Text content to write
@@ -448,7 +495,7 @@ class GenericComputerInterface(BaseComputerInterface):
return CommandResult(
stdout=result.get("stdout", ""),
stderr=result.get("stderr", ""),
returncode=result.get("return_code", 0)
returncode=result.get("return_code", 0),
)
# Accessibility Actions
@@ -458,7 +505,7 @@ class GenericComputerInterface(BaseComputerInterface):
if not result.get("success", False):
raise RuntimeError(result.get("error", "Failed to get accessibility tree"))
return result
async def get_active_window_bounds(self) -> Dict[str, int]:
"""Get the bounds of the currently active window."""
result = await self._send_command("get_active_window_bounds")
@@ -564,33 +611,30 @@ class GenericComputerInterface(BaseComputerInterface):
timeout=120,
)
self.logger.info("WebSocket connection established")
# If api_key and vm_name are provided, perform authentication handshake
if self.api_key and self.vm_name:
self.logger.info("Performing authentication handshake...")
auth_message = {
"command": "authenticate",
"params": {
"api_key": self.api_key,
"container_name": self.vm_name
}
"params": {"api_key": self.api_key, "container_name": self.vm_name},
}
await self._ws.send(json.dumps(auth_message))
# Wait for authentication response
async with self._recv_lock:
auth_response = await asyncio.wait_for(self._ws.recv(), timeout=10)
auth_result = json.loads(auth_response)
if not auth_result.get("success"):
error_msg = auth_result.get("error", "Authentication failed")
self.logger.error(f"Authentication failed: {error_msg}")
await self._ws.close()
self._ws = None
raise ConnectionError(f"Authentication failed: {error_msg}")
self.logger.info("Authentication successful")
self._reconnect_delay = 1 # Reset reconnect delay on successful connection
self._last_ping = time.time()
retry_count = 0 # Reset retry count on successful connection
@@ -600,7 +644,7 @@ class GenericComputerInterface(BaseComputerInterface):
# Only log the first error at WARNING level, then every Nth attempt
if retry_count == 1:
self.logger.warning(
f"Computer API Server not ready yet. Will retry automatically."
"Computer API Server not ready yet. Will retry automatically."
)
elif retry_count % log_interval == 0:
self.logger.warning(
@@ -648,7 +692,7 @@ class GenericComputerInterface(BaseComputerInterface):
# Only log connection lost warnings at most once every min_warning_interval seconds
if current_time - last_warning_time >= min_warning_interval:
self.logger.warning(
f"Computer API Server connection lost. Will retry automatically."
"Computer API Server connection lost. Will retry automatically."
)
last_warning_time = current_time
else:
@@ -661,7 +705,7 @@ class GenericComputerInterface(BaseComputerInterface):
except:
pass
self._ws = None
async def _ensure_connection(self):
"""Ensure WebSocket connection is established."""
if self._reconnect_task is None or self._reconnect_task.done():
@@ -730,32 +774,30 @@ class GenericComputerInterface(BaseComputerInterface):
raise last_error if last_error else RuntimeError("Failed to send command")
async def _send_command_rest(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
async def _send_command_rest(
self, command: str, params: Optional[Dict] = None
) -> Dict[str, Any]:
"""Send command through REST API without retries or connection management."""
try:
# Prepare the request payload
payload = {"command": command, "params": params or {}}
# Prepare headers
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["X-API-Key"] = self.api_key
if self.vm_name:
headers["X-Container-Name"] = self.vm_name
# Send the request
async with aiohttp.ClientSession() as session:
async with session.post(
self.rest_uri,
json=payload,
headers=headers
) as response:
async with session.post(self.rest_uri, json=payload, headers=headers) as response:
# Get the response text
response_text = await response.text()
# Trim whitespace
response_text = response_text.strip()
# Check if it starts with "data: "
if response_text.startswith("data: "):
# Extract everything after "data: "
@@ -766,38 +808,39 @@ class GenericComputerInterface(BaseComputerInterface):
return {
"success": False,
"error": "Server returned malformed response",
"message": response_text
"message": response_text,
}
else:
# Return error response
return {
"success": False,
"error": "Server returned malformed response",
"message": response_text
"message": response_text,
}
except Exception as e:
return {
"success": False,
"error": "Request failed",
"message": str(e)
}
return {"success": False, "error": "Request failed", "message": str(e)}
async def _send_command(self, command: str, params: Optional[Dict] = None) -> Dict[str, Any]:
"""Send command using REST API with WebSocket fallback."""
# Try REST API first
result = await self._send_command_rest(command, params)
# If REST failed with "Request failed", try WebSocket as fallback
if not result.get("success", True) and (result.get("error") == "Request failed" or result.get("error") == "Server returned malformed response"):
self.logger.warning(f"REST API failed for command '{command}', trying WebSocket fallback")
if not result.get("success", True) and (
result.get("error") == "Request failed"
or result.get("error") == "Server returned malformed response"
):
self.logger.warning(
f"REST API failed for command '{command}', trying WebSocket fallback"
)
try:
return await self._send_command_ws(command, params)
except Exception as e:
self.logger.error(f"WebSocket fallback also failed: {e}")
# Return the original REST error
return result
return result
async def wait_for_ready(self, timeout: int = 60, interval: float = 1.0):
@@ -808,7 +851,9 @@ class GenericComputerInterface(BaseComputerInterface):
result = await self._send_command_rest("version", {})
assert result.get("success", True)
except Exception as e:
self.logger.debug(f"REST API failed for command 'version', trying WebSocket fallback: {e}")
self.logger.debug(
f"REST API failed for command 'version', trying WebSocket fallback: {e}"
)
try:
await self._wait_for_ready_ws(timeout, interval)
return
@@ -957,7 +1002,7 @@ class GenericComputerInterface(BaseComputerInterface):
# if self._ws:
# asyncio.create_task(self._ws.close())
# self._ws = None
def force_close(self):
"""Force close the WebSocket connection.
@@ -970,4 +1015,3 @@ class GenericComputerInterface(BaseComputerInterface):
if self._ws:
asyncio.create_task(self._ws.close())
self._ws = None

View File

@@ -1,8 +1,19 @@
from typing import Optional
from .generic import GenericComputerInterface
class LinuxComputerInterface(GenericComputerInterface):
"""Interface for Linux."""
def __init__(self, ip_address: str, username: str = "lume", password: str = "lume", api_key: Optional[str] = None, vm_name: Optional[str] = None):
super().__init__(ip_address, username, password, api_key, vm_name, "computer.interface.linux")
def __init__(
self,
ip_address: str,
username: str = "lume",
password: str = "lume",
api_key: Optional[str] = None,
vm_name: Optional[str] = None,
):
super().__init__(
ip_address, username, password, api_key, vm_name, "computer.interface.linux"
)

Some files were not shown because too many files have changed in this diff Show More