Merge upstream/main to resolve conflicts with trycua/cua

This commit is contained in:
Sarina Li
2025-10-27 16:06:06 -07:00
344 changed files with 30505 additions and 16871 deletions

View File

@@ -0,0 +1,10 @@
[bumpversion]
current_version = 0.4.35
commit = True
tag = True
tag_name = agent-v{new_version}
message = Bump cua-agent to v{new_version}
[bumpversion:file:pyproject.toml]
search = version = "{current_version}"
replace = version = "{new_version}"

View File

@@ -8,10 +8,11 @@
</picture>
</div>
[![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#)
[![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85)
[![PyPI](https://img.shields.io/pypi/v/cua-computer?color=333333)](https://pypi.org/project/cua-computer/)
[![Python](https://img.shields.io/badge/Python-333333?logo=python&logoColor=white&labelColor=333333)](#)
[![macOS](https://img.shields.io/badge/macOS-000000?logo=apple&logoColor=F0F0F0)](#)
[![Discord](https://img.shields.io/badge/Discord-%235865F2.svg?&logo=discord&logoColor=white)](https://discord.com/invite/mVnXXpdE85)
[![PyPI](https://img.shields.io/pypi/v/cua-computer?color=333333)](https://pypi.org/project/cua-computer/)
</h1>
</div>
@@ -47,7 +48,7 @@ async def main():
name=os.getenv("CUA_CONTAINER_NAME"),
api_key=os.getenv("CUA_API_KEY")
) as computer:
# Create agent
agent = ComputerAgent(
model="anthropic/claude-3-5-sonnet-20241022",
@@ -56,10 +57,10 @@ async def main():
trajectory_dir="trajectories",
max_trajectory_budget=5.0 # $5 budget limit
)
# Run agent
messages = [{"role": "user", "content": "Take a screenshot and tell me what you see"}]
async for result in agent.run(messages):
for item in result["output"]:
if item["type"] == "message":
@@ -84,4 +85,4 @@ if __name__ == "__main__":
## License
MIT License - see LICENSE file for details.
MIT License - see LICENSE file for details.

View File

@@ -5,19 +5,13 @@ agent - Decorator-based Computer Use Agent with liteLLM integration
import logging
import sys
from .decorators import register_agent
from .agent import ComputerAgent
from .types import Messages, AgentResponse
# Import loops to register them
from . import loops
from .agent import ComputerAgent
from .decorators import register_agent
from .types import AgentResponse, Messages
__all__ = [
"register_agent",
"ComputerAgent",
"Messages",
"AgentResponse"
]
__all__ = ["register_agent", "ComputerAgent", "Messages", "AgentResponse"]
__version__ = "0.4.0"

View File

@@ -5,8 +5,9 @@ Usage:
python -m agent.cli <model_string>
"""
import sys
import asyncio
import sys
from .cli import main
if __name__ == "__main__":

View File

@@ -2,27 +2,30 @@ import asyncio
import functools
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional
from litellm.types.utils import GenericStreamingChunk, ModelResponse
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
from litellm import acompletion, completion
from litellm.llms.custom_llm import CustomLLM
from litellm import completion, acompletion
from litellm.types.utils import GenericStreamingChunk, ModelResponse
# Try to import HuggingFace dependencies
try:
import torch
from transformers import AutoModelForImageTextToText, AutoProcessor
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
from .models import load_model as load_model_handler
class HuggingFaceLocalAdapter(CustomLLM):
"""HuggingFace Local Adapter for running vision-language models locally."""
def __init__(self, device: str = "auto", trust_remote_code: bool = False, **kwargs):
"""Initialize the adapter.
Args:
device: Device to load model on ("auto", "cuda", "cpu", etc.)
trust_remote_code: Whether to trust remote code
@@ -34,129 +37,120 @@ class HuggingFaceLocalAdapter(CustomLLM):
# Cache for model handlers keyed by model_name
self._handlers: Dict[str, Any] = {}
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
def _get_handler(self, model_name: str):
"""Get or create a model handler for the given model name."""
if model_name not in self._handlers:
self._handlers[model_name] = load_model_handler(model_name=model_name, device=self.device, trust_remote_code=self.trust_remote_code)
self._handlers[model_name] = load_model_handler(
model_name=model_name, device=self.device, trust_remote_code=self.trust_remote_code
)
return self._handlers[model_name]
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert OpenAI format messages to HuggingFace format.
Args:
messages: Messages in OpenAI format
Returns:
Messages in HuggingFace format
"""
converted_messages = []
for message in messages:
converted_message = {
"role": message["role"],
"content": []
}
converted_message = {"role": message["role"], "content": []}
content = message.get("content", [])
if isinstance(content, str):
# Simple text content
converted_message["content"].append({
"type": "text",
"text": content
})
converted_message["content"].append({"type": "text", "text": content})
elif isinstance(content, list):
# Multi-modal content
for item in content:
if item.get("type") == "text":
converted_message["content"].append({
"type": "text",
"text": item.get("text", "")
})
converted_message["content"].append(
{"type": "text", "text": item.get("text", "")}
)
elif item.get("type") == "image_url":
# Convert image_url format to image format
image_url = item.get("image_url", {}).get("url", "")
converted_message["content"].append({
"type": "image",
"image": image_url
})
converted_message["content"].append({"type": "image", "image": image_url})
converted_messages.append(converted_message)
return converted_messages
def _generate(self, **kwargs) -> str:
"""Generate response using the local HuggingFace model.
Args:
**kwargs: Keyword arguments containing messages and model info
Returns:
Generated text response
"""
if not HF_AVAILABLE:
raise ImportError(
"HuggingFace transformers dependencies not found. "
"Please install with: pip install \"cua-agent[uitars-hf]\""
'Please install with: pip install "cua-agent[uitars-hf]"'
)
# Extract messages and model from kwargs
messages = kwargs.get('messages', [])
model_name = kwargs.get('model', 'ByteDance-Seed/UI-TARS-1.5-7B')
max_new_tokens = kwargs.get('max_tokens', 128)
messages = kwargs.get("messages", [])
model_name = kwargs.get("model", "ByteDance-Seed/UI-TARS-1.5-7B")
max_new_tokens = kwargs.get("max_tokens", 128)
# Warn about ignored kwargs
ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
ignored_kwargs = set(kwargs.keys()) - {"messages", "model", "max_tokens"}
if ignored_kwargs:
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
# Convert messages to HuggingFace format
hf_messages = self._convert_messages(messages)
# Delegate to model handler
handler = self._get_handler(model_name)
generated_text = handler.generate(hf_messages, max_new_tokens=max_new_tokens)
return generated_text
def completion(self, *args, **kwargs) -> ModelResponse:
"""Synchronous completion method.
Returns:
ModelResponse with generated text
"""
generated_text = self._generate(**kwargs)
return completion(
model=f"huggingface-local/{kwargs['model']}",
mock_response=generated_text,
)
async def acompletion(self, *args, **kwargs) -> ModelResponse:
"""Asynchronous completion method.
Returns:
ModelResponse with generated text
"""
# Run _generate in thread pool to avoid blocking
loop = asyncio.get_event_loop()
generated_text = await loop.run_in_executor(
self._executor,
functools.partial(self._generate, **kwargs)
self._executor, functools.partial(self._generate, **kwargs)
)
return await acompletion(
model=f"huggingface-local/{kwargs['model']}",
mock_response=generated_text,
)
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
"""Synchronous streaming method.
Returns:
Iterator of GenericStreamingChunk
"""
generated_text = self._generate(**kwargs)
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
@@ -165,22 +159,21 @@ class HuggingFaceLocalAdapter(CustomLLM):
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
yield generic_streaming_chunk
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
"""Asynchronous streaming method.
Returns:
AsyncIterator of GenericStreamingChunk
"""
# Run _generate in thread pool to avoid blocking
loop = asyncio.get_event_loop()
generated_text = await loop.run_in_executor(
self._executor,
functools.partial(self._generate, **kwargs)
self._executor, functools.partial(self._generate, **kwargs)
)
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
@@ -189,5 +182,5 @@ class HuggingFaceLocalAdapter(CustomLLM):
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
yield generic_streaming_chunk
yield generic_streaming_chunk

View File

@@ -1,22 +1,23 @@
import os
import asyncio
import os
from typing import Any, AsyncIterator, Dict, Iterator, List
import requests
from typing import List, Dict, Any, Iterator, AsyncIterator
from litellm.types.utils import GenericStreamingChunk, ModelResponse
from litellm import acompletion, completion
from litellm.llms.custom_llm import CustomLLM
from litellm import completion, acompletion
from litellm.types.utils import GenericStreamingChunk, ModelResponse
class HumanAdapter(CustomLLM):
"""Human Adapter for human-in-the-loop completions.
This adapter sends completion requests to a human completion server
where humans can review and respond to AI requests.
"""
def __init__(self, base_url: str | None = None, timeout: float = 300.0, **kwargs):
"""Initialize the human adapter.
Args:
base_url: Base URL for the human completion server.
Defaults to HUMAN_BASE_URL environment variable or http://localhost:8002
@@ -24,60 +25,58 @@ class HumanAdapter(CustomLLM):
**kwargs: Additional arguments
"""
super().__init__()
self.base_url = base_url or os.getenv('HUMAN_BASE_URL', 'http://localhost:8002')
self.base_url = base_url or os.getenv("HUMAN_BASE_URL", "http://localhost:8002")
self.timeout = timeout
# Ensure base_url doesn't end with slash
self.base_url = self.base_url.rstrip('/')
self.base_url = self.base_url.rstrip("/")
def _queue_completion(self, messages: List[Dict[str, Any]], model: str) -> str:
"""Queue a completion request and return the call ID.
Args:
messages: Messages in OpenAI format
model: Model name
Returns:
Call ID for tracking the request
Raises:
Exception: If queueing fails
"""
try:
response = requests.post(
f"{self.base_url}/queue",
json={"messages": messages, "model": model},
timeout=10
f"{self.base_url}/queue", json={"messages": messages, "model": model}, timeout=10
)
response.raise_for_status()
return response.json()["id"]
except requests.RequestException as e:
raise Exception(f"Failed to queue completion request: {e}")
def _wait_for_completion(self, call_id: str) -> Dict[str, Any]:
"""Wait for human to complete the call.
Args:
call_id: ID of the queued completion call
Returns:
Dict containing response and/or tool_calls
Raises:
TimeoutError: If timeout is exceeded
Exception: If completion fails
"""
import time
start_time = time.time()
while True:
try:
# Check status
status_response = requests.get(f"{self.base_url}/status/{call_id}")
status_response.raise_for_status()
status_data = status_response.json()
if status_data["status"] == "completed":
result = {}
if "response" in status_data and status_data["response"]:
@@ -88,38 +87,41 @@ class HumanAdapter(CustomLLM):
elif status_data["status"] == "failed":
error_msg = status_data.get("error", "Unknown error")
raise Exception(f"Completion failed: {error_msg}")
# Check timeout
if time.time() - start_time > self.timeout:
raise TimeoutError(f"Timeout waiting for human response after {self.timeout} seconds")
raise TimeoutError(
f"Timeout waiting for human response after {self.timeout} seconds"
)
# Wait before checking again
time.sleep(1.0)
except requests.RequestException as e:
if time.time() - start_time > self.timeout:
raise TimeoutError(f"Timeout waiting for human response: {e}")
# Continue trying if we haven't timed out
time.sleep(1.0)
async def _async_wait_for_completion(self, call_id: str) -> Dict[str, Any]:
"""Async version of wait_for_completion.
Args:
call_id: ID of the queued completion call
Returns:
Dict containing response and/or tool_calls
Raises:
TimeoutError: If timeout is exceeded
Exception: If completion fails
"""
import aiohttp
import time
import aiohttp
start_time = time.time()
async with aiohttp.ClientSession() as session:
while True:
try:
@@ -127,7 +129,7 @@ class HumanAdapter(CustomLLM):
async with session.get(f"{self.base_url}/status/{call_id}") as response:
response.raise_for_status()
status_data = await response.json()
if status_data["status"] == "completed":
result = {}
if "response" in status_data and status_data["response"]:
@@ -138,166 +140,158 @@ class HumanAdapter(CustomLLM):
elif status_data["status"] == "failed":
error_msg = status_data.get("error", "Unknown error")
raise Exception(f"Completion failed: {error_msg}")
# Check timeout
if time.time() - start_time > self.timeout:
raise TimeoutError(f"Timeout waiting for human response after {self.timeout} seconds")
raise TimeoutError(
f"Timeout waiting for human response after {self.timeout} seconds"
)
# Wait before checking again
await asyncio.sleep(1.0)
except Exception as e:
if time.time() - start_time > self.timeout:
raise TimeoutError(f"Timeout waiting for human response: {e}")
# Continue trying if we haven't timed out
await asyncio.sleep(1.0)
def _generate_response(self, messages: List[Dict[str, Any]], model: str) -> Dict[str, Any]:
"""Generate a human response for the given messages.
Args:
messages: Messages in OpenAI format
model: Model name
Returns:
Dict containing response and/or tool_calls
"""
# Queue the completion request
call_id = self._queue_completion(messages, model)
# Wait for human response
response = self._wait_for_completion(call_id)
return response
async def _async_generate_response(self, messages: List[Dict[str, Any]], model: str) -> Dict[str, Any]:
async def _async_generate_response(
self, messages: List[Dict[str, Any]], model: str
) -> Dict[str, Any]:
"""Async version of _generate_response.
Args:
messages: Messages in OpenAI format
model: Model name
Returns:
Dict containing response and/or tool_calls
"""
# Queue the completion request (sync operation)
call_id = self._queue_completion(messages, model)
# Wait for human response (async)
response = await self._async_wait_for_completion(call_id)
return response
def completion(self, *args, **kwargs) -> ModelResponse:
"""Synchronous completion method.
Returns:
ModelResponse with human-generated text or tool calls
"""
messages = kwargs.get('messages', [])
model = kwargs.get('model', 'human')
messages = kwargs.get("messages", [])
model = kwargs.get("model", "human")
# Generate human response
human_response_data = self._generate_response(messages, model)
# Create ModelResponse with proper structure
from litellm.types.utils import ModelResponse, Choices, Message
import uuid
import time
import uuid
from litellm.types.utils import Choices, Message, ModelResponse
# Create message content based on response type
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
# Tool calls response
message = Message(
role="assistant",
content=human_response_data.get("response", ""),
tool_calls=human_response_data["tool_calls"]
tool_calls=human_response_data["tool_calls"],
)
else:
# Text response
message = Message(
role="assistant",
content=human_response_data.get("response", "")
)
choice = Choices(
finish_reason="stop",
index=0,
message=message
)
message = Message(role="assistant", content=human_response_data.get("response", ""))
choice = Choices(finish_reason="stop", index=0, message=message)
result = ModelResponse(
id=f"human-{uuid.uuid4()}",
choices=[choice],
created=int(time.time()),
model=f"human/{model}",
object="chat.completion"
object="chat.completion",
)
return result
async def acompletion(self, *args, **kwargs) -> ModelResponse:
"""Asynchronous completion method.
Returns:
ModelResponse with human-generated text or tool calls
"""
messages = kwargs.get('messages', [])
model = kwargs.get('model', 'human')
messages = kwargs.get("messages", [])
model = kwargs.get("model", "human")
# Generate human response
human_response_data = await self._async_generate_response(messages, model)
# Create ModelResponse with proper structure
from litellm.types.utils import ModelResponse, Choices, Message
import uuid
import time
import uuid
from litellm.types.utils import Choices, Message, ModelResponse
# Create message content based on response type
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
# Tool calls response
message = Message(
role="assistant",
content=human_response_data.get("response", ""),
tool_calls=human_response_data["tool_calls"]
tool_calls=human_response_data["tool_calls"],
)
else:
# Text response
message = Message(
role="assistant",
content=human_response_data.get("response", "")
)
choice = Choices(
finish_reason="stop",
index=0,
message=message
)
message = Message(role="assistant", content=human_response_data.get("response", ""))
choice = Choices(finish_reason="stop", index=0, message=message)
result = ModelResponse(
id=f"human-{uuid.uuid4()}",
choices=[choice],
created=int(time.time()),
model=f"human/{model}",
object="chat.completion"
object="chat.completion",
)
return result
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
"""Synchronous streaming method.
Yields:
Streaming chunks with human-generated text or tool calls
"""
messages = kwargs.get('messages', [])
model = kwargs.get('model', 'human')
messages = kwargs.get("messages", [])
model = kwargs.get("model", "human")
# Generate human response
human_response_data = self._generate_response(messages, model)
import time
# Handle tool calls vs text response
if "tool_calls" in human_response_data and human_response_data["tool_calls"]:
# Stream tool calls as a single chunk
@@ -319,22 +313,26 @@ class HumanAdapter(CustomLLM):
"is_finished": True,
"text": response_text,
"tool_use": None,
"usage": {"completion_tokens": len(response_text.split()), "prompt_tokens": 0, "total_tokens": len(response_text.split())},
"usage": {
"completion_tokens": len(response_text.split()),
"prompt_tokens": 0,
"total_tokens": len(response_text.split()),
},
}
yield generic_chunk
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
"""Asynchronous streaming method.
Yields:
Streaming chunks with human-generated text or tool calls
"""
messages = kwargs.get('messages', [])
model = kwargs.get('model', 'human')
messages = kwargs.get("messages", [])
model = kwargs.get("model", "human")
# Generate human response
human_response = await self._async_generate_response(messages, model)
# Return as single streaming chunk
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
@@ -342,7 +340,11 @@ class HumanAdapter(CustomLLM):
"is_finished": True,
"text": human_response,
"tool_use": None,
"usage": {"completion_tokens": len(human_response.split()), "prompt_tokens": 0, "total_tokens": len(human_response.split())},
"usage": {
"completion_tokens": len(human_response.split()),
"prompt_tokens": 0,
"total_tokens": len(human_response.split()),
},
}
yield generic_streaming_chunk
yield generic_streaming_chunk

View File

@@ -1,24 +1,26 @@
import asyncio
import functools
import warnings
import io
import base64
import functools
import io
import math
import re
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional, Tuple, cast
from PIL import Image
from litellm.types.utils import GenericStreamingChunk, ModelResponse
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple, cast
from litellm import acompletion, completion
from litellm.llms.custom_llm import CustomLLM
from litellm import completion, acompletion
from litellm.types.utils import GenericStreamingChunk, ModelResponse
from PIL import Image
# Try to import MLX dependencies
try:
import mlx.core as mx
from mlx_vlm import load, generate
from mlx_vlm import generate, load
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_config
from transformers.tokenization_utils import PreTrainedTokenizer
MLX_AVAILABLE = True
except ImportError:
MLX_AVAILABLE = False
@@ -29,20 +31,28 @@ MIN_PIXELS = 100 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
def round_by_factor(number: float, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: float, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: float, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
@@ -70,61 +80,62 @@ def smart_resize(
class MLXVLMAdapter(CustomLLM):
"""MLX VLM Adapter for running vision-language models locally using MLX."""
def __init__(self, **kwargs):
"""Initialize the adapter.
Args:
**kwargs: Additional arguments
"""
super().__init__()
self.models = {} # Cache for loaded models
self.processors = {} # Cache for loaded processors
self.configs = {} # Cache for loaded configs
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
def _load_model_and_processor(self, model_name: str):
"""Load model and processor if not already cached.
Args:
model_name: Name of the model to load
Returns:
Tuple of (model, processor, config)
"""
if not MLX_AVAILABLE:
raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.")
if model_name not in self.models:
# Load model and processor
model_obj, processor = load(
model_name,
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
model_name, processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
)
config = load_config(model_name)
# Cache them
self.models[model_name] = model_obj
self.processors[model_name] = processor
self.configs[model_name] = config
return self.models[model_name], self.processors[model_name], self.configs[model_name]
def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
def _process_coordinates(
self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]
) -> str:
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
Args:
text: Text containing box tokens
original_size: Original image size (width, height)
model_size: Model processed image size (width, height)
Returns:
Text with processed coordinates
"""
# Find all box tokens
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
def process_coords(match):
model_x, model_y = int(match.group(1)), int(match.group(2))
# Scale coordinates from model space to original image space
@@ -132,15 +143,20 @@ class MLXVLMAdapter(CustomLLM):
new_x = int(model_x * original_size[0] / model_size[0]) # Width
new_y = int(model_y * original_size[1] / model_size[1]) # Height
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
return re.sub(box_pattern, process_coords, text)
def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Image.Image], Dict[int, Tuple[int, int]], Dict[int, Tuple[int, int]]]:
def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[
List[Dict[str, Any]],
List[Image.Image],
Dict[int, Tuple[int, int]],
Dict[int, Tuple[int, int]],
]:
"""Convert OpenAI format messages to MLX VLM format and extract images.
Args:
messages: Messages in OpenAI format
Returns:
Tuple of (processed_messages, images, original_sizes, model_sizes)
"""
@@ -149,13 +165,10 @@ class MLXVLMAdapter(CustomLLM):
original_sizes = {} # Track original sizes of images for coordinate mapping
model_sizes = {} # Track model processed sizes
image_index = 0
for message in messages:
processed_message = {
"role": message["role"],
"content": []
}
processed_message = {"role": message["role"], "content": []}
content = message.get("content", [])
if isinstance(content, str):
# Simple text content
@@ -165,164 +178,163 @@ class MLXVLMAdapter(CustomLLM):
processed_content = []
for item in content:
if item.get("type") == "text":
processed_content.append({
"type": "text",
"text": item.get("text", "")
})
processed_content.append({"type": "text", "text": item.get("text", "")})
elif item.get("type") == "image_url":
image_url = item.get("image_url", {}).get("url", "")
pil_image = None
if image_url.startswith("data:image/"):
# Extract base64 data
base64_data = image_url.split(',')[1]
base64_data = image_url.split(",")[1]
# Convert base64 to PIL Image
image_data = base64.b64decode(base64_data)
pil_image = Image.open(io.BytesIO(image_data))
else:
# Handle file path or URL
pil_image = Image.open(image_url)
# Store original image size for coordinate mapping
original_size = pil_image.size
original_sizes[image_index] = original_size
# Use smart_resize to determine model size
# Note: smart_resize expects (height, width) but PIL gives (width, height)
height, width = original_size[1], original_size[0]
new_height, new_width = smart_resize(height, width)
# Store model size in (width, height) format for consistent coordinate processing
model_sizes[image_index] = (new_width, new_height)
# Resize the image using the calculated dimensions from smart_resize
resized_image = pil_image.resize((new_width, new_height))
images.append(resized_image)
# Add image placeholder to content
processed_content.append({
"type": "image"
})
processed_content.append({"type": "image"})
image_index += 1
processed_message["content"] = processed_content
processed_messages.append(processed_message)
return processed_messages, images, original_sizes, model_sizes
def _generate(self, **kwargs) -> str:
"""Generate response using the local MLX VLM model.
Args:
**kwargs: Keyword arguments containing messages and model info
Returns:
Generated text response
"""
messages = kwargs.get('messages', [])
model_name = kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')
max_tokens = kwargs.get('max_tokens', 128)
messages = kwargs.get("messages", [])
model_name = kwargs.get("model", "mlx-community/UI-TARS-1.5-7B-4bit")
max_tokens = kwargs.get("max_tokens", 128)
# Warn about ignored kwargs
ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
ignored_kwargs = set(kwargs.keys()) - {"messages", "model", "max_tokens"}
if ignored_kwargs:
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
# Load model and processor
model, processor, config = self._load_model_and_processor(model_name)
# Convert messages and extract images
processed_messages, images, original_sizes, model_sizes = self._convert_messages(messages)
# Process user text input with box coordinates after image processing
# Swap original_size and model_size arguments for inverse transformation
for msg_idx, msg in enumerate(processed_messages):
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
content = msg.get("content", "")
if "<|box_start|>" in content and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
if (
"<|box_start|>" in content
and original_sizes
and model_sizes
and 0 in original_sizes
and 0 in model_sizes
):
orig_size = original_sizes[0]
model_size = model_sizes[0]
# Swap arguments to perform inverse transformation for user input
processed_messages[msg_idx]["content"] = self._process_coordinates(content, model_size, orig_size)
processed_messages[msg_idx]["content"] = self._process_coordinates(
content, model_size, orig_size
)
try:
# Format prompt according to model requirements using the processor directly
prompt = processor.apply_chat_template(
processed_messages,
tokenize=False,
add_generation_prompt=True,
return_tensors='pt'
processed_messages, tokenize=False, add_generation_prompt=True, return_tensors="pt"
)
tokenizer = cast(PreTrainedTokenizer, processor)
# Generate response
text_content, usage = generate(
model,
tokenizer,
str(prompt),
images, # type: ignore
model,
tokenizer,
str(prompt),
images, # type: ignore
verbose=False,
max_tokens=max_tokens
max_tokens=max_tokens,
)
except Exception as e:
raise RuntimeError(f"Error generating response: {str(e)}") from e
# Process coordinates in the response back to original image space
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
# Get original image size and model size (using the first image)
orig_size = original_sizes[0]
model_size = model_sizes[0]
# Check if output contains box tokens that need processing
if "<|box_start|>" in text_content:
# Process coordinates from model space back to original image space
text_content = self._process_coordinates(text_content, orig_size, model_size)
return text_content
def completion(self, *args, **kwargs) -> ModelResponse:
"""Synchronous completion method.
Returns:
ModelResponse with generated text
"""
generated_text = self._generate(**kwargs)
result = completion(
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
mock_response=generated_text,
)
return cast(ModelResponse, result)
async def acompletion(self, *args, **kwargs) -> ModelResponse:
"""Asynchronous completion method.
Returns:
ModelResponse with generated text
"""
# Run _generate in thread pool to avoid blocking
loop = asyncio.get_event_loop()
generated_text = await loop.run_in_executor(
self._executor,
functools.partial(self._generate, **kwargs)
self._executor, functools.partial(self._generate, **kwargs)
)
result = await acompletion(
model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}",
mock_response=generated_text,
)
return cast(ModelResponse, result)
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
"""Synchronous streaming method.
Returns:
Iterator of GenericStreamingChunk
"""
generated_text = self._generate(**kwargs)
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
@@ -331,22 +343,21 @@ class MLXVLMAdapter(CustomLLM):
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
yield generic_streaming_chunk
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
"""Asynchronous streaming method.
Returns:
AsyncIterator of GenericStreamingChunk
"""
# Run _generate in thread pool to avoid blocking
loop = asyncio.get_event_loop()
generated_text = await loop.run_in_executor(
self._executor,
functools.partial(self._generate, **kwargs)
self._executor, functools.partial(self._generate, **kwargs)
)
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
@@ -355,5 +366,5 @@ class MLXVLMAdapter(CustomLLM):
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
yield generic_streaming_chunk
yield generic_streaming_chunk

View File

@@ -2,32 +2,40 @@ from typing import Optional
try:
from transformers import AutoConfig
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
from .generic import GenericHFModel
from .internvl import InternVLModel
from .opencua import OpenCUAModel
from .qwen2_5_vl import Qwen2_5_VLModel
from .internvl import InternVLModel
def load_model(model_name: str, device: str = "auto", trust_remote_code: bool = False):
"""Factory function to load and return the right model handler instance.
- If the underlying transformers config class matches OpenCUA, return OpenCUAModel
- Otherwise, return GenericHFModel
"""
if not HF_AVAILABLE:
raise ImportError(
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
)
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
cls = cfg.__class__.__name__
print(f"cls: {cls}")
if "OpenCUA" in cls:
return OpenCUAModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
return OpenCUAModel(
model_name=model_name, device=device, trust_remote_code=trust_remote_code
)
elif "Qwen2_5_VL" in cls:
return Qwen2_5_VLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
return Qwen2_5_VLModel(
model_name=model_name, device=device, trust_remote_code=trust_remote_code
)
elif "InternVL" in cls:
return InternVLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
return InternVLModel(
model_name=model_name, device=device, trust_remote_code=trust_remote_code
)
return GenericHFModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)

View File

@@ -1,9 +1,10 @@
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
# Hugging Face imports are local to avoid hard dependency at module import
try:
import torch # type: ignore
from transformers import AutoModel, AutoProcessor # type: ignore
HF_AVAILABLE = True
except Exception:
HF_AVAILABLE = False
@@ -14,10 +15,12 @@ class GenericHFModel:
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
"""
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
def __init__(
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
) -> None:
if not HF_AVAILABLE:
raise ImportError(
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
)
self.model_name = model_name
self.device = device
@@ -64,7 +67,7 @@ class GenericHFModel:
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
# Trim prompt tokens from output
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
# Decode
output_text = self.processor.batch_decode(

View File

@@ -1,19 +1,22 @@
from __future__ import annotations
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
# Hugging Face imports are local to avoid hard dependency at module import
try:
import torch # type: ignore
from transformers import AutoModel, AutoTokenizer # type: ignore
# Attempt to import InternVL's model dependencies
import einops as _ # type: ignore
import timm as _ # type: ignore
from PIL import Image # type: ignore
import torchvision.transforms as T # type: ignore
from torchvision.transforms.functional import InterpolationMode # type: ignore
import base64 # type: ignore
from io import BytesIO # type: ignore
# Attempt to import InternVL's model dependencies
import einops as _ # type: ignore
import requests # type: ignore
import timm as _ # type: ignore
import torch # type: ignore
import torchvision.transforms as T # type: ignore
from PIL import Image # type: ignore
from torchvision.transforms.functional import InterpolationMode # type: ignore
from transformers import AutoModel, AutoTokenizer # type: ignore
HF_AVAILABLE = True
except Exception:
HF_AVAILABLE = False
@@ -25,10 +28,12 @@ class InternVLModel:
Provides preprocessing to support multi-turn conversations with multiple images.
"""
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
def __init__(
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
) -> None:
if not HF_AVAILABLE:
raise ImportError(
"InternVL dependencies not found. Install with: pip install \"cua-agent[internvl-hf]\""
'InternVL dependencies not found. Install with: pip install "cua-agent[internvl-hf]"'
)
self.model_name = model_name
self.device = device
@@ -60,16 +65,25 @@ class InternVLModel:
def _build_transform(self, input_size: int) -> T.Compose:
MEAN, STD = self.IMAGENET_MEAN, self.IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD),
]
)
return transform
def _find_closest_aspect_ratio(self, aspect_ratio: float, target_ratios: List[tuple], width: int, height: int, image_size: int):
best_ratio_diff = float('inf')
def _find_closest_aspect_ratio(
self,
aspect_ratio: float,
target_ratios: List[tuple],
width: int,
height: int,
image_size: int,
):
best_ratio_diff = float("inf")
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
@@ -83,17 +97,29 @@ class InternVLModel:
best_ratio = ratio
return best_ratio
def _dynamic_preprocess(self, image: Image.Image, min_num: int = 1, max_num: int = 12, image_size: int = 448, use_thumbnail: bool = True) -> List[Image.Image]:
def _dynamic_preprocess(
self,
image: Image.Image,
min_num: int = 1,
max_num: int = 12,
image_size: int = 448,
use_thumbnail: bool = True,
) -> List[Image.Image]:
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
(i, j)
for n in range(min_num, max_num + 1)
for i in range(1, n + 1)
for j in range(1, n + 1)
if i * j <= max_num and i * j >= min_num
)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
target_aspect_ratio = self._find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
aspect_ratio, target_ratios, orig_width, orig_height, image_size
)
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
@@ -106,7 +132,7 @@ class InternVLModel:
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
((i // (target_width // image_size)) + 1) * image_size,
)
split_img = resized_img.crop(box)
processed_images.append(split_img)
@@ -122,20 +148,24 @@ class InternVLModel:
# data URL base64
header, b64data = src.split(",", 1)
img_bytes = base64.b64decode(b64data)
return Image.open(BytesIO(img_bytes)).convert('RGB')
return Image.open(BytesIO(img_bytes)).convert("RGB")
if src.startswith("http://") or src.startswith("https://"):
resp = requests.get(src, timeout=10)
resp.raise_for_status()
return Image.open(BytesIO(resp.content)).convert('RGB')
return Image.open(BytesIO(resp.content)).convert("RGB")
# Assume local file path
return Image.open(src).convert('RGB')
return Image.open(src).convert("RGB")
def _images_to_pixel_values(self, images: List[Image.Image], input_size: int = 448, max_num: int = 12):
def _images_to_pixel_values(
self, images: List[Image.Image], input_size: int = 448, max_num: int = 12
):
transform = self._build_transform(input_size=input_size)
pixel_values_list = []
num_patches_list: List[int] = []
for img in images:
tiles = self._dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num)
tiles = self._dynamic_preprocess(
img, image_size=input_size, use_thumbnail=True, max_num=max_num
)
pv = [transform(tile) for tile in tiles]
pv = torch.stack(pv)
num_patches_list.append(pv.shape[0])
@@ -191,7 +221,9 @@ class InternVLModel:
last_user_text_parts = parts_text or last_user_text_parts
elif role == "assistant":
# Only keep text content for history
parts_text = [item.get("text", "") for item in content_items if item.get("type") == "text"]
parts_text = [
item.get("text", "") for item in content_items if item.get("type") == "text"
]
text = "\n".join(parts_text).strip()
if text:
context_lines.append(f"Assistant: {text}")
@@ -200,7 +232,9 @@ class InternVLModel:
pixel_values = None
num_patches_list: List[int] = []
if all_images:
pixel_values, num_patches_list = self._images_to_pixel_values(all_images, input_size=448, max_num=12)
pixel_values, num_patches_list = self._images_to_pixel_values(
all_images, input_size=448, max_num=12
)
if pixel_values is not None:
# Convert dtype/device as in docs
pixel_values = pixel_values.to(torch.bfloat16)
@@ -246,7 +280,9 @@ class InternVLModel:
num_patches_list=num_patches_list,
)
else:
response = self.model.chat(self.tokenizer, pixel_values, question, generation_config)
response = self.model.chat(
self.tokenizer, pixel_values, question, generation_config
)
except Exception as e:
# Fallback: return empty string to avoid crashing the adapter
return ""

View File

@@ -1,13 +1,18 @@
from typing import List, Dict, Any
import re
import base64
import re
from io import BytesIO
from typing import Any, Dict, List
try:
import blobfile as _ # assert blobfile is installed
import torch # type: ignore
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor # type: ignore
from PIL import Image # type: ignore
import blobfile as _ # assert blobfile is installed
from transformers import ( # type: ignore
AutoImageProcessor,
AutoModel,
AutoTokenizer,
)
OPENCUA_AVAILABLE = True
except Exception:
OPENCUA_AVAILABLE = False
@@ -16,10 +21,12 @@ except Exception:
class OpenCUAModel:
"""OpenCUA model handler using AutoTokenizer, AutoModel and AutoImageProcessor."""
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
def __init__(
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
) -> None:
if not OPENCUA_AVAILABLE:
raise ImportError(
"OpenCUA requirements not found. Install with: pip install \"cua-agent[opencua-hf]\""
'OpenCUA requirements not found. Install with: pip install "cua-agent[opencua-hf]"'
)
self.model_name = model_name
self.device = device
@@ -56,7 +63,11 @@ class OpenCUAModel:
return ""
def generate(self, messages: List[Dict[str, Any]], max_new_tokens: int = 512) -> str:
assert self.model is not None and self.tokenizer is not None and self.image_processor is not None
assert (
self.model is not None
and self.tokenizer is not None
and self.image_processor is not None
)
# Tokenize text side using chat template
input_ids = self.tokenizer.apply_chat_template(
@@ -74,7 +85,11 @@ class OpenCUAModel:
pixel_values = torch.tensor(image_info["pixel_values"]).to(
dtype=torch.bfloat16, device=self.model.device
)
grid_thws = torch.tensor(image_info["image_grid_thw"]) if "image_grid_thw" in image_info else None
grid_thws = (
torch.tensor(image_info["image_grid_thw"])
if "image_grid_thw" in image_info
else None
)
gen_kwargs: Dict[str, Any] = {
"max_new_tokens": max_new_tokens,

View File

@@ -1,9 +1,10 @@
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
# Hugging Face imports are local to avoid hard dependency at module import
try:
import torch # type: ignore
from transformers import AutoModelForImageTextToText, AutoProcessor # type: ignore
HF_AVAILABLE = True
except Exception:
HF_AVAILABLE = False
@@ -14,10 +15,12 @@ class Qwen2_5_VLModel:
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
"""
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
def __init__(
self, model_name: str, device: str = "auto", trust_remote_code: bool = False
) -> None:
if not HF_AVAILABLE:
raise ImportError(
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
'HuggingFace transformers dependencies not found. Install with: pip install "cua-agent[uitars-hf]"'
)
self.model_name = model_name
self.device = device
@@ -64,7 +67,7 @@ class Qwen2_5_VLModel:
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
# Trim prompt tokens from output
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
# Decode
output_text = self.processor.batch_decode(

View File

@@ -3,76 +3,83 @@ ComputerAgent - Main agent class that selects and runs agent loops
"""
import asyncio
from pathlib import Path
from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set, Tuple
from litellm.responses.utils import Usage
from .types import (
Messages,
AgentCapability,
ToolError,
IllegalArgumentError
)
from .responses import make_tool_error_item, replace_failed_computer_calls_with_function_calls
from .decorators import find_agent_config
import inspect
import json
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
Callable,
Dict,
List,
Optional,
Set,
Tuple,
Union,
cast,
)
import litellm
import litellm.utils
import inspect
from litellm.responses.utils import Usage
from .adapters import (
HuggingFaceLocalAdapter,
HumanAdapter,
MLXVLMAdapter,
)
from .callbacks import (
ImageRetentionCallback,
LoggingCallback,
TrajectorySaverCallback,
BudgetManagerCallback,
TelemetryCallback,
ImageRetentionCallback,
LoggingCallback,
OperatorNormalizerCallback,
PromptInstructionsCallback,
TelemetryCallback,
TrajectorySaverCallback,
)
from .computers import (
AsyncComputerHandler,
is_agent_computer,
make_computer_handler
from .computers import AsyncComputerHandler, is_agent_computer, make_computer_handler
from .decorators import find_agent_config
from .responses import (
make_tool_error_item,
replace_failed_computer_calls_with_function_calls,
)
from .types import AgentCapability, IllegalArgumentError, Messages, ToolError
def assert_callable_with(f, *args, **kwargs):
"""Check if function can be called with given arguments."""
try:
inspect.signature(f).bind(*args, **kwargs)
return True
except TypeError as e:
sig = inspect.signature(f)
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
"""Check if function can be called with given arguments."""
try:
inspect.signature(f).bind(*args, **kwargs)
return True
except TypeError as e:
sig = inspect.signature(f)
raise IllegalArgumentError(f"Expected {sig}, got args={args} kwargs={kwargs}") from e
def get_json(obj: Any, max_depth: int = 10) -> Any:
def custom_serializer(o: Any, depth: int = 0, seen: Optional[Set[int]] = None) -> Any:
if seen is None:
seen = set()
# Use model_dump() if available
if hasattr(o, 'model_dump'):
if hasattr(o, "model_dump"):
return o.model_dump()
# Check depth limit
if depth > max_depth:
return f"<max_depth_exceeded:{max_depth}>"
# Check for circular references using object id
obj_id = id(o)
if obj_id in seen:
return f"<circular_reference:{type(o).__name__}>"
# Handle Computer objects
if hasattr(o, '__class__') and 'computer' in getattr(o, '__class__').__name__.lower():
if hasattr(o, "__class__") and "computer" in o.__class__.__name__.lower():
return f"<computer:{o.__class__.__name__}>"
# Handle objects with __dict__
if hasattr(o, '__dict__'):
if hasattr(o, "__dict__"):
seen.add(obj_id)
try:
result = {}
@@ -84,7 +91,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
return result
finally:
seen.discard(obj_id)
# Handle common types that might contain nested objects
elif isinstance(o, dict):
seen.add(obj_id)
@@ -96,7 +103,7 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
}
finally:
seen.discard(obj_id)
elif isinstance(o, (list, tuple, set)):
seen.add(obj_id)
try:
@@ -107,32 +114,33 @@ def get_json(obj: Any, max_depth: int = 10) -> Any:
]
finally:
seen.discard(obj_id)
# For basic types that json.dumps can handle
elif isinstance(o, (str, int, float, bool)) or o is None:
return o
# Fallback to string representation
else:
return str(o)
def remove_nones(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: remove_nones(v) for k, v in obj.items() if v is not None}
elif isinstance(obj, list):
return [remove_nones(item) for item in obj if item is not None]
return obj
# Serialize with circular reference and depth protection
serialized = custom_serializer(obj)
# Convert to JSON string and back to ensure JSON compatibility
json_str = json.dumps(serialized)
parsed = json.loads(json_str)
# Final cleanup of any remaining None values
return remove_nones(parsed)
def sanitize_message(msg: Any) -> Any:
"""Return a copy of the message with image_url omitted for computer_call_output messages."""
if msg.get("type") == "computer_call_output":
@@ -143,19 +151,24 @@ def sanitize_message(msg: Any) -> Any:
return sanitized
return msg
def get_output_call_ids(messages: List[Dict[str, Any]]) -> List[str]:
call_ids = []
for message in messages:
if message.get("type") == "computer_call_output" or message.get("type") == "function_call_output":
if (
message.get("type") == "computer_call_output"
or message.get("type") == "function_call_output"
):
call_ids.append(message.get("call_id"))
return call_ids
class ComputerAgent:
"""
Main agent class that automatically selects the appropriate agent loop
based on the model and executes tool calls.
"""
def __init__(
self,
model: str,
@@ -172,11 +185,11 @@ class ComputerAgent:
max_trajectory_budget: Optional[float | dict] = None,
telemetry_enabled: Optional[bool] = True,
trust_remote_code: Optional[bool] = False,
**kwargs
**kwargs,
):
"""
Initialize ComputerAgent.
Args:
model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
tools: List of tools (computer objects, decorated functions, etc.)
@@ -193,11 +206,11 @@ class ComputerAgent:
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
**kwargs: Additional arguments passed to the agent loop
"""
"""
# If the loop is "human/human", we need to prefix a grounding model fallback
if model in ["human/human", "human"]:
model = "openai/computer-use-preview+human/human"
self.model = model
self.tools = tools or []
self.custom_loop = custom_loop
@@ -236,34 +249,33 @@ class ComputerAgent:
# Add image retention callback if only_n_most_recent_images is set
if self.only_n_most_recent_images:
self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
# Add trajectory saver callback if trajectory_dir is set
if self.trajectory_dir:
if isinstance(self.trajectory_dir, dict):
self.callbacks.append(TrajectorySaverCallback(**self.trajectory_dir))
elif isinstance(self.trajectory_dir, (str, Path)):
self.callbacks.append(TrajectorySaverCallback(str(self.trajectory_dir)))
# Add budget manager if max_trajectory_budget is set
if max_trajectory_budget:
if isinstance(max_trajectory_budget, dict):
self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
else:
self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
# == Enable local model providers w/ LiteLLM ==
# Register local model providers
hf_adapter = HuggingFaceLocalAdapter(
device="auto",
trust_remote_code=self.trust_remote_code or False
device="auto", trust_remote_code=self.trust_remote_code or False
)
human_adapter = HumanAdapter()
mlx_adapter = MLXVLMAdapter()
litellm.custom_provider_map = [
{"provider": "huggingface-local", "custom_handler": hf_adapter},
{"provider": "human", "custom_handler": human_adapter},
{"provider": "mlx", "custom_handler": mlx_adapter}
{"provider": "mlx", "custom_handler": mlx_adapter},
]
litellm.suppress_debug_info = True
@@ -280,16 +292,16 @@ class ComputerAgent:
# Instantiate the agent config class
self.agent_loop = config_info.agent_class()
self.agent_config_info = config_info
self.tool_schemas = []
self.computer_handler = None
async def _initialize_computers(self):
"""Initialize computer objects"""
if not self.tool_schemas:
# Process tools and create tool schemas
self.tool_schemas = self._process_tools()
# Find computer tool and create interface adapter
computer_handler = None
for schema in self.tool_schemas:
@@ -297,7 +309,7 @@ class ComputerAgent:
computer_handler = await make_computer_handler(schema["computer"])
break
self.computer_handler = computer_handler
def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
"""Process input messages and create schemas for the agent loop"""
if isinstance(input, str):
@@ -307,69 +319,73 @@ class ComputerAgent:
def _process_tools(self) -> List[Dict[str, Any]]:
"""Process tools and create schemas for the agent loop"""
schemas = []
for tool in self.tools:
# Check if it's a computer object (has interface attribute)
if is_agent_computer(tool):
# This is a computer tool - will be handled by agent loop
schemas.append({
"type": "computer",
"computer": tool
})
schemas.append({"type": "computer", "computer": tool})
elif callable(tool):
# Use litellm.utils.function_to_dict to extract schema from docstring
try:
function_schema = litellm.utils.function_to_dict(tool)
schemas.append({
"type": "function",
"function": function_schema
})
schemas.append({"type": "function", "function": function_schema})
except Exception as e:
print(f"Warning: Could not process tool {tool}: {e}")
else:
print(f"Warning: Unknown tool type: {tool}")
return schemas
def _get_tool(self, name: str) -> Optional[Callable]:
"""Get a tool by name"""
for tool in self.tools:
if hasattr(tool, '__name__') and tool.__name__ == name:
if hasattr(tool, "__name__") and tool.__name__ == name:
return tool
elif hasattr(tool, 'func') and tool.func.__name__ == name:
elif hasattr(tool, "func") and tool.func.__name__ == name:
return tool
return None
# ============================================================================
# AGENT RUN LOOP LIFECYCLE HOOKS
# ============================================================================
async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Initialize run tracking by calling callbacks."""
for callback in self.callbacks:
if hasattr(callback, 'on_run_start'):
if hasattr(callback, "on_run_start"):
await callback.on_run_start(kwargs, old_items)
async def _on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
async def _on_run_end(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> None:
"""Finalize run tracking by calling callbacks."""
for callback in self.callbacks:
if hasattr(callback, 'on_run_end'):
if hasattr(callback, "on_run_end"):
await callback.on_run_end(kwargs, old_items, new_items)
async def _on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
async def _on_run_continue(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> bool:
"""Check if run should continue by calling callbacks."""
for callback in self.callbacks:
if hasattr(callback, 'on_run_continue'):
if hasattr(callback, "on_run_continue"):
should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
if not should_continue:
return False
return True
async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Prepare messages for the LLM call by applying callbacks."""
result = messages
for callback in self.callbacks:
if hasattr(callback, 'on_llm_start'):
if hasattr(callback, "on_llm_start"):
result = await callback.on_llm_start(result)
return result
@@ -377,82 +393,91 @@ class ComputerAgent:
"""Postprocess messages after the LLM call by applying callbacks."""
result = messages
for callback in self.callbacks:
if hasattr(callback, 'on_llm_end'):
if hasattr(callback, "on_llm_end"):
result = await callback.on_llm_end(result)
return result
async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
"""Called when responses are received."""
for callback in self.callbacks:
if hasattr(callback, 'on_responses'):
if hasattr(callback, "on_responses"):
await callback.on_responses(get_json(kwargs), get_json(responses))
async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a computer call is about to start."""
for callback in self.callbacks:
if hasattr(callback, 'on_computer_call_start'):
if hasattr(callback, "on_computer_call_start"):
await callback.on_computer_call_start(get_json(item))
async def _on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
async def _on_computer_call_end(
self, item: Dict[str, Any], result: List[Dict[str, Any]]
) -> None:
"""Called when a computer call has completed."""
for callback in self.callbacks:
if hasattr(callback, 'on_computer_call_end'):
if hasattr(callback, "on_computer_call_end"):
await callback.on_computer_call_end(get_json(item), get_json(result))
async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a function call is about to start."""
for callback in self.callbacks:
if hasattr(callback, 'on_function_call_start'):
if hasattr(callback, "on_function_call_start"):
await callback.on_function_call_start(get_json(item))
async def _on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
async def _on_function_call_end(
self, item: Dict[str, Any], result: List[Dict[str, Any]]
) -> None:
"""Called when a function call has completed."""
for callback in self.callbacks:
if hasattr(callback, 'on_function_call_end'):
if hasattr(callback, "on_function_call_end"):
await callback.on_function_call_end(get_json(item), get_json(result))
async def _on_text(self, item: Dict[str, Any]) -> None:
"""Called when a text message is encountered."""
for callback in self.callbacks:
if hasattr(callback, 'on_text'):
if hasattr(callback, "on_text"):
await callback.on_text(get_json(item))
async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
"""Called when an LLM API call is about to start."""
for callback in self.callbacks:
if hasattr(callback, 'on_api_start'):
if hasattr(callback, "on_api_start"):
await callback.on_api_start(get_json(kwargs))
async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
"""Called when an LLM API call has completed."""
for callback in self.callbacks:
if hasattr(callback, 'on_api_end'):
if hasattr(callback, "on_api_end"):
await callback.on_api_end(get_json(kwargs), get_json(result))
async def _on_usage(self, usage: Dict[str, Any]) -> None:
"""Called when usage information is received."""
for callback in self.callbacks:
if hasattr(callback, 'on_usage'):
if hasattr(callback, "on_usage"):
await callback.on_usage(get_json(usage))
async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
"""Called when a screenshot is taken."""
for callback in self.callbacks:
if hasattr(callback, 'on_screenshot'):
if hasattr(callback, "on_screenshot"):
await callback.on_screenshot(screenshot, name)
# ============================================================================
# AGENT OUTPUT PROCESSING
# ============================================================================
async def _handle_item(self, item: Any, computer: Optional[AsyncComputerHandler] = None, ignore_call_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]:
async def _handle_item(
self,
item: Any,
computer: Optional[AsyncComputerHandler] = None,
ignore_call_ids: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Handle each item; may cause a computer action + screenshot."""
call_id = item.get("call_id")
if ignore_call_ids and call_id and call_id in ignore_call_ids:
return []
item_type = item.get("type", None)
if item_type == "message":
await self._on_text(item)
# # Print messages
@@ -461,7 +486,7 @@ class ComputerAgent:
# if content_item.get("text"):
# print(content_item.get("text"))
return []
try:
if item_type == "computer_call":
await self._on_computer_call_start(item)
@@ -472,14 +497,16 @@ class ComputerAgent:
action = item.get("action")
action_type = action.get("type")
if action_type is None:
print(f"Action type cannot be `None`: action={action}, action_type={action_type}")
print(
f"Action type cannot be `None`: action={action}, action_type={action_type}"
)
return []
# Extract action arguments (all fields except 'type')
action_args = {k: v for k, v in action.items() if k != "type"}
# print(f"{action_type}({action_args})")
# Execute the computer action
computer_method = getattr(computer, action_type, None)
if computer_method:
@@ -487,13 +514,13 @@ class ComputerAgent:
await computer_method(**action_args)
else:
raise ToolError(f"Unknown computer action: {action_type}")
# Take screenshot after action
if self.screenshot_delay and self.screenshot_delay > 0:
await asyncio.sleep(self.screenshot_delay)
screenshot_base64 = await computer.screenshot()
await self._on_screenshot(screenshot_base64, "screenshot_after")
# Handle safety checks
pending_checks = item.get("pending_safety_checks", [])
acknowledged_checks = []
@@ -505,7 +532,7 @@ class ComputerAgent:
# acknowledged_checks.append(check)
# else:
# raise ValueError(f"Safety check failed: {check_message}")
# Create call output
call_output = {
"type": "computer_call_output",
@@ -516,25 +543,25 @@ class ComputerAgent:
"image_url": f"data:image/png;base64,{screenshot_base64}",
},
}
# # Additional URL safety checks for browser environments
# if await computer.get_environment() == "browser":
# current_url = await computer.get_current_url()
# call_output["output"]["current_url"] = current_url
# # TODO: implement a callback for URL safety checks
# # check_blocklisted_url(current_url)
result = [call_output]
await self._on_computer_call_end(item, result)
return result
if item_type == "function_call":
await self._on_function_call_start(item)
# Perform function call
function = self._get_tool(item.get("name"))
if not function:
raise ToolError(f"Function {item.get("name")} not found")
raise ToolError(f"Function {item.get('name')} not found")
args = json.loads(item.get("arguments"))
# Validate arguments before execution
@@ -545,14 +572,14 @@ class ComputerAgent:
result = await function(**args)
else:
result = await asyncio.to_thread(function, **args)
# Create function call output
call_output = {
"type": "function_call_output",
"call_id": item.get("call_id"),
"output": str(result),
}
result = [call_output]
await self._on_function_call_end(item, result)
return result
@@ -564,36 +591,35 @@ class ComputerAgent:
# ============================================================================
# MAIN AGENT LOOP
# ============================================================================
async def run(
self,
messages: Messages,
stream: bool = False,
**kwargs
self, messages: Messages, stream: bool = False, **kwargs
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Run the agent with the given messages using Computer protocol handler pattern.
Args:
messages: List of message dictionaries
stream: Whether to stream the response
**kwargs: Additional arguments
Returns:
AsyncGenerator that yields response chunks
"""
if not self.agent_config_info:
raise ValueError("Agent configuration not found")
capabilities = self.get_capabilities()
if "step" not in capabilities:
raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions")
raise ValueError(
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions"
)
await self._initialize_computers()
# Merge kwargs
merged_kwargs = {**self.kwargs, **kwargs}
old_items = self._process_input(messages)
new_items = []
@@ -603,7 +629,7 @@ class ComputerAgent:
"stream": stream,
"model": self.model,
"agent_loop": self.agent_config_info.agent_class.__name__,
**merged_kwargs
**merged_kwargs,
}
await self._on_run_start(run_kwargs, old_items)
@@ -620,7 +646,7 @@ class ComputerAgent:
combined_messages = old_items + new_items
combined_messages = replace_failed_computer_calls_with_function_calls(combined_messages)
preprocessed_messages = await self._on_llm_start(combined_messages)
loop_kwargs = {
"messages": preprocessed_messages,
"model": self.model,
@@ -629,7 +655,7 @@ class ComputerAgent:
"computer_handler": self.computer_handler,
"max_retries": self.max_retries,
"use_prompt_caching": self.use_prompt_caching,
**merged_kwargs
**merged_kwargs,
}
# Run agent loop iteration
@@ -641,13 +667,13 @@ class ComputerAgent:
_on_screenshot=self._on_screenshot,
)
result = get_json(result)
# Lifecycle hook: Postprocess messages after the LLM call
# Use cases:
# - PII deanonymization (if you want tool calls to see PII)
result["output"] = await self._on_llm_end(result.get("output", []))
await self._on_responses(loop_kwargs, result)
# Yield agent response
yield result
@@ -659,7 +685,9 @@ class ComputerAgent:
# Handle computer actions
for item in result.get("output"):
partial_items = await self._handle_item(item, self.computer_handler, ignore_call_ids=output_call_ids)
partial_items = await self._handle_item(
item, self.computer_handler, ignore_call_ids=output_call_ids
)
new_items += partial_items
# Yield partial response
@@ -669,54 +697,52 @@ class ComputerAgent:
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
),
}
await self._on_run_end(loop_kwargs, old_items, new_items)
async def predict_click(
self,
instruction: str,
image_b64: Optional[str] = None
self, instruction: str, image_b64: Optional[str] = None
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates based on image and instruction.
Args:
instruction: Instruction for where to click
image_b64: Base64 encoded image (optional, will take screenshot if not provided)
Returns:
None or tuple with (x, y) coordinates
"""
if not self.agent_config_info:
raise ValueError("Agent configuration not found")
capabilities = self.get_capabilities()
if "click" not in capabilities:
raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions")
if hasattr(self.agent_loop, 'predict_click'):
raise ValueError(
f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions"
)
if hasattr(self.agent_loop, "predict_click"):
if not image_b64:
if not self.computer_handler:
raise ValueError("Computer tool or image_b64 is required for predict_click")
image_b64 = await self.computer_handler.screenshot()
return await self.agent_loop.predict_click(
model=self.model,
image_b64=image_b64,
instruction=instruction
model=self.model, image_b64=image_b64, instruction=instruction
)
return None
def get_capabilities(self) -> List[AgentCapability]:
"""
Get list of capabilities supported by the current agent config.
Returns:
List of capability strings (e.g., ["step", "click"])
"""
if not self.agent_config_info:
raise ValueError("Agent configuration not found")
if hasattr(self.agent_loop, 'get_capabilities'):
if hasattr(self.agent_loop, "get_capabilities"):
return self.agent_loop.get_capabilities()
return ["step"] # Default capability
return ["step"] # Default capability

View File

@@ -3,17 +3,17 @@ Callback system for ComputerAgent preprocessing and postprocessing hooks.
"""
from .base import AsyncCallbackHandler
from .budget_manager import BudgetManagerCallback
from .image_retention import ImageRetentionCallback
from .logging import LoggingCallback
from .trajectory_saver import TrajectorySaverCallback
from .budget_manager import BudgetManagerCallback
from .telemetry import TelemetryCallback
from .operator_validator import OperatorNormalizerCallback
from .prompt_instructions import PromptInstructionsCallback
from .telemetry import TelemetryCallback
from .trajectory_saver import TrajectorySaverCallback
__all__ = [
"AsyncCallbackHandler",
"ImageRetentionCallback",
"ImageRetentionCallback",
"LoggingCallback",
"TrajectorySaverCallback",
"BudgetManagerCallback",

View File

@@ -3,7 +3,7 @@ Base callback handler interface for ComputerAgent preprocessing and postprocessi
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Union
from typing import Any, Dict, List, Optional, Union
class AsyncCallbackHandler(ABC):
@@ -16,42 +16,52 @@ class AsyncCallbackHandler(ABC):
"""Called at the start of an agent run loop."""
pass
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
async def on_run_end(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> None:
"""Called at the end of an agent run loop."""
pass
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
async def on_run_continue(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> bool:
"""Called during agent run loop to determine if execution should continue.
Args:
kwargs: Run arguments
old_items: Original messages
new_items: New messages generated during run
Returns:
True to continue execution, False to stop
"""
return True
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Called before messages are sent to the agent loop.
Args:
messages: List of message dictionaries to preprocess
Returns:
List of preprocessed message dictionaries
"""
return messages
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Called after the agent loop returns output.
Args:
output: List of output message dictionaries to postprocess
Returns:
List of postprocessed output dictionaries
"""
@@ -60,63 +70,67 @@ class AsyncCallbackHandler(ABC):
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
"""
Called when a computer call is about to start.
Args:
item: The computer call item dictionary
"""
pass
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
async def on_computer_call_end(
self, item: Dict[str, Any], result: List[Dict[str, Any]]
) -> None:
"""
Called when a computer call has completed.
Args:
item: The computer call item dictionary
result: The result of the computer call
"""
pass
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
"""
Called when a function call is about to start.
Args:
item: The function call item dictionary
"""
pass
async def on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
async def on_function_call_end(
self, item: Dict[str, Any], result: List[Dict[str, Any]]
) -> None:
"""
Called when a function call has completed.
Args:
item: The function call item dictionary
result: The result of the function call
"""
pass
async def on_text(self, item: Dict[str, Any]) -> None:
"""
Called when a text message is encountered.
Args:
item: The message item dictionary
"""
pass
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
"""
Called when an API call is about to start.
Args:
kwargs: The kwargs being passed to the API call
"""
pass
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
"""
Called when an API call has completed.
Args:
kwargs: The kwargs that were passed to the API call
result: The result of the API call
@@ -126,7 +140,7 @@ class AsyncCallbackHandler(ABC):
async def on_usage(self, usage: Dict[str, Any]) -> None:
"""
Called when usage information is received.
Args:
usage: The usage information
"""
@@ -135,7 +149,7 @@ class AsyncCallbackHandler(ABC):
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
"""
Called when a screenshot is taken.
Args:
screenshot: The screenshot image
name: The name of the screenshot
@@ -145,9 +159,9 @@ class AsyncCallbackHandler(ABC):
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
"""
Called when responses are received.
Args:
kwargs: The kwargs being passed to the agent loop
responses: The responses received
"""
pass
pass

View File

@@ -1,17 +1,23 @@
from typing import Dict, List, Any
from typing import Any, Dict, List
from .base import AsyncCallbackHandler
class BudgetExceededError(Exception):
"""Exception raised when budget is exceeded."""
pass
class BudgetManagerCallback(AsyncCallbackHandler):
"""Budget manager callback that tracks usage costs and can stop execution when budget is exceeded."""
def __init__(self, max_budget: float, reset_after_each_run: bool = True, raise_error: bool = False):
def __init__(
self, max_budget: float, reset_after_each_run: bool = True, raise_error: bool = False
):
"""
Initialize BudgetManagerCallback.
Args:
max_budget: Maximum budget allowed
reset_after_each_run: Whether to reset budget after each run
@@ -21,24 +27,30 @@ class BudgetManagerCallback(AsyncCallbackHandler):
self.reset_after_each_run = reset_after_each_run
self.raise_error = raise_error
self.total_cost = 0.0
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Reset budget if configured to do so."""
if self.reset_after_each_run:
self.total_cost = 0.0
async def on_usage(self, usage: Dict[str, Any]) -> None:
"""Track usage costs."""
if "response_cost" in usage:
self.total_cost += usage["response_cost"]
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
async def on_run_continue(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> bool:
"""Check if budget allows continuation."""
if self.total_cost >= self.max_budget:
if self.raise_error:
raise BudgetExceededError(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
raise BudgetExceededError(
f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}"
)
else:
print(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
return False
return True

View File

@@ -2,7 +2,8 @@
Image retention callback handler that limits the number of recent images in message history.
"""
from typing import List, Dict, Any, Optional
from typing import Any, Dict, List, Optional
from .base import AsyncCallbackHandler
@@ -11,40 +12,40 @@ class ImageRetentionCallback(AsyncCallbackHandler):
Callback handler that applies image retention policy to limit the number
of recent images in message history to prevent context window overflow.
"""
def __init__(self, only_n_most_recent_images: Optional[int] = None):
"""
Initialize the image retention callback.
Args:
only_n_most_recent_images: If set, only keep the N most recent images in message history
"""
self.only_n_most_recent_images = only_n_most_recent_images
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Apply image retention policy to messages before sending to agent loop.
Args:
messages: List of message dictionaries
Returns:
List of messages with image retention policy applied
"""
if self.only_n_most_recent_images is None:
return messages
return self._apply_image_retention(messages)
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Apply image retention policy to keep only the N most recent images.
Removes computer_call_output items with image_url and their corresponding computer_call items,
keeping only the most recent N image pairs based on only_n_most_recent_images setting.
Args:
messages: List of message dictionaries
Returns:
Filtered list of messages with image retention applied
"""
@@ -78,7 +79,11 @@ class ImageRetentionCallback(AsyncCallbackHandler):
# Remove the immediately preceding computer_call with matching call_id (if present)
call_id = messages[idx].get("call_id")
prev_idx = idx - 1
if prev_idx >= 0 and messages[prev_idx].get("type") == "computer_call" and messages[prev_idx].get("call_id") == call_id:
if (
prev_idx >= 0
and messages[prev_idx].get("type") == "computer_call"
and messages[prev_idx].get("call_id") == call_id
):
to_remove.add(prev_idx)
# Check a single reasoning immediately before that computer_call
r_idx = prev_idx - 1
@@ -87,4 +92,4 @@ class ImageRetentionCallback(AsyncCallbackHandler):
# Construct filtered list
filtered = [m for i, m in enumerate(messages) if i not in to_remove]
return filtered
return filtered

View File

@@ -4,17 +4,18 @@ Logging callback for ComputerAgent that provides configurable logging of agent l
import json
import logging
from typing import Dict, List, Any, Optional, Union
from typing import Any, Dict, List, Optional, Union
from .base import AsyncCallbackHandler
def sanitize_image_urls(data: Any) -> Any:
"""
Recursively search for 'image_url' keys and set their values to '[omitted]'.
Args:
data: Any data structure (dict, list, or primitive type)
Returns:
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
"""
@@ -28,11 +29,11 @@ def sanitize_image_urls(data: Any) -> Any:
# Recursively sanitize the value
sanitized[key] = sanitize_image_urls(value)
return sanitized
elif isinstance(data, list):
# Recursively sanitize each item in the list
return [sanitize_image_urls(item) for item in data]
else:
# For primitive types (str, int, bool, None, etc.), return as-is
return data
@@ -41,37 +42,36 @@ def sanitize_image_urls(data: Any) -> Any:
class LoggingCallback(AsyncCallbackHandler):
"""
Callback handler that logs agent lifecycle events with configurable verbosity.
Logging levels:
- DEBUG: All events including API calls, message preprocessing, and detailed outputs
- INFO: Major lifecycle events (start/end, messages, outputs)
- INFO: Major lifecycle events (start/end, messages, outputs)
- WARNING: Only warnings and errors
- ERROR: Only errors
"""
def __init__(self, logger: Optional[logging.Logger] = None, level: int = logging.INFO):
"""
Initialize the logging callback.
Args:
logger: Logger instance to use. If None, creates a logger named 'agent.ComputerAgent'
level: Logging level (logging.DEBUG, logging.INFO, etc.)
"""
self.logger = logger or logging.getLogger('agent.ComputerAgent')
self.logger = logger or logging.getLogger("agent.ComputerAgent")
self.level = level
# Set up logger if it doesn't have handlers
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(level)
def _update_usage(self, usage: Dict[str, Any]) -> None:
"""Update total usage statistics."""
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
for key, value in source.items():
if isinstance(value, dict):
@@ -82,18 +82,25 @@ class LoggingCallback(AsyncCallbackHandler):
if key not in target:
target[key] = 0
target[key] += value
add_dicts(self.total_usage, usage)
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Called before the run starts."""
self.total_usage = {}
async def on_usage(self, usage: Dict[str, Any]) -> None:
"""Called when usage information is received."""
self._update_usage(usage)
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
async def on_run_end(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> None:
"""Called after the run ends."""
def format_dict(d, indent=0):
lines = []
prefix = f" - {' ' * indent}"
@@ -106,10 +113,10 @@ class LoggingCallback(AsyncCallbackHandler):
else:
lines.append(f"{prefix}{key}: {value}")
return lines
formatted_output = "\n".join(format_dict(self.total_usage))
self.logger.info(f"Total usage:\n{formatted_output}")
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Called before LLM processing starts."""
if self.logger.isEnabledFor(logging.INFO):
@@ -118,27 +125,27 @@ class LoggingCallback(AsyncCallbackHandler):
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
self.logger.debug(f"LLM input messages: {json.dumps(sanitized_messages, indent=2)}")
return messages
async def on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Called after LLM processing ends."""
if self.logger.isEnabledFor(logging.DEBUG):
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
self.logger.debug(f"LLM output: {json.dumps(sanitized_messages, indent=2)}")
return messages
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a computer call starts."""
action = item.get("action", {})
action_type = action.get("type", "unknown")
action_args = {k: v for k, v in action.items() if k != "type"}
# INFO level logging for the action
self.logger.info(f"Computer: {action_type}({action_args})")
# DEBUG level logging for full details
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Computer call started: {json.dumps(action, indent=2)}")
async def on_computer_call_end(self, item: Dict[str, Any], result: Any) -> None:
"""Called when a computer call ends."""
if self.logger.isEnabledFor(logging.DEBUG):
@@ -147,48 +154,52 @@ class LoggingCallback(AsyncCallbackHandler):
if result:
sanitized_result = sanitize_image_urls(result)
self.logger.debug(f"Computer call result: {json.dumps(sanitized_result, indent=2)}")
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a function call starts."""
name = item.get("name", "unknown")
arguments = item.get("arguments", "{}")
# INFO level logging for the function call
self.logger.info(f"Function: {name}({arguments})")
# DEBUG level logging for full details
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Function call started: {name}")
async def on_function_call_end(self, item: Dict[str, Any], result: Any) -> None:
"""Called when a function call ends."""
# INFO level logging for function output (similar to function_call_output)
if result:
# Handle both list and direct result formats
if isinstance(result, list) and len(result) > 0:
output = result[0].get("output", str(result)) if isinstance(result[0], dict) else str(result[0])
output = (
result[0].get("output", str(result))
if isinstance(result[0], dict)
else str(result[0])
)
else:
output = str(result)
# Truncate long outputs
if len(output) > 100:
output = output[:100] + "..."
self.logger.info(f"Output: {output}")
# DEBUG level logging for full details
if self.logger.isEnabledFor(logging.DEBUG):
name = item.get("name", "unknown")
self.logger.debug(f"Function call completed: {name}")
if result:
self.logger.debug(f"Function call result: {json.dumps(result, indent=2)}")
async def on_text(self, item: Dict[str, Any]) -> None:
"""Called when a text message is encountered."""
# Get the role to determine if it's Agent or User
role = item.get("role", "unknown")
content_items = item.get("content", [])
# Process content items to build display text
text_parts = []
for content_item in content_items:
@@ -206,10 +217,10 @@ class LoggingCallback(AsyncCallbackHandler):
else:
# Non-text content, show as [type]
text_parts.append(f"[{content_type}]")
# Join all text parts
display_text = ''.join(text_parts) if text_parts else "[empty]"
display_text = "".join(text_parts) if text_parts else "[empty]"
# Log with appropriate level and format
if role == "assistant":
self.logger.info(f"Agent: {display_text}")
@@ -219,7 +230,7 @@ class LoggingCallback(AsyncCallbackHandler):
# Fallback for unknown roles, use debug level
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Text message ({role}): {display_text}")
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
"""Called when an API call is about to start."""
if self.logger.isEnabledFor(logging.DEBUG):
@@ -232,16 +243,18 @@ class LoggingCallback(AsyncCallbackHandler):
elif "input" in kwargs:
sanitized_input = sanitize_image_urls(kwargs["input"])
self.logger.debug(f"API call input: {json.dumps(sanitized_input, indent=2)}")
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
"""Called when an API call has completed."""
if self.logger.isEnabledFor(logging.DEBUG):
model = kwargs.get("model", "unknown")
self.logger.debug(f"API call completed for model: {model}")
self.logger.debug(f"API call result: {json.dumps(sanitize_image_urls(result), indent=2)}")
self.logger.debug(
f"API call result: {json.dumps(sanitize_image_urls(result), indent=2)}"
)
async def on_screenshot(self, item: Union[str, bytes], name: str = "screenshot") -> None:
"""Called when a screenshot is taken."""
if self.logger.isEnabledFor(logging.DEBUG):
image_size = len(item) / 1024
self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB")
self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB")

View File

@@ -9,6 +9,7 @@ Ensures agent output actions conform to expected schemas by fixing common issues
This runs in on_llm_end, which receives the output array (AgentMessage[] as dicts).
The purpose is to avoid spending another LLM call to fix broken computer call syntax when possible.
"""
from __future__ import annotations
from typing import Any, Dict, List
@@ -48,6 +49,7 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
action["type"] = "type"
action_type = action.get("type")
def _keep_keys(action: Dict[str, Any], keys_to_keep: List[str]):
"""Keep only the provided keys on action; delete everything else.
Always ensures required 'type' is present if listed in keys_to_keep.
@@ -55,6 +57,7 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
for key in list(action.keys()):
if key not in keys_to_keep:
del action[key]
# rename "coordinate" to "x", "y"
if "coordinate" in action:
action["x"] = action["coordinate"][0]
@@ -100,7 +103,6 @@ class OperatorNormalizerCallback(AsyncCallbackHandler):
keep = required_keys_by_type.get(action_type or "")
if keep:
_keep_keys(action, keep)
# # Second pass: if an assistant message is immediately followed by a computer_call,
# # replace the assistant message itself with a reasoning message with summary text.

View File

@@ -2,38 +2,41 @@
PII anonymization callback handler using Microsoft Presidio for text and image redaction.
"""
from typing import List, Dict, Any, Optional, Tuple
from .base import AsyncCallbackHandler
import base64
import io
import logging
from typing import Any, Dict, List, Optional, Tuple
from .base import AsyncCallbackHandler
try:
# TODO: Add Presidio dependencies
from PIL import Image
PRESIDIO_AVAILABLE = True
except ImportError:
PRESIDIO_AVAILABLE = False
logger = logging.getLogger(__name__)
class PIIAnonymizationCallback(AsyncCallbackHandler):
"""
Callback handler that anonymizes PII in text and images using Microsoft Presidio.
This handler:
1. Anonymizes PII in messages before sending to the agent loop
2. Deanonymizes PII in tool calls and message outputs after the agent loop
3. Redacts PII from images in computer_call_output messages
"""
def __init__(
self,
# TODO: Any extra kwargs if needed
):
"""
Initialize the PII anonymization callback.
Args:
anonymize_text: Whether to anonymize text content
anonymize_images: Whether to redact images
@@ -46,16 +49,16 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
"Presidio is not available. Install with: "
"pip install cua-agent[pii-anonymization]"
)
# TODO: Implement __init__
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Anonymize PII in messages before sending to agent loop.
Args:
messages: List of message dictionaries
Returns:
List of messages with PII anonymized
"""
@@ -63,16 +66,16 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
for msg in messages:
anonymized_msg = await self._anonymize_message(msg)
anonymized_messages.append(anonymized_msg)
return anonymized_messages
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Deanonymize PII in tool calls and message outputs after agent loop.
Args:
output: List of output dictionaries
Returns:
List of output with PII deanonymized for tool calls
"""
@@ -84,13 +87,13 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
deanonymized_output.append(deanonymized_item)
else:
deanonymized_output.append(item)
return deanonymized_output
async def _anonymize_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
# TODO: Implement _anonymize_message
return message
async def _deanonymize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
# TODO: Implement _deanonymize_item
return item

View File

@@ -2,17 +2,17 @@
Telemetry callback handler for Computer-Use Agent (cua-agent)
"""
import platform
import time
import uuid
from typing import List, Dict, Any, Optional, Union
from typing import Any, Dict, List, Optional, Union
from .base import AsyncCallbackHandler
from core.telemetry import (
record_event,
is_telemetry_enabled,
record_event,
)
import platform
from .base import AsyncCallbackHandler
SYSTEM_INFO = {
"os": platform.system().lower(),
@@ -20,32 +20,29 @@ SYSTEM_INFO = {
"python_version": platform.python_version(),
}
class TelemetryCallback(AsyncCallbackHandler):
"""
Telemetry callback handler for Computer-Use Agent (cua-agent)
Tracks agent usage, performance metrics, and optionally trajectory data.
"""
def __init__(
self,
agent,
log_trajectory: bool = False
):
def __init__(self, agent, log_trajectory: bool = False):
"""
Initialize telemetry callback.
Args:
agent: The ComputerAgent instance
log_trajectory: Whether to log full trajectory items (opt-in)
"""
self.agent = agent
self.log_trajectory = log_trajectory
# Generate session/run IDs
self.session_id = str(uuid.uuid4())
self.run_id = None
# Track timing and metrics
self.run_start_time = None
self.step_count = 0
@@ -54,126 +51,133 @@ class TelemetryCallback(AsyncCallbackHandler):
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"response_cost": 0.0
"response_cost": 0.0,
}
# Record agent initialization
if is_telemetry_enabled():
self._record_agent_initialization()
def _record_agent_initialization(self) -> None:
"""Record agent type/model and session initialization."""
agent_info = {
"session_id": self.session_id,
"agent_type": self.agent.agent_loop.__name__ if hasattr(self.agent, 'agent_loop') else 'unknown',
"model": getattr(self.agent, 'model', 'unknown'),
**SYSTEM_INFO
"agent_type": (
self.agent.agent_loop.__name__ if hasattr(self.agent, "agent_loop") else "unknown"
),
"model": getattr(self.agent, "model", "unknown"),
**SYSTEM_INFO,
}
record_event("agent_session_start", agent_info)
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Called at the start of an agent run loop."""
if not is_telemetry_enabled():
return
self.run_id = str(uuid.uuid4())
self.run_start_time = time.time()
self.step_count = 0
# Calculate input context size
input_context_size = self._calculate_context_size(old_items)
run_data = {
"session_id": self.session_id,
"run_id": self.run_id,
"start_time": self.run_start_time,
"input_context_size": input_context_size,
"num_existing_messages": len(old_items)
"num_existing_messages": len(old_items),
}
# Log trajectory if opted in
if self.log_trajectory:
trajectory = self._extract_trajectory(old_items)
if trajectory:
run_data["uploaded_trajectory"] = trajectory
record_event("agent_run_start", run_data)
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
async def on_run_end(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> None:
"""Called at the end of an agent run loop."""
if not is_telemetry_enabled() or not self.run_start_time:
return
run_duration = time.time() - self.run_start_time
run_data = {
"session_id": self.session_id,
"run_id": self.run_id,
"end_time": time.time(),
"duration_seconds": run_duration,
"num_steps": self.step_count,
"total_usage": self.total_usage.copy()
"total_usage": self.total_usage.copy(),
}
# Log trajectory if opted in
if self.log_trajectory:
trajectory = self._extract_trajectory(new_items)
if trajectory:
run_data["uploaded_trajectory"] = trajectory
record_event("agent_run_end", run_data)
async def on_usage(self, usage: Dict[str, Any]) -> None:
"""Called when usage information is received."""
if not is_telemetry_enabled():
return
# Accumulate usage stats
self.total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
self.total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
self.total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
self.total_usage["total_tokens"] += usage.get("total_tokens", 0)
self.total_usage["response_cost"] += usage.get("response_cost", 0.0)
# Record individual usage event
usage_data = {
"session_id": self.session_id,
"run_id": self.run_id,
"step": self.step_count,
**usage
**usage,
}
record_event("agent_usage", usage_data)
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
"""Called when responses are received."""
if not is_telemetry_enabled():
return
self.step_count += 1
step_duration = None
if self.step_start_time:
step_duration = time.time() - self.step_start_time
self.step_start_time = time.time()
step_data = {
"session_id": self.session_id,
"run_id": self.run_id,
"step": self.step_count,
"timestamp": self.step_start_time
"timestamp": self.step_start_time,
}
if step_duration is not None:
step_data["duration_seconds"] = step_duration
record_event("agent_step", step_data)
def _calculate_context_size(self, items: List[Dict[str, Any]]) -> int:
"""Calculate approximate context size in tokens/characters."""
total_size = 0
for item in items:
if item.get("type") == "message" and "content" in item:
content = item["content"]
@@ -185,25 +189,27 @@ class TelemetryCallback(AsyncCallbackHandler):
total_size += len(part["text"])
elif "content" in item and isinstance(item["content"], str):
total_size += len(item["content"])
return total_size
def _extract_trajectory(self, items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Extract trajectory items that should be logged."""
trajectory = []
for item in items:
# Include user messages, assistant messages, reasoning, computer calls, and computer outputs
if (
item.get("role") == "user" or # User inputs
(item.get("type") == "message" and item.get("role") == "assistant") or # Model outputs
item.get("type") == "reasoning" or # Reasoning traces
item.get("type") == "computer_call" or # Computer actions
item.get("type") == "computer_call_output" # Computer outputs
item.get("role") == "user" # User inputs
or (
item.get("type") == "message" and item.get("role") == "assistant"
) # Model outputs
or item.get("type") == "reasoning" # Reasoning traces
or item.get("type") == "computer_call" # Computer actions
or item.get("type") == "computer_call_output" # Computer outputs
):
# Create a copy of the item with timestamp
trajectory_item = item.copy()
trajectory_item["logged_at"] = time.time()
trajectory.append(trajectory_item)
return trajectory
return trajectory

View File

@@ -2,26 +2,28 @@
Trajectory saving callback handler for ComputerAgent.
"""
import os
import json
import uuid
from datetime import datetime
import base64
from pathlib import Path
from typing import List, Dict, Any, Optional, Union, override
from PIL import Image, ImageDraw
import io
import json
import os
import uuid
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Union, override
from PIL import Image, ImageDraw
from .base import AsyncCallbackHandler
def sanitize_image_urls(data: Any) -> Any:
"""
Recursively search for 'image_url' keys and set their values to '[omitted]'.
Args:
data: Any data structure (dict, list, or primitive type)
Returns:
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
"""
@@ -35,17 +37,19 @@ def sanitize_image_urls(data: Any) -> Any:
# Recursively sanitize the value
sanitized[key] = sanitize_image_urls(value)
return sanitized
elif isinstance(data, list):
# Recursively sanitize each item in the list
return [sanitize_image_urls(item) for item in data]
else:
# For primitive types (str, int, bool, None, etc.), return as-is
return data
def extract_computer_call_outputs(items: List[Dict[str, Any]], screenshot_dir: Optional[Path]) -> List[Dict[str, Any]]:
def extract_computer_call_outputs(
items: List[Dict[str, Any]], screenshot_dir: Optional[Path]
) -> List[Dict[str, Any]]:
"""
Save any base64-encoded screenshots from computer_call_output entries to files and
replace their image_url with the saved file path when a call_id is present.
@@ -103,18 +107,21 @@ def extract_computer_call_outputs(items: List[Dict[str, Any]], screenshot_dir: O
updated.append(msg)
return updated
class TrajectorySaverCallback(AsyncCallbackHandler):
"""
Callback handler that saves agent trajectories to disk.
Saves each run as a separate trajectory with unique ID, and each turn
within the trajectory gets its own folder with screenshots and responses.
"""
def __init__(self, trajectory_dir: str, reset_on_run: bool = True, screenshot_dir: Optional[str] = None):
def __init__(
self, trajectory_dir: str, reset_on_run: bool = True, screenshot_dir: Optional[str] = None
):
"""
Initialize trajectory saver.
Args:
trajectory_dir: Base directory to save trajectories
reset_on_run: If True, reset trajectory_id/turn/artifact on each run.
@@ -129,7 +136,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
self.reset_on_run = reset_on_run
# Optional directory to store extracted screenshots from metadata/new_items
self.screenshot_dir: Optional[Path] = Path(screenshot_dir) if screenshot_dir else None
# Ensure trajectory directory exists
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
@@ -137,7 +144,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
"""Get the directory for the current turn."""
if not self.trajectory_id:
raise ValueError("Trajectory not initialized - call _on_run_start first")
# format: trajectory_id/turn_000
turn_dir = self.trajectory_dir / self.trajectory_id / f"turn_{self.current_turn:03d}"
turn_dir.mkdir(parents=True, exist_ok=True)
@@ -166,6 +173,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
def _update_usage(self, usage: Dict[str, Any]) -> None:
"""Update total usage statistics."""
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
for key, value in source.items():
if isinstance(value, dict):
@@ -176,20 +184,21 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
if key not in target:
target[key] = 0
target[key] += value
add_dicts(self.total_usage, usage)
@override
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Initialize trajectory tracking for a new run."""
model = kwargs.get("model", "unknown")
# Only reset trajectory state if reset_on_run is True or no trajectory exists
if self.reset_on_run or not self.trajectory_id:
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
if "+" in model:
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
# strip non-alphanumeric characters from model_name_short
model_name_short = ''.join(c for c in model_name_short if c.isalnum() or c == '_')
model_name_short = "".join(c for c in model_name_short if c.isalnum() or c == "_")
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
now = datetime.now()
@@ -198,11 +207,11 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
self.current_artifact = 0
self.model = model
self.total_usage = {}
# Create trajectory directory
trajectory_path = self.trajectory_dir / self.trajectory_id
trajectory_path.mkdir(parents=True, exist_ok=True)
# Save trajectory metadata (optionally extract screenshots to screenshot_dir)
kwargs_to_save = kwargs.copy()
try:
@@ -219,7 +228,7 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
"status": "running",
"kwargs": kwargs_to_save,
}
with open(trajectory_path / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
else:
@@ -227,22 +236,27 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
self.model = model
@override
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
async def on_run_end(
self,
kwargs: Dict[str, Any],
old_items: List[Dict[str, Any]],
new_items: List[Dict[str, Any]],
) -> None:
"""Finalize run tracking by updating metadata with completion status, usage, and new items."""
if not self.trajectory_id:
return
# Update metadata with completion status, total usage, and new items
trajectory_path = self.trajectory_dir / self.trajectory_id
metadata_path = trajectory_path / "metadata.json"
# Read existing metadata
if metadata_path.exists():
with open(metadata_path, "r") as f:
metadata = json.load(f)
else:
metadata = {}
# Update metadata with completion info
# Optionally extract screenshots from new_items before persisting
new_items_to_save = new_items
@@ -251,32 +265,34 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
except Exception:
pass
metadata.update({
"status": "completed",
"completed_at": str(uuid.uuid1().time),
"total_usage": self.total_usage,
"new_items": new_items_to_save,
"total_turns": self.current_turn
})
metadata.update(
{
"status": "completed",
"completed_at": str(uuid.uuid1().time),
"total_usage": self.total_usage,
"new_items": new_items_to_save,
"total_turns": self.current_turn,
}
)
# Save updated metadata
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
@override
@override
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
if not self.trajectory_id:
return
self._save_artifact("api_start", { "kwargs": kwargs })
self._save_artifact("api_start", {"kwargs": kwargs})
@override
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
"""Save API call result."""
if not self.trajectory_id:
return
self._save_artifact("api_result", { "kwargs": kwargs, "result": result })
self._save_artifact("api_result", {"kwargs": kwargs, "result": result})
@override
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
@@ -295,77 +311,83 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
"""Save responses to the current turn directory and update usage statistics."""
if not self.trajectory_id:
return
# Save responses
turn_dir = self._get_turn_dir()
response_data = {
"timestamp": str(uuid.uuid1().time),
"model": self.model,
"kwargs": kwargs,
"response": responses
"response": responses,
}
self._save_artifact("agent_response", response_data)
# Increment turn counter
self.current_turn += 1
def _draw_crosshair_on_image(self, image_bytes: bytes, x: int, y: int) -> bytes:
"""
Draw a red dot and crosshair at the specified coordinates on the image.
Args:
image_bytes: The original image as bytes
x: X coordinate for the crosshair
y: Y coordinate for the crosshair
Returns:
Modified image as bytes with red dot and crosshair
"""
# Open the image
image = Image.open(io.BytesIO(image_bytes))
draw = ImageDraw.Draw(image)
# Draw crosshair lines (red, 2px thick)
crosshair_size = 20
line_width = 2
color = "red"
# Horizontal line
draw.line([(x - crosshair_size, y), (x + crosshair_size, y)], fill=color, width=line_width)
# Vertical line
draw.line([(x, y - crosshair_size), (x, y + crosshair_size)], fill=color, width=line_width)
# Draw center dot (filled circle)
dot_radius = 3
draw.ellipse([(x - dot_radius, y - dot_radius), (x + dot_radius, y + dot_radius)], fill=color)
draw.ellipse(
[(x - dot_radius, y - dot_radius), (x + dot_radius, y + dot_radius)], fill=color
)
# Convert back to bytes
output = io.BytesIO()
image.save(output, format='PNG')
image.save(output, format="PNG")
return output.getvalue()
@override
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
async def on_computer_call_end(
self, item: Dict[str, Any], result: List[Dict[str, Any]]
) -> None:
"""
Called when a computer call has completed.
Saves screenshots and computer call output.
"""
if not self.trajectory_id:
return
self._save_artifact("computer_call_result", { "item": item, "result": result })
self._save_artifact("computer_call_result", {"item": item, "result": result})
# Check if action has x/y coordinates and there's a screenshot in the result
action = item.get("action", {})
if "x" in action and "y" in action:
# Look for screenshot in the result
for result_item in result:
if (result_item.get("type") == "computer_call_output" and
result_item.get("output", {}).get("type") == "input_image"):
if (
result_item.get("type") == "computer_call_output"
and result_item.get("output", {}).get("type") == "input_image"
):
image_url = result_item["output"]["image_url"]
# Extract base64 image data
if image_url.startswith("data:image/"):
# Format: data:image/png;base64,<base64_data>
@@ -373,26 +395,24 @@ class TrajectorySaverCallback(AsyncCallbackHandler):
else:
# Assume it's just base64 data
base64_data = image_url
try:
# Decode the image
image_bytes = base64.b64decode(base64_data)
# Draw crosshair at the action coordinates
annotated_image = self._draw_crosshair_on_image(
image_bytes,
int(action["x"]),
int(action["y"])
image_bytes, int(action["x"]), int(action["y"])
)
# Save as screenshot_action
self._save_artifact("screenshot_action", annotated_image)
except Exception as e:
# If annotation fails, just log and continue
print(f"Failed to annotate screenshot: {e}")
break # Only process the first screenshot found
# Increment turn counter
self.current_turn += 1
self.current_turn += 1

View File

@@ -3,7 +3,7 @@ CLI chat interface for agent - Computer Use Agent
Usage:
python -m agent.cli <model_string>
Examples:
python -m agent.cli openai/computer-use-preview
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
@@ -11,19 +11,22 @@ Examples:
"""
try:
import asyncio
import argparse
import os
import sys
import json
from typing import List, Dict, Any
import dotenv
import asyncio
import base64
import time
import json
import os
import platform
import sys
import time
from pathlib import Path
from typing import Any, Dict, List
import dotenv
try:
from PIL import Image, ImageDraw
PIL_AVAILABLE = True
except Exception:
PIL_AVAILABLE = False
@@ -31,36 +34,44 @@ try:
except ImportError:
if __name__ == "__main__":
raise ImportError(
"CLI dependencies not found. "
"Please install with: pip install \"cua-agent[cli]\""
"CLI dependencies not found. " 'Please install with: pip install "cua-agent[cli]"'
)
# Load environment variables
dotenv.load_dotenv()
# Color codes for terminal output
class Colors:
RESET = '\033[0m'
BOLD = '\033[1m'
DIM = '\033[2m'
# Text colors
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
MAGENTA = '\033[35m'
CYAN = '\033[36m'
WHITE = '\033[37m'
GRAY = '\033[90m'
# Background colors
BG_RED = '\033[41m'
BG_GREEN = '\033[42m'
BG_YELLOW = '\033[43m'
BG_BLUE = '\033[44m'
RESET = "\033[0m"
BOLD = "\033[1m"
DIM = "\033[2m"
def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = False, end: str = "\n", right: str = ""):
# Text colors
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
MAGENTA = "\033[35m"
CYAN = "\033[36m"
WHITE = "\033[37m"
GRAY = "\033[90m"
# Background colors
BG_RED = "\033[41m"
BG_GREEN = "\033[42m"
BG_YELLOW = "\033[43m"
BG_BLUE = "\033[44m"
def print_colored(
text: str,
color: str = "",
bold: bool = False,
dim: bool = False,
end: str = "\n",
right: str = "",
):
"""Print colored text to terminal with optional right-aligned text."""
prefix = ""
if bold:
@@ -69,24 +80,25 @@ def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = Fa
prefix += Colors.DIM
if color:
prefix += color
if right:
# Get terminal width (default to 80 if unable to determine)
try:
import shutil
terminal_width = shutil.get_terminal_size().columns
except:
terminal_width = 80
# Add right margin
terminal_width -= 1
# Calculate padding needed
# Account for ANSI escape codes not taking visual space
visible_left_len = len(text)
visible_right_len = len(right)
padding = terminal_width - visible_left_len - visible_right_len
if padding > 0:
output = f"{prefix}{text}{' ' * padding}{right}{Colors.RESET}"
else:
@@ -94,7 +106,7 @@ def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = Fa
output = f"{prefix}{text} {right}{Colors.RESET}"
else:
output = f"{prefix}{text}{Colors.RESET}"
print(output, end=end)
@@ -113,29 +125,34 @@ def print_action(action_type: str, details: Dict[str, Any], total_cost: float):
args_str = f"('{details['text']}')"
elif action_type == "scroll" and "x" in details and "y" in details:
args_str = f"({details['x']}, {details['y']})"
if total_cost > 0:
print_colored(f"🛠️ {action_type}{args_str}", dim=True, right=f"💸 ${total_cost:.2f}")
else:
print_colored(f"🛠️ {action_type}{args_str}", dim=True)
def print_welcome(model: str, agent_loop: str, container_name: str):
"""Print welcome message."""
print_colored(f"Connected to {container_name} ({model}, {agent_loop})")
print_colored("Type 'exit' to quit.", dim=True)
async def ainput(prompt: str = ""):
return await asyncio.to_thread(input, prompt)
async def chat_loop(agent, model: str, container_name: str, initial_prompt: str = "", show_usage: bool = True):
async def chat_loop(
agent, model: str, container_name: str, initial_prompt: str = "", show_usage: bool = True
):
"""Main chat loop with the agent."""
print_welcome(model, agent.agent_config_info.agent_class.__name__, container_name)
history = []
if initial_prompt:
history.append({"role": "user", "content": initial_prompt})
total_cost = 0
while True:
@@ -143,31 +160,31 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
# Get user input with prompt
print_colored("> ", end="")
user_input = await ainput()
if user_input.lower() in ['exit', 'quit', 'q']:
if user_input.lower() in ["exit", "quit", "q"]:
print_colored("\n👋 Goodbye!")
break
if not user_input:
continue
# Add user message to history
history.append({"role": "user", "content": user_input})
# Stream responses from the agent with spinner
with yaspin(text="Thinking...", spinner="line", attrs=["dark"]) as spinner:
spinner.hide()
async for result in agent.run(history):
# Add agent responses to history
history.extend(result.get("output", []))
if show_usage:
total_cost += result.get("usage", {}).get("response_cost", 0)
# Process and display the output
for item in result.get("output", []):
if item.get("type") == "message":
if item.get("type") == "message" and item.get("role") == "assistant":
# Display agent text response
content = item.get("content", [])
for content_part in content:
@@ -176,7 +193,7 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
if text:
spinner.hide()
print_colored(text)
elif item.get("type") == "computer_call":
# Display computer action
action = item.get("action", {})
@@ -186,7 +203,7 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
print_action(action_type, action, total_cost)
spinner.text = f"Performing {action_type}..."
spinner.show()
elif item.get("type") == "function_call":
# Display function call
function_name = item.get("name", "")
@@ -194,18 +211,18 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
print_colored(f"🔧 Calling function: {function_name}", dim=True)
spinner.text = f"Calling {function_name}..."
spinner.show()
elif item.get("type") == "function_call_output":
# Display function output (dimmed)
output = item.get("output", "")
if output and len(output.strip()) > 0:
spinner.hide()
print_colored(f"📤 {output}", dim=True)
spinner.hide()
if show_usage and total_cost > 0:
print_colored(f"Total cost: ${total_cost:.2f}", dim=True)
async def main():
"""Main CLI function."""
@@ -218,104 +235,103 @@ Examples:
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
python -m agent.cli huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B
"""
""",
)
parser.add_argument(
"model",
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')"
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')",
)
parser.add_argument(
"--provider",
choices=["cloud", "lume", "winsandbox", "docker"],
default="cloud",
help="Computer provider to use: cloud (default), lume, winsandbox, or docker",
)
parser.add_argument(
"--images",
type=int,
default=3,
help="Number of recent images to keep in context (default: 3)"
help="Number of recent images to keep in context (default: 3)",
)
parser.add_argument("--trajectory", action="store_true", help="Save trajectory for debugging")
parser.add_argument("--budget", type=float, help="Maximum budget for the session (in dollars)")
parser.add_argument("--verbose", action="store_true", help="Enable verbose logging")
parser.add_argument(
"--trajectory",
action="store_true",
help="Save trajectory for debugging"
)
parser.add_argument(
"--budget",
type=float,
help="Maximum budget for the session (in dollars)"
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose logging"
"-p",
"--prompt",
type=str,
help="Initial prompt to send to the agent. Leave blank for interactive mode.",
)
parser.add_argument(
"-p", "--prompt",
type=str,
help="Initial prompt to send to the agent. Leave blank for interactive mode."
"--prompt-file",
type=Path,
help="Path to a UTF-8 text file whose contents will be used as the initial prompt. If provided, overrides --prompt.",
)
parser.add_argument(
"--predict-click",
dest="predict_click",
type=str,
help="Instruction for click prediction. If set, runs predict_click, draws crosshair on a fresh screenshot, saves and opens it."
help="Instruction for click prediction. If set, runs predict_click, draws crosshair on a fresh screenshot, saves and opens it.",
)
parser.add_argument("-c", "--cache", action="store_true", help="Tell the API to enable caching")
parser.add_argument(
"-u", "--usage", action="store_true", help="Show total cost of the agent runs"
)
parser.add_argument(
"-c", "--cache",
action="store_true",
help="Tell the API to enable caching"
)
parser.add_argument(
"-u", "--usage",
action="store_true",
help="Show total cost of the agent runs"
)
parser.add_argument(
"-r", "--max-retries",
"-r",
"--max-retries",
type=int,
default=3,
help="Maximum number of retries for the LLM API calls"
help="Maximum number of retries for the LLM API calls",
)
args = parser.parse_args()
# Check for required environment variables
container_name = os.getenv("CUA_CONTAINER_NAME")
cua_api_key = os.getenv("CUA_API_KEY")
# Prompt for missing environment variables
# Prompt for missing environment variables (container name always required)
if not container_name:
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
container_name = input("Enter your CUA container name: ").strip()
if not container_name:
print_colored("❌ Container name is required.")
sys.exit(1)
if not cua_api_key:
if args.provider == "cloud":
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
container_name = input("Enter your CUA container name: ").strip()
if not container_name:
print_colored("❌ Container name is required.")
sys.exit(1)
else:
container_name = "cli-sandbox"
# Only require API key for cloud provider
if args.provider == "cloud" and not cua_api_key:
print_colored("CUA_API_KEY not set.", dim=True)
cua_api_key = input("Enter your CUA API key: ").strip()
if not cua_api_key:
print_colored("❌ API key is required.")
print_colored("❌ API key is required for cloud provider.")
sys.exit(1)
# Check for provider-specific API keys based on model
provider_api_keys = {
"openai/": "OPENAI_API_KEY",
"anthropic/": "ANTHROPIC_API_KEY",
"omniparser+": "OPENAI_API_KEY",
"omniparser+": "ANTHROPIC_API_KEY",
}
# Find matching provider and check for API key
for prefix, env_var in provider_api_keys.items():
if args.model.startswith(prefix):
if prefix in args.model:
if not os.getenv(env_var):
print_colored(f"{env_var} not set.", dim=True)
api_key = input(f"Enter your {env_var.replace('_', ' ').title()}: ").strip()
@@ -325,7 +341,7 @@ Examples:
# Set the environment variable for the session
os.environ[env_var] = api_key
break
# Import here to avoid import errors if dependencies are missing
try:
from agent import ComputerAgent
@@ -334,46 +350,62 @@ Examples:
print_colored(f"❌ Import error: {e}", Colors.RED, bold=True)
print_colored("Make sure agent and computer libraries are installed.", Colors.YELLOW)
sys.exit(1)
# Resolve provider -> os_type, provider_type, api key requirement
provider_map = {
"cloud": ("linux", "cloud", True),
"lume": ("macos", "lume", False),
"winsandbox": ("windows", "winsandbox", False),
"docker": ("linux", "docker", False),
}
os_type, provider_type, needs_api_key = provider_map[args.provider]
computer_kwargs = {
"os_type": os_type,
"provider_type": provider_type,
"name": container_name,
}
if needs_api_key:
computer_kwargs["api_key"] = cua_api_key # type: ignore
# Create computer instance
async with Computer(
os_type="linux",
provider_type="cloud",
name=container_name,
api_key=cua_api_key
) as computer:
async with Computer(**computer_kwargs) as computer: # type: ignore
# Create agent
agent_kwargs = {
"model": args.model,
"tools": [computer],
"trust_remote_code": True, # needed for some local models (e.g., InternVL, OpenCUA)
"trust_remote_code": True, # needed for some local models (e.g., InternVL, OpenCUA)
"verbosity": 20 if args.verbose else 30, # DEBUG vs WARNING
"max_retries": args.max_retries
"max_retries": args.max_retries,
}
if args.images > 0:
agent_kwargs["only_n_most_recent_images"] = args.images
if args.trajectory:
agent_kwargs["trajectory_dir"] = "trajectories"
if args.budget:
agent_kwargs["max_trajectory_budget"] = {
"max_budget": args.budget,
"raise_error": True,
"reset_after_each_run": False
"reset_after_each_run": False,
}
if args.cache:
agent_kwargs["use_prompt_caching"] = True
agent = ComputerAgent(**agent_kwargs)
# If predict-click mode is requested, run once and exit
if args.predict_click:
if not PIL_AVAILABLE:
print_colored("❌ Pillow (PIL) is required for --predict-click visualization. Install with: pip install pillow", Colors.RED, bold=True)
print_colored(
"❌ Pillow (PIL) is required for --predict-click visualization. Install with: pip install pillow",
Colors.RED,
bold=True,
)
sys.exit(1)
instruction = args.predict_click
@@ -408,6 +440,7 @@ Examples:
try:
from io import BytesIO
with Image.open(BytesIO(img_bytes)) as img:
img = img.convert("RGB")
draw = ImageDraw.Draw(img)
@@ -430,9 +463,9 @@ Examples:
if system == "windows":
os.startfile(str(out_path)) # type: ignore[attr-defined]
elif system == "darwin":
os.system(f"open \"{out_path}\"")
os.system(f'open "{out_path}"')
else:
os.system(f"xdg-open \"{out_path}\"")
os.system(f'xdg-open "{out_path}"')
except Exception:
pass
except Exception as e:
@@ -442,13 +475,21 @@ Examples:
# Done
sys.exit(0)
# Start chat loop (default interactive mode)
await chat_loop(agent, args.model, container_name, args.prompt, args.usage)
# Resolve initial prompt from --prompt-file or --prompt
initial_prompt = args.prompt or ""
if args.prompt_file:
try:
initial_prompt = args.prompt_file.read_text(encoding="utf-8")
except Exception as e:
print_colored(f"❌ Failed to read --prompt-file: {e}", Colors.RED, bold=True)
sys.exit(1)
# Start chat loop (default interactive mode)
await chat_loop(agent, args.model, container_name, initial_prompt, args.usage)
if __name__ == "__main__":
try:
asyncio.run(main())
except (KeyboardInterrupt, EOFError) as _:
print_colored("\n\n👋 Goodbye!")
print_colored("\n\n👋 Goodbye!")

View File

@@ -6,27 +6,32 @@ computer interface types, supporting both the ComputerHandler protocol and the
Computer library interface.
"""
from computer import Computer as cuaComputer
from .base import AsyncComputerHandler
from .cua import cuaComputerHandler
from .custom import CustomComputerHandler
from computer import Computer as cuaComputer
def is_agent_computer(computer):
"""Check if the given computer is a ComputerHandler or CUA Computer."""
return isinstance(computer, AsyncComputerHandler) or \
isinstance(computer, cuaComputer) or \
(isinstance(computer, dict)) #and "screenshot" in computer)
return (
isinstance(computer, AsyncComputerHandler)
or isinstance(computer, cuaComputer)
or (isinstance(computer, dict))
) # and "screenshot" in computer)
async def make_computer_handler(computer):
"""
Create a computer handler from a computer interface.
Args:
computer: Either a ComputerHandler instance, Computer instance, or dict of functions
Returns:
ComputerHandler: A computer handler instance
Raises:
ValueError: If the computer type is not supported
"""
@@ -38,4 +43,4 @@ async def make_computer_handler(computer):
return computer_handler
if isinstance(computer, dict):
return CustomComputerHandler(computer)
raise ValueError(f"Unsupported computer type: {type(computer)}")
raise ValueError(f"Unsupported computer type: {type(computer)}")

View File

@@ -2,23 +2,32 @@
Base computer interface protocol for agent interactions.
"""
from typing import Protocol, Literal, List, Dict, Any, Union, Optional, runtime_checkable
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
Union,
runtime_checkable,
)
@runtime_checkable
class AsyncComputerHandler(Protocol):
"""Protocol defining the interface for computer interactions."""
# ==== Computer-Use-Preview Action Space ====
# ==== Computer-Use-Preview Action Space ====
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
"""Get the current environment type."""
...
async def get_dimensions(self) -> tuple[int, int]:
"""Get screen dimensions as (width, height)."""
...
async def screenshot(self, text: Optional[str] = None) -> str:
"""Take a screenshot and return as base64 string.
@@ -26,49 +35,49 @@ class AsyncComputerHandler(Protocol):
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
"""
...
async def click(self, x: int, y: int, button: str = "left") -> None:
"""Click at coordinates with specified button."""
...
async def double_click(self, x: int, y: int) -> None:
"""Double click at coordinates."""
...
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
"""Scroll at coordinates with specified scroll amounts."""
...
async def type(self, text: str) -> None:
"""Type text."""
...
async def wait(self, ms: int = 1000) -> None:
"""Wait for specified milliseconds."""
...
async def move(self, x: int, y: int) -> None:
"""Move cursor to coordinates."""
...
async def keypress(self, keys: Union[List[str], str]) -> None:
"""Press key combination."""
...
async def drag(self, path: List[Dict[str, int]]) -> None:
"""Drag along specified path."""
...
async def get_current_url(self) -> str:
"""Get current URL (for browser environments)."""
...
# ==== Anthropic Action Space ====
# ==== Anthropic Action Space ====
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Left mouse down at coordinates."""
...
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Left mouse up at coordinates."""
...

View File

@@ -3,24 +3,27 @@ Computer handler implementation for OpenAI computer-use-preview protocol.
"""
import base64
from typing import Dict, List, Any, Literal, Union, Optional
from .base import AsyncComputerHandler
from typing import Any, Dict, List, Literal, Optional, Union
from computer import Computer
from .base import AsyncComputerHandler
class cuaComputerHandler(AsyncComputerHandler):
"""Computer handler that implements the Computer protocol using the computer interface."""
def __init__(self, cua_computer: Computer):
"""Initialize with a computer interface (from tool schema)."""
self.cua_computer = cua_computer
self.interface = None
async def _initialize(self):
if hasattr(self.cua_computer, '_initialized') and not self.cua_computer._initialized:
if hasattr(self.cua_computer, "_initialized") and not self.cua_computer._initialized:
await self.cua_computer.run()
self.interface = self.cua_computer.interface
# ==== Computer-Use-Preview Action Space ====
# ==== Computer-Use-Preview Action Space ====
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
"""Get the current environment type."""
@@ -32,7 +35,7 @@ class cuaComputerHandler(AsyncComputerHandler):
assert self.interface is not None
screen_size = await self.interface.get_screen_size()
return screen_size["width"], screen_size["height"]
async def screenshot(self, text: Optional[str] = None) -> str:
"""Take a screenshot and return as base64 string.
@@ -41,8 +44,8 @@ class cuaComputerHandler(AsyncComputerHandler):
"""
assert self.interface is not None
screenshot_bytes = await self.interface.screenshot()
return base64.b64encode(screenshot_bytes).decode('utf-8')
return base64.b64encode(screenshot_bytes).decode("utf-8")
async def click(self, x: int, y: int, button: str = "left") -> None:
"""Click at coordinates with specified button."""
assert self.interface is not None
@@ -53,34 +56,35 @@ class cuaComputerHandler(AsyncComputerHandler):
else:
# Default to left click for unknown buttons
await self.interface.left_click(x, y)
async def double_click(self, x: int, y: int) -> None:
"""Double click at coordinates."""
assert self.interface is not None
await self.interface.double_click(x, y)
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
"""Scroll at coordinates with specified scroll amounts."""
assert self.interface is not None
await self.interface.move_cursor(x, y)
await self.interface.scroll(scroll_x, scroll_y)
async def type(self, text: str) -> None:
"""Type text."""
assert self.interface is not None
await self.interface.type_text(text)
async def wait(self, ms: int = 1000) -> None:
"""Wait for specified milliseconds."""
assert self.interface is not None
import asyncio
await asyncio.sleep(ms / 1000.0)
async def move(self, x: int, y: int) -> None:
"""Move cursor to coordinates."""
assert self.interface is not None
await self.interface.move_cursor(x, y)
async def keypress(self, keys: Union[List[str], str]) -> None:
"""Press key combination."""
assert self.interface is not None
@@ -91,38 +95,38 @@ class cuaComputerHandler(AsyncComputerHandler):
else:
# Handle key combinations
await self.interface.hotkey(*keys)
async def drag(self, path: List[Dict[str, int]]) -> None:
"""Drag along specified path."""
assert self.interface is not None
if not path:
return
# Start drag from first point
start = path[0]
await self.interface.mouse_down(start["x"], start["y"])
# Move through path
for point in path[1:]:
await self.interface.move_cursor(point["x"], point["y"])
# End drag at last point
end = path[-1]
await self.interface.mouse_up(end["x"], end["y"])
async def get_current_url(self) -> str:
"""Get current URL (for browser environments)."""
# This would need to be implemented based on the specific browser interface
# For now, return empty string
return ""
# ==== Anthropic Computer Action Space ====
# ==== Anthropic Computer Action Space ====
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Left mouse down at coordinates."""
assert self.interface is not None
await self.interface.mouse_down(x, y, button="left")
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Left mouse up at coordinates."""
assert self.interface is not None
await self.interface.mouse_up(x, y, button="left")
await self.interface.mouse_up(x, y, button="left")

View File

@@ -3,47 +3,49 @@ Custom computer handler implementation that accepts a dictionary of functions.
"""
import base64
from typing import Dict, List, Any, Literal, Union, Optional, Callable
from PIL import Image
import io
from typing import Any, Callable, Dict, List, Literal, Optional, Union
from PIL import Image
from .base import AsyncComputerHandler
class CustomComputerHandler(AsyncComputerHandler):
"""Computer handler that implements the Computer protocol using a dictionary of custom functions."""
def __init__(self, functions: Dict[str, Callable]):
"""
Initialize with a dictionary of functions.
Args:
functions: Dictionary where keys are method names and values are callable functions.
Only 'screenshot' is required, all others are optional.
Raises:
ValueError: If required 'screenshot' function is not provided.
"""
if 'screenshot' not in functions:
if "screenshot" not in functions:
raise ValueError("'screenshot' function is required in functions dictionary")
self.functions = functions
self._last_screenshot_size: Optional[tuple[int, int]] = None
async def _call_function(self, func, *args, **kwargs):
"""
Call a function, handling both async and sync functions.
Args:
func: The function to call
*args: Positional arguments to pass to the function
**kwargs: Keyword arguments to pass to the function
Returns:
The result of the function call
"""
import asyncio
import inspect
if callable(func):
if inspect.iscoroutinefunction(func):
return await func(*args, **kwargs)
@@ -51,14 +53,14 @@ class CustomComputerHandler(AsyncComputerHandler):
return func(*args, **kwargs)
else:
return func
async def _get_value(self, attribute: str):
"""
Get value for an attribute, checking both 'get_{attribute}' and '{attribute}' keys.
Args:
attribute: The attribute name to look for
Returns:
The value from the functions dict, called if callable, returned directly if not
"""
@@ -66,20 +68,20 @@ class CustomComputerHandler(AsyncComputerHandler):
get_key = f"get_{attribute}"
if get_key in self.functions:
return await self._call_function(self.functions[get_key])
# Check for '{attribute}'
# Check for '{attribute}'
if attribute in self.functions:
return await self._call_function(self.functions[attribute])
return None
def _to_b64_str(self, img: Union[bytes, Image.Image, str]) -> str:
"""
Convert image to base64 string.
Args:
img: Image as bytes, PIL Image, or base64 string
Returns:
str: Base64 encoded image string
"""
@@ -88,47 +90,47 @@ class CustomComputerHandler(AsyncComputerHandler):
return img
elif isinstance(img, bytes):
# Raw bytes
return base64.b64encode(img).decode('utf-8')
return base64.b64encode(img).decode("utf-8")
elif isinstance(img, Image.Image):
# PIL Image
buffer = io.BytesIO()
img.save(buffer, format='PNG')
return base64.b64encode(buffer.getvalue()).decode('utf-8')
img.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
else:
raise ValueError(f"Unsupported image type: {type(img)}")
# ==== Computer-Use-Preview Action Space ====
# ==== Computer-Use-Preview Action Space ====
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
"""Get the current environment type."""
result = await self._get_value('environment')
result = await self._get_value("environment")
if result is None:
return "linux"
assert result in ["windows", "mac", "linux", "browser"]
return result # type: ignore
return result # type: ignore
async def get_dimensions(self) -> tuple[int, int]:
"""Get screen dimensions as (width, height)."""
result = await self._get_value('dimensions')
result = await self._get_value("dimensions")
if result is not None:
return result # type: ignore
return result # type: ignore
# Fallback: use last screenshot size if available
if not self._last_screenshot_size:
await self.screenshot()
assert self._last_screenshot_size is not None, "Failed to get screenshot size"
return self._last_screenshot_size
async def screenshot(self, text: Optional[str] = None) -> str:
"""Take a screenshot and return as base64 string.
Args:
text: Optional descriptive text (for compatibility with GPT-4o models, ignored)
"""
result = await self._call_function(self.functions['screenshot'])
b64_str = self._to_b64_str(result) # type: ignore
result = await self._call_function(self.functions["screenshot"])
b64_str = self._to_b64_str(result) # type: ignore
# Try to extract dimensions for fallback use
try:
if isinstance(result, Image.Image):
@@ -140,74 +142,75 @@ class CustomComputerHandler(AsyncComputerHandler):
except Exception:
# If we can't get dimensions, that's okay
pass
return b64_str
async def click(self, x: int, y: int, button: str = "left") -> None:
"""Click at coordinates with specified button."""
if 'click' in self.functions:
await self._call_function(self.functions['click'], x, y, button)
if "click" in self.functions:
await self._call_function(self.functions["click"], x, y, button)
# No-op if not implemented
async def double_click(self, x: int, y: int) -> None:
"""Double click at coordinates."""
if 'double_click' in self.functions:
await self._call_function(self.functions['double_click'], x, y)
if "double_click" in self.functions:
await self._call_function(self.functions["double_click"], x, y)
# No-op if not implemented
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
"""Scroll at coordinates with specified scroll amounts."""
if 'scroll' in self.functions:
await self._call_function(self.functions['scroll'], x, y, scroll_x, scroll_y)
if "scroll" in self.functions:
await self._call_function(self.functions["scroll"], x, y, scroll_x, scroll_y)
# No-op if not implemented
async def type(self, text: str) -> None:
"""Type text."""
if 'type' in self.functions:
await self._call_function(self.functions['type'], text)
if "type" in self.functions:
await self._call_function(self.functions["type"], text)
# No-op if not implemented
async def wait(self, ms: int = 1000) -> None:
"""Wait for specified milliseconds."""
if 'wait' in self.functions:
await self._call_function(self.functions['wait'], ms)
if "wait" in self.functions:
await self._call_function(self.functions["wait"], ms)
else:
# Default implementation
import asyncio
await asyncio.sleep(ms / 1000.0)
async def move(self, x: int, y: int) -> None:
"""Move cursor to coordinates."""
if 'move' in self.functions:
await self._call_function(self.functions['move'], x, y)
if "move" in self.functions:
await self._call_function(self.functions["move"], x, y)
# No-op if not implemented
async def keypress(self, keys: Union[List[str], str]) -> None:
"""Press key combination."""
if 'keypress' in self.functions:
await self._call_function(self.functions['keypress'], keys)
if "keypress" in self.functions:
await self._call_function(self.functions["keypress"], keys)
# No-op if not implemented
async def drag(self, path: List[Dict[str, int]]) -> None:
"""Drag along specified path."""
if 'drag' in self.functions:
await self._call_function(self.functions['drag'], path)
if "drag" in self.functions:
await self._call_function(self.functions["drag"], path)
# No-op if not implemented
async def get_current_url(self) -> str:
"""Get current URL (for browser environments)."""
if 'get_current_url' in self.functions:
return await self._get_value('current_url') # type: ignore
if "get_current_url" in self.functions:
return await self._get_value("current_url") # type: ignore
return "" # Default fallback
async def left_mouse_down(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Left mouse down at coordinates."""
if 'left_mouse_down' in self.functions:
await self._call_function(self.functions['left_mouse_down'], x, y)
if "left_mouse_down" in self.functions:
await self._call_function(self.functions["left_mouse_down"], x, y)
# No-op if not implemented
async def left_mouse_up(self, x: Optional[int] = None, y: Optional[int] = None) -> None:
"""Left mouse up at coordinates."""
if 'left_mouse_up' in self.functions:
await self._call_function(self.functions['left_mouse_up'], x, y)
if "left_mouse_up" in self.functions:
await self._call_function(self.functions["left_mouse_up"], x, y)
# No-op if not implemented

View File

@@ -3,47 +3,56 @@ Decorators for agent - agent_loop decorator
"""
from typing import List, Optional
from .types import AgentConfigInfo
# Global registry
_agent_configs: List[AgentConfigInfo] = []
def register_agent(models: str, priority: int = 0):
"""
Decorator to register an AsyncAgentConfig class.
Args:
models: Regex pattern to match supported models
priority: Priority for agent selection (higher = more priority)
"""
def decorator(agent_class: type):
# Validate that the class implements AsyncAgentConfig protocol
if not hasattr(agent_class, 'predict_step'):
raise ValueError(f"Agent class {agent_class.__name__} must implement predict_step method")
if not hasattr(agent_class, 'predict_click'):
raise ValueError(f"Agent class {agent_class.__name__} must implement predict_click method")
if not hasattr(agent_class, 'get_capabilities'):
raise ValueError(f"Agent class {agent_class.__name__} must implement get_capabilities method")
if not hasattr(agent_class, "predict_step"):
raise ValueError(
f"Agent class {agent_class.__name__} must implement predict_step method"
)
if not hasattr(agent_class, "predict_click"):
raise ValueError(
f"Agent class {agent_class.__name__} must implement predict_click method"
)
if not hasattr(agent_class, "get_capabilities"):
raise ValueError(
f"Agent class {agent_class.__name__} must implement get_capabilities method"
)
# Register the agent config
config_info = AgentConfigInfo(
agent_class=agent_class,
models_regex=models,
priority=priority
agent_class=agent_class, models_regex=models, priority=priority
)
_agent_configs.append(config_info)
# Sort by priority (highest first)
_agent_configs.sort(key=lambda x: x.priority, reverse=True)
return agent_class
return decorator
def get_agent_configs() -> List[AgentConfigInfo]:
"""Get all registered agent configs"""
return _agent_configs.copy()
def find_agent_config(model: str) -> Optional[AgentConfigInfo]:
"""Find the best matching agent config for a model"""
for config_info in _agent_configs:

View File

@@ -12,7 +12,7 @@ Components:
Usage:
# Run the server and UI
python -m agent.human_tool
# Or run components separately
python -m agent.human_tool.server # API server only
python -m agent.human_tool.ui # UI only
@@ -21,9 +21,4 @@ Usage:
from .server import CompletionQueue, completion_queue
from .ui import HumanCompletionUI, create_ui
__all__ = [
"CompletionQueue",
"completion_queue",
"HumanCompletionUI",
"create_ui"
]
__all__ = ["CompletionQueue", "completion_queue", "HumanCompletionUI", "create_ui"]

View File

@@ -8,6 +8,7 @@ with a Gradio UI for human interaction.
import gradio as gr
from fastapi import FastAPI
from .server import app as fastapi_app
from .ui import create_ui
@@ -18,6 +19,7 @@ gradio_demo = create_ui()
CUSTOM_PATH = "/gradio"
app = gr.mount_gradio_app(fastapi_app, gradio_demo, path=CUSTOM_PATH)
# Add a redirect from root to Gradio UI
@fastapi_app.get("/")
async def redirect_to_ui():
@@ -25,14 +27,16 @@ async def redirect_to_ui():
return {
"message": "Human Completion Server is running",
"ui_url": "/gradio",
"api_docs": "/docs"
"api_docs": "/docs",
}
if __name__ == "__main__":
import uvicorn
print("🚀 Starting Human-in-the-Loop Completion Server...")
print("📊 API Server: http://localhost:8002")
print("🎨 Gradio UI: http://localhost:8002/gradio")
print("📚 API Docs: http://localhost:8002/docs")
uvicorn.run(app, host="0.0.0.0", port=8002)

View File

@@ -1,9 +1,9 @@
import asyncio
import uuid
from dataclasses import asdict, dataclass
from datetime import datetime
from typing import Dict, List, Any, Optional
from dataclasses import dataclass, asdict
from enum import Enum
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
@@ -49,7 +49,7 @@ class CompletionQueue:
self._queue: Dict[str, CompletionCall] = {}
self._pending_order: List[str] = []
self._lock = asyncio.Lock()
async def add_completion(self, messages: List[Dict[str, Any]], model: str) -> str:
"""Add a completion call to the queue."""
async with self._lock:
@@ -59,42 +59,47 @@ class CompletionQueue:
messages=messages,
model=model,
status=CompletionStatus.PENDING,
created_at=datetime.now()
created_at=datetime.now(),
)
self._queue[call_id] = completion_call
self._pending_order.append(call_id)
return call_id
async def get_pending_calls(self) -> List[Dict[str, Any]]:
"""Get all pending completion calls."""
async with self._lock:
pending_calls = []
for call_id in self._pending_order:
if call_id in self._queue and self._queue[call_id].status == CompletionStatus.PENDING:
if (
call_id in self._queue
and self._queue[call_id].status == CompletionStatus.PENDING
):
call = self._queue[call_id]
pending_calls.append({
"id": call.id,
"model": call.model,
"created_at": call.created_at.isoformat(),
"messages": call.messages
})
pending_calls.append(
{
"id": call.id,
"model": call.model,
"created_at": call.created_at.isoformat(),
"messages": call.messages,
}
)
return pending_calls
async def get_call_status(self, call_id: str) -> Optional[Dict[str, Any]]:
"""Get the status of a specific completion call."""
async with self._lock:
if call_id not in self._queue:
return None
call = self._queue[call_id]
result = {
"id": call.id,
"status": call.status.value,
"created_at": call.created_at.isoformat(),
"model": call.model,
"messages": call.messages
"messages": call.messages,
}
if call.completed_at:
result["completed_at"] = call.completed_at.isoformat()
if call.response:
@@ -103,69 +108,74 @@ class CompletionQueue:
result["tool_calls"] = call.tool_calls
if call.error:
result["error"] = call.error
return result
async def complete_call(self, call_id: str, response: Optional[str] = None, tool_calls: Optional[List[Dict[str, Any]]] = None) -> bool:
async def complete_call(
self,
call_id: str,
response: Optional[str] = None,
tool_calls: Optional[List[Dict[str, Any]]] = None,
) -> bool:
"""Mark a completion call as completed with a response or tool calls."""
async with self._lock:
if call_id not in self._queue:
return False
call = self._queue[call_id]
if call.status != CompletionStatus.PENDING:
return False
call.status = CompletionStatus.COMPLETED
call.completed_at = datetime.now()
call.response = response
call.tool_calls = tool_calls
# Remove from pending order
if call_id in self._pending_order:
self._pending_order.remove(call_id)
return True
async def fail_call(self, call_id: str, error: str) -> bool:
"""Mark a completion call as failed with an error."""
async with self._lock:
if call_id not in self._queue:
return False
call = self._queue[call_id]
if call.status != CompletionStatus.PENDING:
return False
call.status = CompletionStatus.FAILED
call.completed_at = datetime.now()
call.error = error
# Remove from pending order
if call_id in self._pending_order:
self._pending_order.remove(call_id)
return True
async def wait_for_completion(self, call_id: str, timeout: float = 300.0) -> Optional[str]:
"""Wait for a completion call to be completed and return the response."""
start_time = asyncio.get_event_loop().time()
while True:
status = await self.get_call_status(call_id)
if not status:
return None
if status["status"] == CompletionStatus.COMPLETED.value:
return status.get("response")
elif status["status"] == CompletionStatus.FAILED.value:
raise Exception(f"Completion failed: {status.get('error', 'Unknown error')}")
# Check timeout
if asyncio.get_event_loop().time() - start_time > timeout:
await self.fail_call(call_id, "Timeout waiting for human response")
raise TimeoutError("Timeout waiting for human response")
# Wait a bit before checking again
await asyncio.sleep(0.5)
@@ -204,9 +214,7 @@ async def get_status(call_id: str):
async def complete_call(call_id: str, response: CompletionResponse):
"""Complete a call with a human response."""
success = await completion_queue.complete_call(
call_id,
response=response.response,
tool_calls=response.tool_calls
call_id, response=response.response, tool_calls=response.tool_calls
)
if success:
return {"status": "success", "message": "Call completed"}
@@ -219,7 +227,9 @@ async def fail_call(call_id: str, error: Dict[str, str]):
"""Mark a call as failed."""
success = await completion_queue.fail_call(call_id, error.get("error", "Unknown error"))
if not success:
raise HTTPException(status_code=404, detail="Completion call not found or already completed")
raise HTTPException(
status_code=404, detail="Completion call not found or already completed"
)
return {"status": "failed"}
@@ -231,4 +241,5 @@ async def root():
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8002)

View File

@@ -1,14 +1,17 @@
import gradio as gr
import json
import time
from typing import List, Dict, Any, Optional
from datetime import datetime
import requests
from .server import completion_queue
import base64
import io
import json
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
import gradio as gr
import requests
from PIL import Image
from .server import completion_queue
class HumanCompletionUI:
def __init__(self, server_url: str = "http://localhost:8002"):
self.server_url = server_url
@@ -20,7 +23,7 @@ class HumanCompletionUI:
self.current_button: str = "left"
self.current_scroll_x: int = 0
self.current_scroll_y: int = -120
def format_messages_for_chatbot(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Format messages for display in gr.Chatbot with type='messages'."""
formatted = []
@@ -28,7 +31,7 @@ class HumanCompletionUI:
role = msg.get("role", "user")
content = msg.get("content", "")
tool_calls = msg.get("tool_calls", [])
# Handle different content formats
if isinstance(content, list):
# Multi-modal content - can include text and images
@@ -55,7 +58,7 @@ class HumanCompletionUI:
else:
# For URL images, create gr.Image with URL
formatted_content.append(gr.Image(value=image_url))
# Determine final content format
if len(formatted_content) == 1:
content = formatted_content[0]
@@ -63,28 +66,28 @@ class HumanCompletionUI:
content = formatted_content
else:
content = "[Empty content]"
# Ensure role is valid for Gradio Chatbot
if role not in ["user", "assistant"]:
role = "assistant" if role == "system" else "user"
# Invert roles for better display in human UI context
# (what the AI says becomes "user", what human should respond becomes "assistant")
if role == "user":
role = "assistant"
else:
role = "user"
# Add the main message if it has content
if content and str(content).strip():
formatted.append({"role": role, "content": content})
# Handle tool calls - create separate messages for each tool call
if tool_calls:
for tool_call in tool_calls:
function_name = tool_call.get("function", {}).get("name", "unknown")
arguments_str = tool_call.get("function", {}).get("arguments", "{}")
try:
# Parse arguments to format them nicely
arguments = json.loads(arguments_str)
@@ -92,18 +95,20 @@ class HumanCompletionUI:
except json.JSONDecodeError:
# If parsing fails, use the raw string
formatted_args = arguments_str
# Create a formatted message for the tool call
tool_call_content = f"```json\n{formatted_args}\n```"
formatted.append({
"role": role,
"content": tool_call_content,
"metadata": {"title": f"🛠️ Used {function_name}"}
})
formatted.append(
{
"role": role,
"content": tool_call_content,
"metadata": {"title": f"🛠️ Used {function_name}"},
}
)
return formatted
def get_pending_calls(self) -> List[Dict[str, Any]]:
"""Get pending calls from the server."""
try:
@@ -113,38 +118,39 @@ class HumanCompletionUI:
except Exception as e:
print(f"Error fetching pending calls: {e}")
return []
def complete_call_with_response(self, call_id: str, response: str) -> bool:
"""Complete a call with a text response."""
try:
response_data = {"response": response}
response_obj = requests.post(
f"{self.server_url}/complete/{call_id}",
json=response_data,
timeout=10
f"{self.server_url}/complete/{call_id}", json=response_data, timeout=10
)
response_obj.raise_for_status()
return True
except requests.RequestException as e:
print(f"Error completing call: {e}")
return False
def complete_call_with_tool_calls(self, call_id: str, tool_calls: List[Dict[str, Any]]) -> bool:
"""Complete a call with tool calls."""
try:
response_data = {"tool_calls": tool_calls}
response_obj = requests.post(
f"{self.server_url}/complete/{call_id}",
json=response_data,
timeout=10
f"{self.server_url}/complete/{call_id}", json=response_data, timeout=10
)
response_obj.raise_for_status()
return True
except requests.RequestException as e:
print(f"Error completing call: {e}")
return False
def complete_call(self, call_id: str, response: Optional[str] = None, tool_calls: Optional[List[Dict[str, Any]]] = None) -> bool:
def complete_call(
self,
call_id: str,
response: Optional[str] = None,
tool_calls: Optional[List[Dict[str, Any]]] = None,
) -> bool:
"""Complete a call with either a response or tool calls."""
try:
response_data = {}
@@ -152,25 +158,23 @@ class HumanCompletionUI:
response_data["response"] = response
if tool_calls:
response_data["tool_calls"] = tool_calls
response_obj = requests.post(
f"{self.server_url}/complete/{call_id}",
json=response_data,
timeout=10
f"{self.server_url}/complete/{call_id}", json=response_data, timeout=10
)
response_obj.raise_for_status()
return True
except requests.RequestException as e:
print(f"Error completing call: {e}")
return False
def get_last_image_from_messages(self, messages: List[Dict[str, Any]]) -> Optional[Any]:
"""Extract the last image from the messages for display above conversation."""
last_image = None
for msg in reversed(messages): # Start from the last message
content = msg.get("content", "")
if isinstance(content, list):
for item in reversed(content): # Get the last image in the message
if item.get("type") == "image_url":
@@ -189,13 +193,13 @@ class HumanCompletionUI:
else:
# For URL images, return the URL
return image_url
return last_image
def refresh_pending_calls(self):
"""Refresh the list of pending calls."""
pending_calls = self.get_pending_calls()
if not pending_calls:
return (
gr.update(choices=["latest"], value="latest"), # dropdown
@@ -205,27 +209,27 @@ class HumanCompletionUI:
gr.update(visible=False), # click_actions_group hidden
gr.update(visible=False), # actions_group hidden
)
# Sort pending calls by created_at to get oldest first
sorted_calls = sorted(pending_calls, key=lambda x: x.get("created_at", ""))
# Create choices for dropdown
choices = [("latest", "latest")] # Add "latest" option first
for call in sorted_calls:
call_id = call["id"]
model = call.get("model", "unknown")
created_at = call.get("created_at", "")
# Format timestamp
try:
dt = datetime.fromisoformat(created_at.replace('Z', '+00:00'))
dt = datetime.fromisoformat(created_at.replace("Z", "+00:00"))
time_str = dt.strftime("%H:%M:%S")
except:
time_str = created_at
choice_label = f"{call_id[:8]}... ({model}) - {time_str}"
choices.append((choice_label, call_id))
# Default to "latest" which shows the oldest pending conversation
selected_call_id = "latest"
if selected_call_id == "latest" and sorted_calls:
@@ -239,7 +243,7 @@ class HumanCompletionUI:
conversation = []
self.current_call_id = None
self.last_image = None
return (
gr.update(choices=choices, value="latest"),
gr.update(value=self.last_image),
@@ -248,7 +252,7 @@ class HumanCompletionUI:
gr.update(visible=True), # click_actions_group visible when there is a call
gr.update(visible=True), # actions_group visible when there is a call
)
def on_call_selected(self, selected_choice):
"""Handle when a call is selected from the dropdown."""
if not selected_choice:
@@ -259,7 +263,7 @@ class HumanCompletionUI:
gr.update(visible=False), # click_actions_group hidden
gr.update(visible=False), # actions_group hidden
)
pending_calls = self.get_pending_calls()
if not pending_calls:
return (
@@ -269,7 +273,7 @@ class HumanCompletionUI:
gr.update(visible=False), # click_actions_group hidden
gr.update(visible=False), # actions_group hidden
)
# Handle "latest" option
if selected_choice == "latest":
# Sort calls by created_at to get oldest first
@@ -284,17 +288,17 @@ class HumanCompletionUI:
if call_id_short in selected_choice:
call_id = call["id"]
break
if not call_id:
return (
gr.update(value=None), # no image
gr.update(value=[]), # empty chatbot
gr.update(interactive=False)
gr.update(interactive=False),
)
# Find the selected call
selected_call = next((c for c in pending_calls if c["id"] == call_id), None)
if not selected_call:
return (
gr.update(value=None), # no image
@@ -303,12 +307,12 @@ class HumanCompletionUI:
gr.update(visible=False), # click_actions_group hidden
gr.update(visible=False), # actions_group hidden
)
conversation = self.format_messages_for_chatbot(selected_call.get("messages", []))
self.current_call_id = call_id
# Get the last image from messages
self.last_image = self.get_last_image_from_messages(selected_call.get("messages", []))
return (
gr.update(value=self.last_image),
gr.update(value=conversation),
@@ -316,110 +320,111 @@ class HumanCompletionUI:
gr.update(visible=True), # click_actions_group visible
gr.update(visible=True), # actions_group visible
)
def submit_response(self, response_text: str):
"""Submit a text response to the current call."""
if not self.current_call_id:
return (
gr.update(value=response_text), # keep response text
gr.update(value="❌ No call selected") # status
gr.update(value="❌ No call selected"), # status
)
if not response_text.strip():
return (
gr.update(value=response_text), # keep response text
gr.update(value="❌ Response cannot be empty") # status
gr.update(value="❌ Response cannot be empty"), # status
)
success = self.complete_call_with_response(self.current_call_id, response_text)
if success:
status_msg = "✅ Response submitted successfully!"
return (
gr.update(value=""), # clear response text
gr.update(value=status_msg) # status
gr.update(value=status_msg), # status
)
else:
return (
gr.update(value=response_text), # keep response text
gr.update(value="❌ Failed to submit response") # status
gr.update(value="❌ Failed to submit response"), # status
)
def submit_action(self, action_type: str, **kwargs) -> str:
"""Submit a computer action as a tool call."""
if not self.current_call_id:
return "❌ No call selected"
import uuid
# Create tool call structure
action_data = {"type": action_type, **kwargs}
tool_call = {
"id": f"call_{uuid.uuid4().hex[:24]}",
"type": "function",
"function": {
"name": "computer",
"arguments": json.dumps(action_data)
}
"function": {"name": "computer", "arguments": json.dumps(action_data)},
}
success = self.complete_call_with_tool_calls(self.current_call_id, [tool_call])
if success:
return f"{action_type.capitalize()} action submitted as tool call"
else:
return f"❌ Failed to submit {action_type} action"
def submit_click_action(self, x: int, y: int, action_type: str = "click", button: str = "left") -> str:
def submit_click_action(
self, x: int, y: int, action_type: str = "click", button: str = "left"
) -> str:
"""Submit a coordinate-based action."""
if action_type == "click":
return self.submit_action(action_type, x=x, y=y, button=button)
else:
return self.submit_action(action_type, x=x, y=y)
def submit_type_action(self, text: str) -> str:
"""Submit a type action."""
return self.submit_action("type", text=text)
def submit_hotkey_action(self, keys: str) -> str:
"""Submit a hotkey action."""
return self.submit_action("keypress", keys=keys)
def submit_wait_action(self) -> str:
"""Submit a wait action with no kwargs."""
return self.submit_action("wait")
def submit_description_click(self, description: str, action_type: str = "click", button: str = "left") -> str:
def submit_description_click(
self, description: str, action_type: str = "click", button: str = "left"
) -> str:
"""Submit a description-based action."""
if action_type == "click":
return self.submit_action(action_type, element_description=description, button=button)
else:
return self.submit_action(action_type, element_description=description)
def wait_for_pending_calls(self, max_seconds: float = 10.0, check_interval: float = 0.2):
"""Wait for pending calls to appear or until max_seconds elapsed.
This method loops and checks for pending calls at regular intervals,
returning as soon as a pending call is found or the maximum wait time is reached.
Args:
max_seconds: Maximum number of seconds to wait
check_interval: How often to check for pending calls (in seconds)
"""
import time
start_time = time.time()
while time.time() - start_time < max_seconds:
# Check if there are any pending calls
pending_calls = self.get_pending_calls()
if pending_calls:
# Found pending calls, return immediately
return self.refresh_pending_calls()
# Wait before checking again
time.sleep(check_interval)
# Max wait time reached, return current state
return self.refresh_pending_calls()
@@ -427,79 +432,73 @@ class HumanCompletionUI:
def create_ui():
"""Create the Gradio interface."""
ui_handler = HumanCompletionUI()
with gr.Blocks(title="Human-in-the-Loop Agent Tool", fill_width=True) as demo:
gr.Markdown("# 🤖 Human-in-the-Loop Agent Tool")
gr.Markdown("Review AI conversation requests and provide human responses.")
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
screenshot_image = gr.Image(
label="Interactive Screenshot",
interactive=False,
height=600
label="Interactive Screenshot", interactive=False, height=600
)
# Action type selection for image clicks (wrapped for visibility control)
with gr.Group(visible=False) as click_actions_group:
with gr.Row():
action_type_radio = gr.Dropdown(
label="Interactive Action",
choices=["click", "double_click", "move", "left_mouse_up", "left_mouse_down", "scroll"],
choices=[
"click",
"double_click",
"move",
"left_mouse_up",
"left_mouse_down",
"scroll",
],
value="click",
scale=2
scale=2,
)
action_button_radio = gr.Dropdown(
label="Button",
choices=["left", "right", "wheel", "back", "forward"],
value="left",
visible=True,
scale=1
scale=1,
)
scroll_x_input = gr.Number(
label="scroll_x",
value=0,
visible=False,
scale=1
label="scroll_x", value=0, visible=False, scale=1
)
scroll_y_input = gr.Number(
label="scroll_y",
value=-120,
visible=False,
scale=1
label="scroll_y", value=-120, visible=False, scale=1
)
conversation_chatbot = gr.Chatbot(
label="Conversation",
type="messages",
height=500,
show_copy_button=True
label="Conversation", type="messages", height=500, show_copy_button=True
)
with gr.Column(scale=1):
with gr.Group():
call_dropdown = gr.Dropdown(
label="Select a pending conversation request",
choices=["latest"],
interactive=True,
value="latest"
value="latest",
)
refresh_btn = gr.Button("🔄 Refresh", variant="secondary")
status_display = gr.Textbox(
label="Status",
interactive=False,
value="Ready to receive requests..."
label="Status", interactive=False, value="Ready to receive requests..."
)
with gr.Group():
response_text = gr.Textbox(
label="Message",
lines=3,
placeholder="Enter your message here..."
label="Message", lines=3, placeholder="Enter your message here..."
)
submit_btn = gr.Button("📤 Submit Message", variant="primary", interactive=False)
submit_btn = gr.Button(
"📤 Submit Message", variant="primary", interactive=False
)
# Action Accordions (wrapped for visibility control)
with gr.Group(visible=False) as actions_group:
with gr.Tabs():
@@ -507,58 +506,73 @@ def create_ui():
with gr.Group():
description_text = gr.Textbox(
label="Element Description",
placeholder="e.g., 'Privacy and security option in left sidebar'"
placeholder="e.g., 'Privacy and security option in left sidebar'",
)
with gr.Row():
description_action_type = gr.Dropdown(
label="Action",
choices=["click", "double_click", "move", "left_mouse_up", "left_mouse_down"],
value="click"
choices=[
"click",
"double_click",
"move",
"left_mouse_up",
"left_mouse_down",
],
value="click",
)
description_button = gr.Dropdown(
label="Button",
choices=["left", "right", "wheel", "back", "forward"],
value="left"
value="left",
)
description_submit_btn = gr.Button("Submit Click Action")
with gr.Tab("📝 Type Action"):
with gr.Group():
type_text = gr.Textbox(
label="Text to Type",
placeholder="Enter text to type..."
label="Text to Type", placeholder="Enter text to type..."
)
type_submit_btn = gr.Button("Submit Type")
with gr.Tab("⌨️ Keypress Action"):
with gr.Group():
keypress_text = gr.Textbox(
label="Keys",
placeholder="e.g., ctrl+c, alt+tab"
label="Keys", placeholder="e.g., ctrl+c, alt+tab"
)
keypress_submit_btn = gr.Button("Submit Keypress")
with gr.Tab("🧰 Misc Actions"):
with gr.Group():
misc_action_dropdown = gr.Dropdown(
label="Action",
choices=["wait"],
value="wait"
label="Action", choices=["wait"], value="wait"
)
misc_submit_btn = gr.Button("Submit Action")
# Event handlers
refresh_btn.click(
fn=ui_handler.refresh_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
call_dropdown.change(
fn=ui_handler.on_call_selected,
inputs=[call_dropdown],
outputs=[screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
def handle_image_click(evt: gr.SelectData):
if evt.index is not None:
x, y = evt.index
@@ -568,31 +582,44 @@ def create_ui():
sx_i = int(ui_handler.current_scroll_x or 0)
sy_i = int(ui_handler.current_scroll_y or 0)
# Submit a scroll action with x,y position and scroll deltas
result = ui_handler.submit_action("scroll", x=x, y=y, scroll_x=sx_i, scroll_y=sy_i)
result = ui_handler.submit_action(
"scroll", x=x, y=y, scroll_x=sx_i, scroll_y=sy_i
)
else:
result = ui_handler.submit_click_action(x, y, action_type, button)
ui_handler.wait_for_pending_calls()
return result
return "No coordinates selected"
screenshot_image.select(
fn=handle_image_click,
outputs=[status_display]
).then(
screenshot_image.select(fn=handle_image_click, outputs=[status_display]).then(
fn=ui_handler.wait_for_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
# Response submission
submit_btn.click(
fn=ui_handler.submit_response,
inputs=[response_text],
outputs=[response_text, status_display]
outputs=[response_text, status_display],
).then(
fn=ui_handler.refresh_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
# Toggle visibility of controls based on action type
def toggle_action_controls(action_type):
# Button visible only for click
@@ -603,59 +630,63 @@ def create_ui():
# Update state
ui_handler.current_action_type = action_type or "click"
return button_vis, scroll_x_vis, scroll_y_vis
action_type_radio.change(
fn=toggle_action_controls,
inputs=[action_type_radio],
outputs=[action_button_radio, scroll_x_input, scroll_y_input]
outputs=[action_button_radio, scroll_x_input, scroll_y_input],
)
# Keep other control values in ui_handler state
def on_button_change(val):
ui_handler.current_button = (val or "left")
action_button_radio.change(
fn=on_button_change,
inputs=[action_button_radio]
)
ui_handler.current_button = val or "left"
action_button_radio.change(fn=on_button_change, inputs=[action_button_radio])
def on_scroll_x_change(val):
try:
ui_handler.current_scroll_x = int(val) if val is not None else 0
except Exception:
ui_handler.current_scroll_x = 0
scroll_x_input.change(
fn=on_scroll_x_change,
inputs=[scroll_x_input]
)
scroll_x_input.change(fn=on_scroll_x_change, inputs=[scroll_x_input])
def on_scroll_y_change(val):
try:
ui_handler.current_scroll_y = int(val) if val is not None else 0
except Exception:
ui_handler.current_scroll_y = 0
scroll_y_input.change(
fn=on_scroll_y_change,
inputs=[scroll_y_input]
)
scroll_y_input.change(fn=on_scroll_y_change, inputs=[scroll_y_input])
type_submit_btn.click(
fn=ui_handler.submit_type_action,
inputs=[type_text],
outputs=[status_display]
fn=ui_handler.submit_type_action, inputs=[type_text], outputs=[status_display]
).then(
fn=ui_handler.wait_for_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
keypress_submit_btn.click(
fn=ui_handler.submit_hotkey_action,
inputs=[keypress_text],
outputs=[status_display]
fn=ui_handler.submit_hotkey_action, inputs=[keypress_text], outputs=[status_display]
).then(
fn=ui_handler.wait_for_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
def handle_description_submit(description, action_type, button):
if description:
result = ui_handler.submit_description_click(description, action_type, button)
@@ -666,12 +697,19 @@ def create_ui():
description_submit_btn.click(
fn=handle_description_submit,
inputs=[description_text, description_action_type, description_button],
outputs=[status_display]
outputs=[status_display],
).then(
fn=ui_handler.wait_for_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
# Misc action handler
def handle_misc_submit(selected_action):
if selected_action == "wait":
@@ -681,20 +719,32 @@ def create_ui():
return f"Unsupported misc action: {selected_action}"
misc_submit_btn.click(
fn=handle_misc_submit,
inputs=[misc_action_dropdown],
outputs=[status_display]
fn=handle_misc_submit, inputs=[misc_action_dropdown], outputs=[status_display]
).then(
fn=ui_handler.wait_for_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
# Load initial data
demo.load(
fn=ui_handler.refresh_pending_calls,
outputs=[call_dropdown, screenshot_image, conversation_chatbot, submit_btn, click_actions_group, actions_group]
outputs=[
call_dropdown,
screenshot_image,
conversation_chatbot,
submit_btn,
click_actions_group,
actions_group,
],
)
return demo

View File

@@ -8,21 +8,22 @@ Exports:
- run_full_dataset(dataset, ...)
- MCPComputerAgent
"""
import time
from typing import Any, Optional
from agent.computers import is_agent_computer
from datasets import load_dataset, Dataset
from hud.datasets import Task, run_dataset
from datasets import Dataset, load_dataset
from hud import trace
from hud.datasets import Task, run_dataset
from .agent import MCPComputerAgent
# ---------------------------------------------------------------------------
# Single-task runner
# ---------------------------------------------------------------------------
async def run_single_task(
dataset: str | Dataset | list[dict[str, Any]],
*,
@@ -47,24 +48,20 @@ async def run_single_task(
# Load dataset and pick a sample
if isinstance(dataset, str):
dataset = load_dataset(dataset, split="train") # type: ignore[arg-type]
dataset = load_dataset(dataset, split="train") # type: ignore[arg-type]
elif isinstance(dataset, list):
dataset = dataset
else:
dataset = dataset["train"]
sample_task = dataset[task_id] # type: ignore[index]
task_prompt = sample_task.get("prompt", f"Task {sample_task.get('id', 0)}") # type: ignore[attr-defined]
# Filter any existing Computer tools
# The eval framework will add its own Computer tool per task
if tools:
tools = [
tool
for tool in tools
if not is_agent_computer(tool)
]
tools = [tool for tool in tools if not is_agent_computer(tool)]
with trace(name=task_prompt):
task = Task(**sample_task) # type: ignore[arg-type]
@@ -87,13 +84,14 @@ async def run_single_task(
)
print(f"Running: {task_prompt}")
result = await agent.run(task, max_steps=10)
print(f"✅ Reward: {getattr(result, 'reward')}")
print(f"✅ Reward: {result.reward}")
# ---------------------------------------------------------------------------
# Full-dataset runner
# ---------------------------------------------------------------------------
async def run_full_dataset(
dataset: str | Dataset | list[dict[str, Any]],
*,
@@ -121,9 +119,9 @@ async def run_full_dataset(
# Run with our MCP-based agent class.
if isinstance(dataset, str):
dataset_name = dataset.split('/')[-1]
dataset_name = dataset.split("/")[-1]
job_name = job_name or f"Evaluation {dataset_name}"
dataset = load_dataset(dataset, split=split) # type: ignore[arg-type]
dataset = load_dataset(dataset, split=split) # type: ignore[arg-type]
else:
dataset_name = "custom"
job_name = job_name or f"Evaluation {time.strftime('%H:%M %Y-%m-%d')}"
@@ -131,12 +129,8 @@ async def run_full_dataset(
# Filter any existing Computer tools
# The eval framework will add its own Computer tool per task
if tools:
tools = [
tool
for tool in tools
if not is_agent_computer(tool)
]
tools = [tool for tool in tools if not is_agent_computer(tool)]
# Execute evaluation
return await run_dataset(
name=job_name,
@@ -170,4 +164,4 @@ __all__ = [
"run_single_task",
"run_full_dataset",
"MCPComputerAgent",
]
]

View File

@@ -9,26 +9,26 @@ Key differences from the OpenAI OperatorAgent variant:
- Planning is executed via `ComputerAgent.run(messages)`.
- The first yielded result per step is returned as the agent response.
"""
from __future__ import annotations
import base64
import io
import uuid
from pathlib import Path
from typing import Any, ClassVar, Optional
import hud
import mcp.types as types
from agent.agent import ComputerAgent as BaseComputerAgent
from agent.callbacks import PromptInstructionsCallback
from agent.callbacks.trajectory_saver import TrajectorySaverCallback
from agent.computers import is_agent_computer
from agent.responses import make_failed_tool_call_items
from hud.agents import MCPAgent
from hud.tools.computer.settings import computer_settings
from hud.types import AgentResponse, MCPToolCall, MCPToolResult, Trace
from agent.responses import make_failed_tool_call_items
from agent.computers import is_agent_computer
from PIL import Image
import mcp.types as types
import hud
import uuid
import base64
from pathlib import Path
class MCPComputerAgent(MCPAgent):
@@ -114,8 +114,10 @@ class MCPComputerAgent(MCPAgent):
self.last_screenshot_b64 = None
buffer = io.BytesIO()
Image.new('RGB', (self.metadata["display_width"], self.metadata["display_height"])).save(buffer, format='PNG')
self.last_screenshot_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
Image.new("RGB", (self.metadata["display_width"], self.metadata["display_height"])).save(
buffer, format="PNG"
)
self.last_screenshot_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
# Ensure a computer shim is present so width/height/environment are known
computer_shim = {
@@ -128,12 +130,8 @@ class MCPComputerAgent(MCPAgent):
}
agent_tools: list[Any] = [computer_shim]
if tools:
agent_tools.extend([
tool
for tool in tools
if not is_agent_computer(tool)
])
agent_tools.extend([tool for tool in tools if not is_agent_computer(tool)])
agent_kwargs = {
"model": self.model,
"trajectory_dir": trajectory_dir,
@@ -150,9 +148,7 @@ class MCPComputerAgent(MCPAgent):
"telemetry_enabled": telemetry_enabled,
}
self.computer_agent = BaseComputerAgent(
**agent_kwargs
)
self.computer_agent = BaseComputerAgent(**agent_kwargs)
async def get_system_messages(self) -> list[Any]:
"""Create initial messages.
@@ -161,9 +157,7 @@ class MCPComputerAgent(MCPAgent):
"""
return []
async def format_blocks(
self, blocks: list[types.ContentBlock]
) -> list[dict[str, Any]]:
async def format_blocks(self, blocks: list[types.ContentBlock]) -> list[dict[str, Any]]:
"""
Format blocks for OpenAI input format.
@@ -200,42 +194,49 @@ class MCPComputerAgent(MCPAgent):
# Call the ComputerAgent LLM API
async for result in self.computer_agent.run(messages): # type: ignore[arg-type]
items = result['output']
items = result["output"]
if not items or tool_calls:
break
for item in items:
if item['type'] in ['reasoning', 'message', 'computer_call', 'function_call', 'function_call_output']:
if item["type"] in [
"reasoning",
"message",
"computer_call",
"function_call",
"function_call_output",
]:
agent_result.append(item)
# Add messages to output text
if item['type'] == 'reasoning':
if item["type"] == "reasoning":
output_text.extend(
f"Reasoning: {summary['text']}"
for summary in item['summary']
f"Reasoning: {summary['text']}" for summary in item["summary"]
)
elif item['type'] == 'message':
if isinstance(item['content'], list):
elif item["type"] == "message":
if isinstance(item["content"], list):
output_text.extend(
item['text']
for item in item['content']
if item['type'] == 'output_text'
item["text"]
for item in item["content"]
if item["type"] == "output_text"
)
elif isinstance(item['content'], str):
output_text.append(item['content'])
elif isinstance(item["content"], str):
output_text.append(item["content"])
# If we get a tool call, we're not done
if item['type'] == 'computer_call':
if item["type"] == "computer_call":
id = item["call_id"]
tool_calls.append(MCPToolCall(
name="openai_computer",
arguments=item["action"],
id=id,
))
tool_calls.append(
MCPToolCall(
name="openai_computer",
arguments=item["action"],
id=id,
)
)
is_done = False
self.tool_call_inputs[id] = agent_result
break
# if we have tool calls, we should exit the loop
if tool_calls:
break
@@ -247,7 +248,7 @@ class MCPComputerAgent(MCPAgent):
tool_calls=tool_calls,
done=is_done,
)
def _log_image(self, image_b64: str):
callbacks = self.computer_agent.callbacks
for callback in callbacks:
@@ -257,9 +258,7 @@ class MCPComputerAgent(MCPAgent):
callback._save_artifact("screenshot_after", image_bytes)
async def format_tool_results(
self,
tool_calls: list[MCPToolCall],
tool_results: list[MCPToolResult]
self, tool_calls: list[MCPToolCall], tool_results: list[MCPToolResult]
) -> list[dict[str, Any]]:
"""Extract latest screenshot from tool results in dict form.
@@ -274,45 +273,60 @@ class MCPComputerAgent(MCPAgent):
previous_output = self.previous_output.copy() or []
# First we need to remove any pending computer_calls from the end of previous_output
while previous_output and previous_output[-1]['type'] == 'computer_call':
while previous_output and previous_output[-1]["type"] == "computer_call":
previous_output.pop()
messages.extend(previous_output)
# If the call is a 'response', don't add the result
if call.name == 'response':
if call.name == "response":
continue
# Otherwise, if we have a result, we should add it to the messages
content = [
{ "type": "input_text", "text": content.text } if isinstance(content, types.TextContent)
else { "type": "input_image", "image_url": f"data:image/png;base64,{content.data}" } if isinstance(content, types.ImageContent)
else { "type": "input_text", "text": "" }
(
{"type": "input_text", "text": content.text}
if isinstance(content, types.TextContent)
else (
{
"type": "input_image",
"image_url": f"data:image/png;base64,{content.data}",
}
if isinstance(content, types.ImageContent)
else {"type": "input_text", "text": ""}
)
)
for content in result.content
]
messages.append({
"role": "user",
"content": content,
})
messages.append(
{
"role": "user",
"content": content,
}
)
continue
# Add the assistant's computer call
messages.extend(self.tool_call_inputs[call.id])
if result.isError:
error_text = "".join([
content.text
for content in result.content
if isinstance(content, types.TextContent)
])
error_text = "".join(
[
content.text
for content in result.content
if isinstance(content, types.TextContent)
]
)
# Replace computer call with failed tool call
messages.pop()
messages.extend(make_failed_tool_call_items(
tool_name=call.name,
tool_kwargs=call.arguments or {},
error_message=error_text,
call_id=call.id,
))
messages.extend(
make_failed_tool_call_items(
tool_name=call.name,
tool_kwargs=call.arguments or {},
error_message=error_text,
call_id=call.id,
)
)
else:
# Get the latest screenshot
screenshots = [
@@ -325,23 +339,27 @@ class MCPComputerAgent(MCPAgent):
if screenshots:
self._log_image(screenshots[0])
self.last_screenshot_b64 = screenshots[0]
messages.append({
"type": "computer_call_output",
"call_id": call.id,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshots[0]}"
},
})
messages.append(
{
"type": "computer_call_output",
"call_id": call.id,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshots[0]}",
},
}
)
else:
# Otherwise, replace computer call with failed tool call
messages.pop()
messages.extend(make_failed_tool_call_items(
tool_name=call.name,
tool_kwargs=call.arguments or {},
error_message="No screenshots returned.",
call_id=call.id,
))
messages.extend(
make_failed_tool_call_items(
tool_name=call.name,
tool_kwargs=call.arguments or {},
error_message="No screenshots returned.",
call_id=call.id,
)
)
return messages

View File

@@ -7,30 +7,33 @@ OpenAI-like response blocks. We intentionally only support a single-step call
by consuming the first yielded result from `ComputerAgent.run()`.
"""
import traceback
import time
import traceback
import uuid
from typing import Any, Dict, List, Optional
from agent.agent import ComputerAgent as BaseComputerAgent
from agent.callbacks import PromptInstructionsCallback
from hud.tools.computer.settings import computer_settings
from PIL import Image
from hud.agents import OperatorAgent
from hud.tools.computer.settings import computer_settings
# OpenAI Responses typed models (required)
from openai.types.responses import (
Response,
ResponseComputerToolCall,
ResponseInputParam,
ResponseOutputItem,
ResponseComputerToolCall,
ResponseOutputMessage,
ResponseOutputText,
ResponseReasoningItem,
ResponseUsage,
)
from PIL import Image
def _map_agent_output_to_openai_blocks(output_items: List[Dict[str, Any]]) -> List[ResponseOutputItem]:
def _map_agent_output_to_openai_blocks(
output_items: List[Dict[str, Any]],
) -> List[ResponseOutputItem]:
"""Map our agent output items to OpenAI ResponseOutputItem typed models.
Only a subset is supported: computer_call, assistant message (text), and reasoning.
@@ -40,14 +43,16 @@ def _map_agent_output_to_openai_blocks(output_items: List[Dict[str, Any]]) -> Li
for item in output_items or []:
t = item.get("type")
if t == "computer_call":
comp = ResponseComputerToolCall.model_validate({
"id": item.get("id") or f"cu_{uuid.uuid4().hex}",
"type": "computer_call",
"call_id": item["call_id"],
"action": item["action"],
"pending_safety_checks": item.get("pending_safety_checks", []),
"status": "completed",
})
comp = ResponseComputerToolCall.model_validate(
{
"id": item.get("id") or f"cu_{uuid.uuid4().hex}",
"type": "computer_call",
"call_id": item["call_id"],
"action": item["action"],
"pending_safety_checks": item.get("pending_safety_checks", []),
"status": "completed",
}
)
blocks.append(comp)
# we will exit early here as the responses api only supports a single step
break
@@ -55,31 +60,38 @@ def _map_agent_output_to_openai_blocks(output_items: List[Dict[str, Any]]) -> Li
content_blocks: List[ResponseOutputText] = []
for c in item.get("content", []) or []:
content_blocks.append(
ResponseOutputText.model_validate({
"type": "output_text",
"text": c["text"],
"annotations": [],
})
ResponseOutputText.model_validate(
{
"type": "output_text",
"text": c["text"],
"annotations": [],
}
)
)
if content_blocks:
msg = ResponseOutputMessage.model_validate({
"id": item.get("id") or f"msg_{uuid.uuid4()}",
"type": "message",
"role": "assistant",
"status": "completed",
"content": [ct.model_dump() for ct in content_blocks],
})
msg = ResponseOutputMessage.model_validate(
{
"id": item.get("id") or f"msg_{uuid.uuid4()}",
"type": "message",
"role": "assistant",
"status": "completed",
"content": [ct.model_dump() for ct in content_blocks],
}
)
blocks.append(msg)
elif t == "reasoning":
reasoning = ResponseReasoningItem.model_validate({
"id": item.get("id") or f"rsn_{uuid.uuid4()}",
"type": "reasoning",
"summary": item["summary"],
})
reasoning = ResponseReasoningItem.model_validate(
{
"id": item.get("id") or f"rsn_{uuid.uuid4()}",
"type": "reasoning",
"summary": item["summary"],
}
)
blocks.append(reasoning)
# Unhandled types are ignored
return blocks
def _to_plain_dict_list(items: Any) -> List[Dict[str, Any]]:
out: List[Dict[str, Any]] = []
for it in list(items):
@@ -92,6 +104,7 @@ def _to_plain_dict_list(items: Any) -> List[Dict[str, Any]]:
out.append(dict(it)) # may raise if not mapping
return out
class FakeAsyncOpenAI:
"""Minimal fake OpenAI client with only `responses.create` implemented.
@@ -132,10 +145,12 @@ class FakeAsyncOpenAI:
# Pre-pend instructions message
effective_input = full_input
if instructions:
effective_input = [{
"role": "user",
"content": instructions,
}] + full_input
effective_input = [
{
"role": "user",
"content": instructions,
}
] + full_input
# Run a single iteration of the ComputerAgent
agent_result: Optional[Dict[str, Any]] = None
@@ -152,32 +167,43 @@ class FakeAsyncOpenAI:
blocks_to_cache = full_input + output
for b in blocks_to_cache:
bid = getattr(b, "id", None) or f"tmp-{hash(repr(b))}"
self.blocks_cache[bid] = b # type: ignore[assignment]
self.blocks_cache[bid] = b # type: ignore[assignment]
block_ids.append(bid)
response_id = agent_result.get("id") or f"fake-{int(time.time()*1000)}"
self.context_cache[response_id] = block_ids
try:
return Response.model_validate({
"id": response_id,
"created_at": time.time(),
"object": "response",
"model": model,
"output": output,
"parallel_tool_calls": False,
"tool_choice": "auto",
"tools": [],
"previous_response_id": previous_response_id,
"usage": ResponseUsage.model_validate({
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
"input_tokens_details": usage.get("input_tokens_details", { "cached_tokens": 0 }),
"output_tokens_details": usage.get("output_tokens_details", { "reasoning_tokens": 0 }),
}),
})
return Response.model_validate(
{
"id": response_id,
"created_at": time.time(),
"object": "response",
"model": model,
"output": output,
"parallel_tool_calls": False,
"tool_choice": "auto",
"tools": [],
"previous_response_id": previous_response_id,
"usage": ResponseUsage.model_validate(
{
"input_tokens": usage.get("input_tokens", 0),
"output_tokens": usage.get("output_tokens", 0),
"total_tokens": usage.get("total_tokens", 0),
"input_tokens_details": usage.get(
"input_tokens_details", {"cached_tokens": 0}
),
"output_tokens_details": usage.get(
"output_tokens_details", {"reasoning_tokens": 0}
),
}
),
}
)
except Exception as e:
print(f"Error while validating agent response (attempt {attempt + 1}/{max_retries}): ", e)
print(
f"Error while validating agent response (attempt {attempt + 1}/{max_retries}): ",
e,
)
if attempt == max_retries - 1:
print(traceback.format_exc())
raise e
@@ -221,9 +247,15 @@ class ProxyOperatorAgent(OperatorAgent):
allowed_tools = allowed_tools or ["openai_computer"]
computer_shim = {
'screenshot': lambda: Image.new('RGB', (computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT)),
'environment': 'linux',
'dimensions': (computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT)
"screenshot": lambda: Image.new(
"RGB",
(computer_settings.OPENAI_COMPUTER_WIDTH, computer_settings.OPENAI_COMPUTER_HEIGHT),
),
"environment": "linux",
"dimensions": (
computer_settings.OPENAI_COMPUTER_WIDTH,
computer_settings.OPENAI_COMPUTER_HEIGHT,
),
}
# Build tools ensuring the computer_shim is included
agent_tools: list[Any] = [computer_shim]
@@ -258,6 +290,7 @@ class ProxyOperatorAgent(OperatorAgent):
**kwargs,
)
__all__ = [
"FakeAsyncOpenAI",
"ProxyOperatorAgent",

View File

@@ -3,26 +3,34 @@ Agent loops for agent
"""
# Import the loops to register them
from . import anthropic
from . import openai
from . import uitars
from . import omniparser
from . import gta1
from . import composed_grounded
from . import glm45v
from . import opencua
from . import internvl
from . import holo
from . import (
anthropic,
composed_grounded,
gemini,
glm45v,
gta1,
holo,
internvl,
moondream3,
omniparser,
openai,
opencua,
qwen,
uitars,
)
__all__ = [
"anthropic",
"openai",
"uitars",
"omniparser",
"gta1",
"composed_grounded",
"glm45v",
"anthropic",
"openai",
"uitars",
"omniparser",
"gta1",
"composed_grounded",
"glm45v",
"opencua",
"internvl",
"holo",
]
"moondream3",
"gemini",
"qwen",
]

File diff suppressed because it is too large Load Diff

View File

@@ -2,13 +2,15 @@
Base protocol for async agent configurations
"""
from typing import Protocol, List, Dict, Any, Optional, Tuple, Union
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Protocol, Tuple, Union
from ..types import AgentCapability
class AsyncAgentConfig(Protocol):
"""Protocol defining the interface for async agent configurations."""
@abstractmethod
async def predict_step(
self,
@@ -22,11 +24,11 @@ class AsyncAgentConfig(Protocol):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""
Predict the next step based on input items.
Args:
messages: Input items following Responses format (message, function_call, computer_call)
model: Model name to use
@@ -39,37 +41,34 @@ class AsyncAgentConfig(Protocol):
_on_usage: Callback for usage tracking
_on_screenshot: Callback for screenshot events
**kwargs: Additional arguments
Returns:
Dictionary with "output" (output items) and "usage" array
"""
...
@abstractmethod
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str
self, model: str, image_b64: str, instruction: str
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates based on image and instruction.
Args:
model: Model name to use
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
None or tuple with (x, y) coordinates
"""
...
@abstractmethod
def get_capabilities(self) -> List[AgentCapability]:
"""
Get list of capabilities supported by this agent config.
Returns:
List of capability strings (e.g., ["step", "click"])
"""

View File

@@ -3,122 +3,117 @@ Composed-grounded agent loop implementation that combines grounding and thinking
Uses a two-stage approach: grounding model for element detection, thinking model for reasoning.
"""
import uuid
import asyncio
import json
import base64
from typing import Dict, List, Any, Optional, Tuple
import json
import uuid
from io import BytesIO
from PIL import Image
import litellm
from typing import Any, Dict, List, Optional, Tuple
import litellm
from PIL import Image
from ..agent import find_agent_config
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..loops.base import AsyncAgentConfig
from ..responses import (
convert_computer_calls_xy2desc,
convert_responses_items_to_completion_messages,
convert_completion_messages_to_responses_items,
convert_computer_calls_desc2xy,
get_all_element_descriptions
convert_computer_calls_xy2desc,
convert_responses_items_to_completion_messages,
get_all_element_descriptions,
)
from ..agent import find_agent_config
from ..types import AgentCapability, AgentResponse, Messages, Tools
GROUNDED_COMPUTER_TOOL_SCHEMA = {
"type": "function",
"function": {
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment"
],
"description": "The action to perform (required for all actions)"
},
"element_description": {
"type": "string",
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions)"
},
"start_element_description": {
"type": "string",
"description": "Description of the element to start dragging from (required for drag action)"
},
"end_element_description": {
"type": "string",
"description": "Description of the element to drag to (required for drag action)"
},
"text": {
"type": "string",
"description": "The text to type (required for type action)"
},
"keys": {
"type": "array",
"items": {
"type": "string"
"type": "function",
"function": {
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment",
],
"description": "The action to perform (required for all actions)",
},
"element_description": {
"type": "string",
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions)",
},
"start_element_description": {
"type": "string",
"description": "Description of the element to start dragging from (required for drag action)",
},
"end_element_description": {
"type": "string",
"description": "Description of the element to drag to (required for drag action)",
},
"text": {
"type": "string",
"description": "The text to type (required for type action)",
},
"keys": {
"type": "array",
"items": {"type": "string"},
"description": "Key(s) to press (required for keypress action)",
},
"button": {
"type": "string",
"enum": ["left", "right", "wheel", "back", "forward"],
"description": "The mouse button to use for click action (required for click and double_click action)",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (required for scroll action)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (required for scroll action)",
},
},
"description": "Key(s) to press (required for keypress action)"
"required": ["action"],
},
"button": {
"type": "string",
"enum": [
"left",
"right",
"wheel",
"back",
"forward"
],
"description": "The mouse button to use for click action (required for click and double_click action)",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (required for scroll action)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (required for scroll action)",
},
},
"required": [
"action"
]
}
}
},
}
def _prepare_tools_for_grounded(tool_schemas: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Prepare tools for grounded API format"""
grounded_tools = []
for schema in tool_schemas:
if schema["type"] == "computer":
grounded_tools.append(GROUNDED_COMPUTER_TOOL_SCHEMA)
else:
grounded_tools.append(schema)
return grounded_tools
def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str]:
"""Get the last computer call output image from messages."""
for message in reversed(messages):
if (isinstance(message, dict) and
message.get("type") == "computer_call_output" and
isinstance(message.get("output"), dict) and
message["output"].get("type") == "input_image"):
if (
isinstance(message, dict)
and message.get("type") == "computer_call_output"
and isinstance(message.get("output"), dict)
and message["output"].get("type") == "input_image"
):
image_url = message["output"].get("image_url", "")
if image_url.startswith("data:image/png;base64,"):
return image_url.split(",", 1)[1]
@@ -129,14 +124,14 @@ def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str
class ComposedGroundedConfig(AsyncAgentConfig):
"""
Composed-grounded agent configuration that uses both grounding and thinking models.
The model parameter should be in format: "grounding_model+thinking_model"
e.g., "huggingface-local/HelloKKMe/GTA1-7B+gemini/gemini-1.5-pro"
"""
def __init__(self):
self.desc2xy: Dict[str, Tuple[float, float]] = {}
async def predict_step(
self,
messages: List[Dict[str, Any]],
@@ -150,11 +145,11 @@ class ComposedGroundedConfig(AsyncAgentConfig):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""
Composed-grounded predict step implementation.
Process:
0. Store last computer call image, if none then take a screenshot
1. Convert computer calls from xy to descriptions
@@ -167,18 +162,20 @@ class ComposedGroundedConfig(AsyncAgentConfig):
"""
# Parse the composed model
if "+" not in model:
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
raise ValueError(
f"Composed model must be in format 'grounding_model+thinking_model', got: {model}"
)
grounding_model, thinking_model = model.split("+", 1)
pre_output_items = []
# Step 0: Store last computer call image, if none then take a screenshot
last_image_b64 = get_last_computer_call_image(messages)
if last_image_b64 is None:
# Take a screenshot
screenshot_b64 = await computer_handler.screenshot() # type: ignore
screenshot_b64 = await computer_handler.screenshot() # type: ignore
if screenshot_b64:
call_id = uuid.uuid4().hex
pre_output_items += [
{
@@ -187,45 +184,42 @@ class ComposedGroundedConfig(AsyncAgentConfig):
"content": [
{
"type": "output_text",
"text": "Taking a screenshot to see the current computer screen."
"text": "Taking a screenshot to see the current computer screen.",
}
]
],
},
{
"action": {
"type": "screenshot"
},
"action": {"type": "screenshot"},
"call_id": call_id,
"status": "completed",
"type": "computer_call"
"type": "computer_call",
},
{
"type": "computer_call_output",
"call_id": call_id,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_b64}"
}
"image_url": f"data:image/png;base64,{screenshot_b64}",
},
},
]
last_image_b64 = screenshot_b64
# Call screenshot callback if provided
if _on_screenshot:
await _on_screenshot(screenshot_b64)
tool_schemas = _prepare_tools_for_grounded(tools) # type: ignore
tool_schemas = _prepare_tools_for_grounded(tools) # type: ignore
# Step 1: Convert computer calls from xy to descriptions
input_messages = messages + pre_output_items
messages_with_descriptions = convert_computer_calls_xy2desc(input_messages, self.desc2xy)
# Step 2: Convert responses items to completion messages
completion_messages = convert_responses_items_to_completion_messages(
messages_with_descriptions,
allow_images_in_tool_results=False
messages_with_descriptions, allow_images_in_tool_results=False
)
# Step 3: Call thinking model with litellm.acompletion
api_kwargs = {
"model": thinking_model,
@@ -233,98 +227,90 @@ class ComposedGroundedConfig(AsyncAgentConfig):
"tools": tool_schemas,
"max_retries": max_retries,
"stream": stream,
**kwargs
**kwargs,
}
if use_prompt_caching:
api_kwargs["use_prompt_caching"] = use_prompt_caching
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
# Make the completion call
response = await litellm.acompletion(**api_kwargs)
# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)
# Extract usage information
usage = {
**response.usage.model_dump(), # type: ignore
**response.usage.model_dump(), # type: ignore
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(usage)
# Step 4: Convert completion messages back to responses items format
response_dict = response.model_dump() # type: ignore
response_dict = response.model_dump() # type: ignore
choice_messages = [choice["message"] for choice in response_dict["choices"]]
thinking_output_items = []
for choice_message in choice_messages:
thinking_output_items.extend(convert_completion_messages_to_responses_items([choice_message]))
thinking_output_items.extend(
convert_completion_messages_to_responses_items([choice_message])
)
# Step 5: Get all element descriptions and populate desc2xy mapping
element_descriptions = get_all_element_descriptions(thinking_output_items)
if element_descriptions and last_image_b64:
# Use grounding model to predict coordinates for each description
grounding_agent_conf = find_agent_config(grounding_model)
if grounding_agent_conf:
grounding_agent = grounding_agent_conf.agent_class()
for desc in element_descriptions:
for _ in range(3): # try 3 times
for _ in range(3): # try 3 times
coords = await grounding_agent.predict_click(
model=grounding_model,
image_b64=last_image_b64,
instruction=desc
model=grounding_model, image_b64=last_image_b64, instruction=desc
)
if coords:
self.desc2xy[desc] = coords
break
# Step 6: Convert computer calls from descriptions back to xy coordinates
final_output_items = convert_computer_calls_desc2xy(thinking_output_items, self.desc2xy)
# Step 7: Return output and usage
return {
"output": pre_output_items + final_output_items,
"usage": usage
}
return {"output": pre_output_items + final_output_items, "usage": usage}
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates using the grounding model.
For composed models, uses only the grounding model part for click prediction.
"""
# Parse the composed model to get grounding model
if "+" not in model:
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
raise ValueError(
f"Composed model must be in format 'grounding_model+thinking_model', got: {model}"
)
grounding_model, thinking_model = model.split("+", 1)
# Find and use the grounding agent
grounding_agent_conf = find_agent_config(grounding_model)
if grounding_agent_conf:
grounding_agent = grounding_agent_conf.agent_class()
return await grounding_agent.predict_click(
model=grounding_model,
image_b64=image_b64,
instruction=instruction,
**kwargs
model=grounding_model, image_b64=image_b64, instruction=instruction, **kwargs
)
return None
def get_capabilities(self) -> List[AgentCapability]:
"""Return the capabilities supported by this agent."""
return ["click", "step"]

View File

@@ -0,0 +1,410 @@
"""
Gemini 2.5 Computer Use agent loop
Maps internal Agent SDK message format to Google's Gemini Computer Use API and back.
Key features:
- Lazy import of google.genai
- Configure Computer Use tool with excluded browser-specific predefined functions
- Optional custom function declarations hook for computer-call specific functions
- Convert Gemini function_call parts into internal computer_call actions
"""
from __future__ import annotations
import base64
import io
import uuid
from typing import Any, Dict, List, Optional, Tuple
from PIL import Image
from ..decorators import register_agent
from ..loops.base import AsyncAgentConfig
from ..types import AgentCapability
def _lazy_import_genai():
"""Import google.genai lazily to avoid hard dependency unless used."""
try:
from google import genai # type: ignore
from google.genai import types # type: ignore
return genai, types
except Exception as e: # pragma: no cover
raise RuntimeError(
"google.genai is required for the Gemini Computer Use loop. Install the Google Gemini SDK."
) from e
def _data_url_to_bytes(data_url: str) -> Tuple[bytes, str]:
"""Convert a data URL to raw bytes and mime type."""
if not data_url.startswith("data:"):
# Assume it's base64 png payload
try:
return base64.b64decode(data_url), "image/png"
except Exception:
return b"", "application/octet-stream"
header, b64 = data_url.split(",", 1)
mime = "image/png"
if ";" in header:
mime = header.split(";")[0].split(":", 1)[1] or "image/png"
return base64.b64decode(b64), mime
def _bytes_image_size(img_bytes: bytes) -> Tuple[int, int]:
try:
img = Image.open(io.BytesIO(img_bytes))
return img.size
except Exception:
return (1024, 768)
def _find_last_user_text(messages: List[Dict[str, Any]]) -> List[str]:
texts: List[str] = []
for msg in reversed(messages):
if msg.get("type") in (None, "message") and msg.get("role") == "user":
content = msg.get("content")
if isinstance(content, str):
return [content]
elif isinstance(content, list):
for c in content:
if c.get("type") in ("input_text", "output_text") and c.get("text"):
texts.append(c["text"]) # newest first
if texts:
return list(reversed(texts))
return []
def _find_last_screenshot(messages: List[Dict[str, Any]]) -> Optional[bytes]:
for msg in reversed(messages):
if msg.get("type") == "computer_call_output":
out = msg.get("output", {})
if isinstance(out, dict) and out.get("type") in ("input_image", "computer_screenshot"):
image_url = out.get("image_url", "")
if image_url:
data, _ = _data_url_to_bytes(image_url)
return data
return None
def _denormalize(v: int, size: int) -> int:
# Gemini returns 0-999 normalized
try:
return max(0, min(size - 1, int(round(v / 1000 * size))))
except Exception:
return 0
def _map_gemini_fc_to_computer_call(
fc: Dict[str, Any],
screen_w: int,
screen_h: int,
) -> Optional[Dict[str, Any]]:
name = fc.get("name")
args = fc.get("args", {}) or {}
action: Dict[str, Any] = {}
if name == "click_at":
x = _denormalize(int(args.get("x", 0)), screen_w)
y = _denormalize(int(args.get("y", 0)), screen_h)
action = {"type": "click", "x": x, "y": y, "button": "left"}
elif name == "type_text_at":
x = _denormalize(int(args.get("x", 0)), screen_w)
y = _denormalize(int(args.get("y", 0)), screen_h)
text = args.get("text", "")
if args.get("press_enter") == True:
text += "\n"
action = {"type": "type", "x": x, "y": y, "text": text}
elif name == "hover_at":
x = _denormalize(int(args.get("x", 0)), screen_w)
y = _denormalize(int(args.get("y", 0)), screen_h)
action = {"type": "move", "x": x, "y": y}
elif name == "key_combination":
keys = str(args.get("keys", ""))
action = {"type": "keypress", "keys": keys}
elif name == "scroll_document":
direction = args.get("direction", "down")
magnitude = 800
dx, dy = 0, 0
if direction == "down":
dy = magnitude
elif direction == "up":
dy = -magnitude
elif direction == "right":
dx = magnitude
elif direction == "left":
dx = -magnitude
action = {
"type": "scroll",
"scroll_x": dx,
"scroll_y": dy,
"x": int(screen_w / 2),
"y": int(screen_h / 2),
}
elif name == "scroll_at":
x = _denormalize(int(args.get("x", 500)), screen_w)
y = _denormalize(int(args.get("y", 500)), screen_h)
direction = args.get("direction", "down")
magnitude = int(args.get("magnitude", 800))
dx, dy = 0, 0
if direction == "down":
dy = magnitude
elif direction == "up":
dy = -magnitude
elif direction == "right":
dx = magnitude
elif direction == "left":
dx = -magnitude
action = {"type": "scroll", "scroll_x": dx, "scroll_y": dy, "x": x, "y": y}
elif name == "drag_and_drop":
x = _denormalize(int(args.get("x", 0)), screen_w)
y = _denormalize(int(args.get("y", 0)), screen_h)
dx = _denormalize(int(args.get("destination_x", x)), screen_w)
dy = _denormalize(int(args.get("destination_y", y)), screen_h)
action = {
"type": "drag",
"start_x": x,
"start_y": y,
"end_x": dx,
"end_y": dy,
"button": "left",
}
elif name == "wait_5_seconds":
action = {"type": "wait"}
else:
# Unsupported / excluded browser-specific or custom function; ignore
return None
return {
"type": "computer_call",
"call_id": uuid.uuid4().hex,
"status": "completed",
"action": action,
}
@register_agent(models=r"^gemini-2\.5-computer-use-preview-10-2025$")
class GeminiComputerUseConfig(AsyncAgentConfig):
async def predict_step(
self,
messages: List[Dict[str, Any]],
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs,
) -> Dict[str, Any]:
genai, types = _lazy_import_genai()
client = genai.Client()
# Build excluded predefined functions for browser-specific behavior
excluded = [
"open_web_browser",
"search",
"navigate",
"go_forward",
"go_back",
"scroll_document",
]
# Optional custom functions: can be extended by host code via `tools` parameter later if desired
CUSTOM_FUNCTION_DECLARATIONS: List[Any] = []
# Compose tools config
generate_content_config = types.GenerateContentConfig(
tools=[
types.Tool(
computer_use=types.ComputerUse(
environment=types.Environment.ENVIRONMENT_BROWSER,
excluded_predefined_functions=excluded,
)
),
# types.Tool(function_declarations=CUSTOM_FUNCTION_DECLARATIONS), # enable when custom functions needed
]
)
# Prepare contents: last user text + latest screenshot
user_texts = _find_last_user_text(messages)
screenshot_bytes = _find_last_screenshot(messages)
parts: List[Any] = []
for t in user_texts:
parts.append(types.Part(text=t))
screen_w, screen_h = 1024, 768
if screenshot_bytes:
screen_w, screen_h = _bytes_image_size(screenshot_bytes)
parts.append(types.Part.from_bytes(data=screenshot_bytes, mime_type="image/png"))
# If we don't have any content, at least pass an empty user part to prompt reasoning
if not parts:
parts = [types.Part(text="Proceed to the next action.")]
contents = [types.Content(role="user", parts=parts)]
api_kwargs = {
"model": model,
"contents": contents,
"config": generate_content_config,
}
if _on_api_start:
await _on_api_start(
{
"model": api_kwargs["model"],
# "contents": api_kwargs["contents"], # Disabled for now
"config": api_kwargs["config"],
}
)
response = client.models.generate_content(**api_kwargs)
if _on_api_end:
await _on_api_end(
{
"model": api_kwargs["model"],
# "contents": api_kwargs["contents"], # Disabled for now
"config": api_kwargs["config"],
},
response,
)
# Usage (Gemini SDK may not always provide token usage; populate when available)
usage: Dict[str, Any] = {}
try:
# Some SDKs expose response.usage; if available, copy
if getattr(response, "usage_metadata", None):
md = response.usage_metadata
usage = {
"prompt_tokens": getattr(md, "prompt_token_count", None) or 0,
"completion_tokens": getattr(md, "candidates_token_count", None) or 0,
"total_tokens": getattr(md, "total_token_count", None) or 0,
}
except Exception:
pass
if _on_usage and usage:
await _on_usage(usage)
# Parse output into internal items
output_items: List[Dict[str, Any]] = []
candidate = response.candidates[0]
# Text parts from the model (assistant message)
text_parts: List[str] = []
function_calls: List[Dict[str, Any]] = []
for p in candidate.content.parts:
if getattr(p, "text", None):
text_parts.append(p.text)
if getattr(p, "function_call", None):
# p.function_call has name and args
fc = {
"name": getattr(p.function_call, "name", None),
"args": dict(getattr(p.function_call, "args", {}) or {}),
}
function_calls.append(fc)
if text_parts:
output_items.append(
{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "\n".join(text_parts)}],
}
)
# Map function calls to internal computer_call actions
for fc in function_calls:
item = _map_gemini_fc_to_computer_call(fc, screen_w, screen_h)
if item is not None:
output_items.append(item)
return {"output": output_items, "usage": usage}
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs,
) -> Optional[Tuple[float, float]]:
"""Ask Gemini CUA to output a single click action for the given instruction.
Excludes all predefined tools except `click_at` and sends the screenshot.
Returns pixel (x, y) if a click is proposed, else None.
"""
genai, types = _lazy_import_genai()
client = genai.Client()
# Exclude all but click_at
exclude_all_but_click = [
"open_web_browser",
"wait_5_seconds",
"go_back",
"go_forward",
"search",
"navigate",
"hover_at",
"type_text_at",
"key_combination",
"scroll_document",
"scroll_at",
"drag_and_drop",
]
config = types.GenerateContentConfig(
tools=[
types.Tool(
computer_use=types.ComputerUse(
environment=types.Environment.ENVIRONMENT_BROWSER,
excluded_predefined_functions=exclude_all_but_click,
)
)
]
)
# Prepare prompt parts
try:
img_bytes = base64.b64decode(image_b64)
except Exception:
img_bytes = b""
w, h = _bytes_image_size(img_bytes) if img_bytes else (1024, 768)
parts: List[Any] = [types.Part(text=f"Click {instruction}.")]
if img_bytes:
parts.append(types.Part.from_bytes(data=img_bytes, mime_type="image/png"))
contents = [types.Content(role="user", parts=parts)]
response = client.models.generate_content(
model=model,
contents=contents,
config=config,
)
# Parse first click_at
try:
candidate = response.candidates[0]
for p in candidate.content.parts:
fc = getattr(p, "function_call", None)
if fc and getattr(fc, "name", None) == "click_at":
args = dict(getattr(fc, "args", {}) or {})
x = _denormalize(int(args.get("x", 0)), w)
y = _denormalize(int(args.get("y", 0)), h)
return float(x), float(y)
except Exception:
return None
return None
def get_capabilities(self) -> List[AgentCapability]:
return ["click", "step"]

View File

@@ -4,33 +4,36 @@ Supports vision-language models for computer control with bounding box parsing.
"""
import asyncio
import json
import base64
import json
import re
from typing import Dict, List, Any, Optional, Tuple
from io import BytesIO
from PIL import Image
from typing import Any, Dict, List, Optional, Tuple
import litellm
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
from litellm.types.utils import ModelResponse
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
from PIL import Image
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..loops.base import AsyncAgentConfig
from ..responses import (
convert_responses_items_to_completion_messages,
convert_completion_messages_to_responses_items,
make_reasoning_item,
make_output_text_item,
convert_responses_items_to_completion_messages,
make_click_item,
make_double_click_item,
make_drag_item,
make_input_image_item,
make_keypress_item,
make_output_text_item,
make_reasoning_item,
make_scroll_item,
make_type_item,
make_wait_item,
make_input_image_item
)
from ..types import AgentCapability, AgentResponse, Messages, Tools
# GLM-4.5V specific constants
GLM_ACTION_SPACE = """
@@ -251,16 +254,18 @@ Call rule: `FAIL()`
}
}"""
def encode_image_to_base64(image_path: str) -> str:
"""Encode image file to base64 string with data URI."""
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return f"data:image/png;base64,{encoded_string}"
def parse_glm_response(response: str) -> Dict[str, Any]:
"""
Parse GLM-4.5V response to extract action and memory.
The special tokens <|begin_of_box|> and <|end_of_box|> mark bounding boxes.
Coordinates are normalized values between 0 and 1000.
"""
@@ -274,26 +279,23 @@ def parse_glm_response(response: str) -> Dict[str, Any]:
action_pattern = r"[\w_]+\([^)]*\)"
matches = re.findall(action_pattern, response)
action = matches[0] if matches else None
# Extract memory section
memory_pattern = r"Memory:(.*?)$"
memory_match = re.search(memory_pattern, response, re.DOTALL)
memory = memory_match.group(1).strip() if memory_match else "[]"
# Extract action text (everything before Memory:)
action_text_pattern = r'^(.*?)Memory:'
action_text_pattern = r"^(.*?)Memory:"
action_text_match = re.search(action_text_pattern, response, re.DOTALL)
action_text = action_text_match.group(1).strip() if action_text_match else response
# Clean up action text by removing special tokens
if action_text:
action_text = action_text.replace("<|begin_of_box|>", "").replace("<|end_of_box|>", "")
return {
"action": action,
"action_text": action_text,
"memory": memory
}
return {"action": action, "action_text": action_text, "memory": memory}
def get_last_image_from_messages(messages: Messages) -> Optional[str]:
"""Extract the last image from messages for processing."""
@@ -314,23 +316,28 @@ def get_last_image_from_messages(messages: Messages) -> Optional[str]:
image_url_obj = item.get("image_url", {})
if isinstance(image_url_obj, dict):
image_url = image_url_obj.get("url", "")
if isinstance(image_url, str) and image_url.startswith("data:image/"):
if isinstance(image_url, str) and image_url.startswith(
"data:image/"
):
return image_url.split(",", 1)[1]
return None
def convert_responses_items_to_glm45v_pc_prompt(messages: Messages, task: str, memory: str = "") -> List[Dict[str, Any]]:
def convert_responses_items_to_glm45v_pc_prompt(
messages: Messages, task: str, memory: str = ""
) -> List[Dict[str, Any]]:
"""Convert responses items to GLM-4.5V PC prompt format with historical actions.
Args:
messages: List of message items from the conversation
task: The task description
memory: Current memory state
Returns:
List of content items for the prompt (text and image_url items)
"""
action_space = GLM_ACTION_SPACE
# Template head
head_text = f"""You are a GUI Agent, and your primary task is to respond accurately to user requests or questions. In addition to directly answering the user's queries, you can also use tools or perform GUI operations directly until you fulfill the user's request or provide a correct answer. You should carefully read and understand the images and questions provided by the user, and engage in thinking and reflection when appropriate. The coordinates involved are all represented in thousandths (0-999).
@@ -345,7 +352,7 @@ Ubuntu
# Historical Actions and Current Memory
History:"""
# Template tail
tail_text = f"""
Memory:
@@ -363,18 +370,18 @@ Memory:
Current Screenshot:
"""
# Build history from messages
history = []
history_images = []
# Group messages into steps
current_step = []
step_num = 0
for message in messages:
msg_type = message.get("type")
if msg_type == "reasoning":
current_step.append(message)
elif msg_type == "message" and message.get("role") == "assistant":
@@ -386,7 +393,7 @@ Current Screenshot:
# End of step - process it
if current_step:
step_num += 1
# Extract bot thought from message content
bot_thought = ""
for item in current_step:
@@ -397,14 +404,14 @@ Current Screenshot:
bot_thought = content_item.get("text", "")
break
break
# Extract action from computer_call
action_text = ""
for item in current_step:
if item.get("type") == "computer_call":
action = item.get("action", {})
action_type = action.get("type", "")
if action_type == "click":
x, y = action.get("x", 0), action.get("y", 0)
# Convert to 0-999 range (assuming screen dimensions)
@@ -436,7 +443,7 @@ Current Screenshot:
elif action_type == "wait":
action_text = "WAIT()"
break
# Extract screenshot from computer_call_output
screenshot_url = None
for item in current_step:
@@ -445,34 +452,34 @@ Current Screenshot:
if output.get("type") == "input_image":
screenshot_url = output.get("image_url", "")
break
# Store step info
step_info = {
"step_num": step_num,
"bot_thought": bot_thought,
"action_text": action_text,
"screenshot_url": screenshot_url
"screenshot_url": screenshot_url,
}
history.append(step_info)
# Store screenshot for last 4 steps
if screenshot_url:
history_images.append(screenshot_url)
current_step = []
# Build content array with head, history, and tail
content = []
current_text = head_text
total_history_steps = len(history)
history_image_count = min(4, len(history_images)) # Last 4 images
for step_idx, step_info in enumerate(history):
step_num = step_info["step_num"]
bot_thought = step_info["bot_thought"]
action_text = step_info["action_text"]
if step_idx < total_history_steps - history_image_count:
# For steps beyond the last 4, use text placeholder
current_text += f"\nstep {step_num}: Screenshot:(Omitted in context.) Thought: {bot_thought}\nAction: {action_text}"
@@ -480,20 +487,21 @@ Current Screenshot:
# For the last 4 steps, insert images
current_text += f"\nstep {step_num}: Screenshot:"
content.append({"type": "text", "text": current_text})
# Add image
img_idx = step_idx - (total_history_steps - history_image_count)
if img_idx < len(history_images):
content.append({"type": "image_url", "image_url": {"url": history_images[img_idx]}})
current_text = f" Thought: {bot_thought}\nAction: {action_text}"
# Add tail
current_text += tail_text
content.append({"type": "text", "text": current_text})
return content
def model_dump(obj) -> Dict[str, Any]:
if isinstance(obj, dict):
return {k: model_dump(v) for k, v in obj.items()}
@@ -502,58 +510,61 @@ def model_dump(obj) -> Dict[str, Any]:
else:
return obj
def convert_glm_completion_to_responses_items(response: ModelResponse, image_width: int, image_height: int) -> List[Dict[str, Any]]:
def convert_glm_completion_to_responses_items(
response: ModelResponse, image_width: int, image_height: int
) -> List[Dict[str, Any]]:
"""
Convert GLM-4.5V completion response to responses items format.
Args:
response: LiteLLM ModelResponse from GLM-4.5V
image_width: Original image width for coordinate scaling
image_height: Original image height for coordinate scaling
Returns:
List of response items in the proper format
"""
import uuid
response_items = []
if not response.choices or not response.choices[0].message:
return response_items
message = response.choices[0].message
content = message.content or ""
reasoning_content = getattr(message, 'reasoning_content', None)
reasoning_content = getattr(message, "reasoning_content", None)
# Add reasoning item if present
if reasoning_content:
reasoning_item = model_dump(make_reasoning_item(reasoning_content))
response_items.append(reasoning_item)
# Parse the content to extract action and text
parsed_response = parse_glm_response(content)
action = parsed_response.get("action", "")
action_text = parsed_response.get("action_text", "")
# Add message item with text content (excluding action and memory)
if action_text:
# Remove action from action_text if it's there
clean_text = action_text
if action and action in clean_text:
clean_text = clean_text.replace(action, "").strip()
# Remove memory section
memory_pattern = r"Memory:\s*\[.*?\]\s*$"
clean_text = re.sub(memory_pattern, "", clean_text, flags=re.DOTALL).strip()
if clean_text:
message_item = model_dump(make_output_text_item(clean_text))
response_items.append(message_item)
# Convert action to computer call if present
if action:
call_id = f"call_{uuid.uuid4().hex[:8]}"
# Parse different action types and create appropriate computer calls
if action.startswith("left_click"):
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
@@ -566,7 +577,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action.startswith("right_click"):
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
if coord_match:
@@ -577,7 +588,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action.startswith("left_double_click"):
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
if coord_match:
@@ -588,7 +599,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action.startswith("left_drag"):
start_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
end_match = re.search(r"end_box='?\[(\d+),\s*(\d+)\]'?", action)
@@ -605,18 +616,18 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action.startswith("key"):
key_match = re.search(r"keys='([^']+)'", action)
if key_match:
keys = key_match.group(1)
# Split keys by '+' for key combinations, or use as single key
key_list = keys.split('+') if '+' in keys else [keys]
key_list = keys.split("+") if "+" in keys else [keys]
computer_call = model_dump(make_keypress_item(key_list))
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action.startswith("type"):
content_match = re.search(r"content='([^']*)'", action)
if content_match:
@@ -625,7 +636,7 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action.startswith("scroll"):
coord_match = re.search(r"start_box='?\[(\d+),\s*(\d+)\]'?", action)
direction_match = re.search(r"direction='([^']+)'", action)
@@ -648,15 +659,16 @@ def convert_glm_completion_to_responses_items(response: ModelResponse, image_wid
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
elif action == "WAIT()":
computer_call = model_dump(make_wait_item())
computer_call["call_id"] = call_id
computer_call["status"] = "completed"
response_items.append(computer_call)
return response_items
@register_agent(models=r"(?i).*GLM-4\.5V.*")
class Glm4vConfig(AsyncAgentConfig):
"""GLM-4.5V agent configuration using liteLLM."""
@@ -674,11 +686,11 @@ class Glm4vConfig(AsyncAgentConfig):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""
Predict the next step using GLM-4.5V model.
Args:
messages: Input messages following Responses format
model: Model name to use
@@ -691,7 +703,7 @@ class Glm4vConfig(AsyncAgentConfig):
_on_api_end: Callback for API end
_on_usage: Callback for usage tracking
_on_screenshot: Callback for screenshot events
Returns:
Dict with "output" and "usage" keys
"""
@@ -708,7 +720,7 @@ class Glm4vConfig(AsyncAgentConfig):
user_instruction = item.get("text", "")
break
break
# Get the last image for processing
last_image_b64 = get_last_image_from_messages(messages)
if not last_image_b64 and computer_handler:
@@ -718,35 +730,28 @@ class Glm4vConfig(AsyncAgentConfig):
last_image_b64 = screenshot_b64
if _on_screenshot:
await _on_screenshot(screenshot_b64)
if not last_image_b64:
raise ValueError("No image available for GLM-4.5V processing")
# Convert responses items to GLM-4.5V PC prompt format with historical actions
prompt_content = convert_responses_items_to_glm45v_pc_prompt(
messages=messages,
task=user_instruction,
memory="[]" # Initialize with empty memory for now
memory="[]", # Initialize with empty memory for now
)
# Add the current screenshot to the end
prompt_content.append({
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{last_image_b64}"}
})
prompt_content.append(
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{last_image_b64}"}}
)
# Prepare messages for liteLLM
litellm_messages = [
{
"role": "system",
"content": "You are a helpful GUI agent assistant."
},
{
"role": "user",
"content": prompt_content
}
{"role": "system", "content": "You are a helpful GUI agent assistant."},
{"role": "user", "content": prompt_content},
]
# Prepare API call kwargs
api_kwargs = {
"model": model,
@@ -757,20 +762,20 @@ class Glm4vConfig(AsyncAgentConfig):
# "skip_special_tokens": False,
# }
}
# Add API callbacks
if _on_api_start:
await _on_api_start(api_kwargs)
# Call liteLLM
response = await litellm.acompletion(**api_kwargs)
if _on_api_end:
await _on_api_end(api_kwargs, response)
# Get image dimensions for coordinate scaling
image_width, image_height = 1920, 1080 # Default dimensions
# Try to get actual dimensions from the image
try:
image_data = base64.b64decode(last_image_b64)
@@ -778,41 +783,38 @@ class Glm4vConfig(AsyncAgentConfig):
image_width, image_height = image.size
except Exception:
pass # Use default dimensions
# Convert GLM completion response to responses items
response_items = convert_glm_completion_to_responses_items(response, image_width, image_height)
response_items = convert_glm_completion_to_responses_items(
response, image_width, image_height
)
# Extract usage information
response_usage = {
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(
response.usage
).model_dump(),
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(response_usage)
# Create agent response
agent_response = {
"output": response_items,
"usage": response_usage
}
agent_response = {"output": response_items, "usage": response_usage}
return agent_response
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates using GLM-4.5V model.
Args:
model: Model name to use
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
Tuple with (x, y) coordinates or None
"""
@@ -824,22 +826,22 @@ Respond with a single click action in this format:
left_click(start_box='[x,y]')
Where x,y are coordinates normalized to 0-999 range."""
# Prepare messages for liteLLM
litellm_messages = [
{
"role": "system",
"content": "You are a helpful GUI agent assistant."
},
{"role": "system", "content": "You are a helpful GUI agent assistant."},
{
"role": "user",
"content": [
{"type": "text", "text": click_prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
]
}
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
},
],
},
]
# Prepare API call kwargs
api_kwargs = {
"model": model,
@@ -848,21 +850,21 @@ Where x,y are coordinates normalized to 0-999 range."""
"temperature": 0.001,
"extra_body": {
"skip_special_tokens": False,
}
},
}
# Call liteLLM
response = await litellm.acompletion(**api_kwargs)
# Extract response content
response_content = response.choices[0].message.content.strip()
print(response)
# Parse response for click coordinates
# Look for coordinates in the response, handling special tokens
coord_pattern = r"<\|begin_of_box\|>.*?left_click\(start_box='?\[(\d+),(\d+)\]'?\).*?<\|end_of_box\|>"
match = re.search(coord_pattern, response_content)
if not match:
# Fallback: look for coordinates without special tokens
coord_pattern = r"left_click\(start_box='?\[(\d+),(\d+)\]'?\)"
@@ -870,7 +872,7 @@ Where x,y are coordinates normalized to 0-999 range."""
if match:
x, y = int(match.group(1)), int(match.group(2))
# Get actual image dimensions for scaling
try:
image_data = base64.b64decode(image_b64)
@@ -879,15 +881,15 @@ Where x,y are coordinates normalized to 0-999 range."""
except Exception:
# Use default dimensions
image_width, image_height = 1920, 1080
# Convert from 0-999 normalized coordinates to actual pixel coordinates
actual_x = int((x / 999.0) * image_width)
actual_y = int((y / 999.0) * image_height)
return (actual_x, actual_y)
return None
except Exception as e:
# Log error and return None
print(f"Error in predict_click: {e}")
@@ -896,7 +898,7 @@ Where x,y are coordinates normalized to 0-999 range."""
def get_capabilities(self) -> List[AgentCapability]:
"""
Get list of capabilities supported by this agent config.
Returns:
List of capability strings
"""

View File

@@ -5,75 +5,80 @@ Code: https://github.com/Yan98/GTA1
"""
import asyncio
import json
import re
import base64
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
from io import BytesIO
import uuid
from PIL import Image
import litellm
import json
import math
import re
import uuid
from io import BytesIO
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import litellm
from PIL import Image
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..loops.base import AsyncAgentConfig
from ..types import AgentCapability, AgentResponse, Messages, Tools
SYSTEM_PROMPT = '''
SYSTEM_PROMPT = """
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
Output the coordinate pair exactly:
(x,y)
'''.strip()
""".strip()
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
"""Extract coordinates from model output."""
try:
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
return tuple(map(float, matches[0])) # type: ignore
return tuple(map(float, matches[0])) # type: ignore
except:
return (0.0, 0.0)
def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360) -> Tuple[int, int]:
def smart_resize(
height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360
) -> Tuple[int, int]:
"""Smart resize function similar to qwen_vl_utils."""
# Calculate the total pixels
total_pixels = height * width
# If already within bounds, return original dimensions
if min_pixels <= total_pixels <= max_pixels:
# Round to nearest factor
new_height = (height // factor) * factor
new_width = (width // factor) * factor
return new_height, new_width
# Calculate scaling factor
if total_pixels > max_pixels:
scale = (max_pixels / total_pixels) ** 0.5
else:
scale = (min_pixels / total_pixels) ** 0.5
# Apply scaling
new_height = int(height * scale)
new_width = int(width * scale)
# Round to nearest factor
new_height = (new_height // factor) * factor
new_width = (new_width // factor) * factor
# Ensure minimum size
new_height = max(new_height, factor)
new_width = max(new_width, factor)
return new_height, new_width
@register_agent(models=r".*GTA1.*")
class GTA1Config(AsyncAgentConfig):
"""GTA1 agent configuration implementing AsyncAgentConfig protocol for click prediction."""
def __init__(self):
self.current_model = None
self.last_screenshot_b64 = None
async def predict_step(
self,
@@ -87,25 +92,21 @@ class GTA1Config(AsyncAgentConfig):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
raise NotImplementedError()
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[float, float]]:
"""
Predict click coordinates using GTA1 model via litellm.acompletion.
Args:
model: The GTA1 model name
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
@@ -113,66 +114,62 @@ class GTA1Config(AsyncAgentConfig):
image_data = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_data))
width, height = image.width, image.height
# Smart resize the image (similar to qwen_vl_utils)
resized_height, resized_width = smart_resize(
height, width,
height,
width,
factor=28, # Default factor for Qwen models
min_pixels=3136,
max_pixels=4096 * 2160
max_pixels=4096 * 2160,
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height
# Convert resized image back to base64
buffered = BytesIO()
resized_image.save(buffered, format="PNG")
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
# Prepare system and user messages
system_message = {
"role": "system",
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width)
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width),
}
user_message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{resized_image_b64}"
}
"image_url": {"url": f"data:image/png;base64,{resized_image_b64}"},
},
{
"type": "text",
"text": instruction
}
]
{"type": "text", "text": instruction},
],
}
# Prepare API call kwargs
api_kwargs = {
"model": model,
"messages": [system_message, user_message],
"max_tokens": 2056,
"temperature": 0.0,
**kwargs
**kwargs,
}
# Use liteLLM acompletion
response = await litellm.acompletion(**api_kwargs)
# Extract response text
output_text = response.choices[0].message.content # type: ignore
output_text = response.choices[0].message.content # type: ignore
# Extract and rescale coordinates
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
pred_x *= scale_x
pred_y *= scale_y
return (math.floor(pred_x), math.floor(pred_y))
def get_capabilities(self) -> List[AgentCapability]:
"""Return the capabilities supported by this agent."""
return ["click"]

View File

@@ -21,8 +21,8 @@ import litellm
from PIL import Image
from ..decorators import register_agent
from .base import AsyncAgentConfig
from ..types import AgentCapability
from .base import AsyncAgentConfig
def _strip_hf_prefix(model: str) -> str:
@@ -53,7 +53,9 @@ def _maybe_smart_resize(image: Image.Image, model: str) -> Tuple[Image.Image, Tu
if image_processor is None:
return image, (orig_w, orig_h)
factor = getattr(image_processor, "patch_size", 14) * getattr(image_processor, "merge_size", 1)
factor = getattr(image_processor, "patch_size", 14) * getattr(
image_processor, "merge_size", 1
)
min_pixels = getattr(image_processor, "min_pixels", 256 * 256)
max_pixels = getattr(image_processor, "max_pixels", 1536 * 1536)

View File

@@ -18,13 +18,12 @@ import re
from io import BytesIO
from typing import Any, Dict, List, Optional, Tuple
from PIL import Image
import litellm
from PIL import Image
from ..decorators import register_agent
from .composed_grounded import ComposedGroundedConfig
from ..types import AgentCapability
from .composed_grounded import ComposedGroundedConfig
# Regex patterns for extracting coordinates
# Accept optional whitespace and optional decimal fractions
@@ -91,7 +90,7 @@ class InternVLConfig(ComposedGroundedConfig):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""Fallback to a self-composed model"""
return await super().predict_step(
@@ -105,15 +104,11 @@ class InternVLConfig(ComposedGroundedConfig):
_on_api_end=_on_api_end,
_on_usage=_on_usage,
_on_screenshot=_on_screenshot,
**kwargs
**kwargs,
)
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates using InternVL via litellm.acompletion.

View File

@@ -0,0 +1,493 @@
"""
Moondream3+ composed-grounded agent loop implementation.
Grounding is handled by a local Moondream3 preview model via Transformers.
Thinking is delegated to the trailing LLM in the composed model string: "moondream3+<thinking_model>".
Differences from composed_grounded:
- Provides a singleton Moondream3 client outside the class.
- predict_click uses model.point(image, instruction, settings={"max_objects": 1}) and returns pixel coordinates.
- If the last image was a screenshot (or we take one), run model.detect(image, "all form ui") to get bboxes, then
run model.caption on each cropped bbox to label it. Overlay labels on the screenshot and emit via _on_screenshot.
- Add a user message listing all detected form UI names so the thinker can reference them.
- If the thinking model doesn't support vision, filter out image content before calling litellm.
"""
from __future__ import annotations
import base64
import io
import uuid
from typing import Any, Dict, List, Optional, Tuple
import litellm
from PIL import Image, ImageDraw, ImageFont
from ..decorators import register_agent
from ..loops.base import AsyncAgentConfig
from ..responses import (
convert_completion_messages_to_responses_items,
convert_computer_calls_desc2xy,
convert_computer_calls_xy2desc,
convert_responses_items_to_completion_messages,
get_all_element_descriptions,
)
from ..types import AgentCapability
_MOONDREAM_SINGLETON = None
def get_moondream_model() -> Any:
"""Get a singleton instance of the Moondream3 preview model."""
global _MOONDREAM_SINGLETON
if _MOONDREAM_SINGLETON is None:
try:
import torch
from transformers import AutoModelForCausalLM
_MOONDREAM_SINGLETON = AutoModelForCausalLM.from_pretrained(
"moondream/moondream3-preview",
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="cuda",
)
except ImportError as e:
raise RuntimeError(
"moondream3 requires torch and transformers. Install with: pip install cua-agent[moondream3]"
) from e
return _MOONDREAM_SINGLETON
def _decode_image_b64(image_b64: str) -> Image.Image:
data = base64.b64decode(image_b64)
return Image.open(io.BytesIO(data)).convert("RGB")
def _image_to_b64(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def _supports_vision(model: str) -> bool:
"""Heuristic vision support detection for thinking model."""
m = model.lower()
vision_markers = [
"gpt-4o",
"gpt-4.1",
"o1",
"o3",
"claude-3",
"claude-3.5",
"sonnet",
"haiku",
"opus",
"gemini-1.5",
"llava",
]
return any(v in m for v in vision_markers)
def _filter_images_from_completion_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
filtered: List[Dict[str, Any]] = []
for msg in messages:
msg_copy = {**msg}
content = msg_copy.get("content")
if isinstance(content, list):
msg_copy["content"] = [c for c in content if c.get("type") != "image_url"]
filtered.append(msg_copy)
return filtered
def _annotate_detect_and_label_ui(base_img: Image.Image, model_md) -> Tuple[str, List[str]]:
"""Detect UI elements with Moondream, caption each, draw labels with backgrounds.
Args:
base_img: PIL image of the screenshot (RGB or RGBA). Will be copied/converted internally.
model_md: Moondream model instance with .detect() and .query() methods.
Returns:
A tuple of (annotated_image_base64_png, detected_names)
"""
# Ensure RGBA for semi-transparent fills
if base_img.mode != "RGBA":
base_img = base_img.convert("RGBA")
W, H = base_img.width, base_img.height
# Detect objects
try:
detect_result = model_md.detect(base_img, "all ui elements")
objects = detect_result.get("objects", []) if isinstance(detect_result, dict) else []
except Exception:
objects = []
draw = ImageDraw.Draw(base_img)
try:
font = ImageFont.load_default()
except Exception:
font = None
detected_names: List[str] = []
for i, obj in enumerate(objects):
try:
# Clamp normalized coords and crop
x_min = max(0.0, min(1.0, float(obj.get("x_min", 0.0))))
y_min = max(0.0, min(1.0, float(obj.get("y_min", 0.0))))
x_max = max(0.0, min(1.0, float(obj.get("x_max", 0.0))))
y_max = max(0.0, min(1.0, float(obj.get("y_max", 0.0))))
left, top, right, bottom = (
int(x_min * W),
int(y_min * H),
int(x_max * W),
int(y_max * H),
)
left, top = max(0, left), max(0, top)
right, bottom = min(W - 1, right), min(H - 1, bottom)
crop = base_img.crop((left, top, right, bottom))
# Prompted short caption
try:
result = model_md.query(crop, "Caption this UI element in few words.")
caption_text = (result or {}).get("answer", "")
except Exception:
caption_text = ""
name = (caption_text or "").strip() or f"element_{i+1}"
detected_names.append(name)
# Draw bbox
draw.rectangle([left, top, right, bottom], outline=(255, 215, 0, 255), width=2)
# Label background with padding and rounded corners
label = f"{i+1}. {name}"
padding = 3
if font:
text_bbox = draw.textbbox((0, 0), label, font=font)
else:
text_bbox = draw.textbbox((0, 0), label)
text_w = text_bbox[2] - text_bbox[0]
text_h = text_bbox[3] - text_bbox[1]
tx = left + 3
ty = top - (text_h + 2 * padding + 4)
if ty < 0:
ty = top + 3
bg_left = tx - padding
bg_top = ty - padding
bg_right = tx + text_w + padding
bg_bottom = ty + text_h + padding
try:
draw.rounded_rectangle(
[bg_left, bg_top, bg_right, bg_bottom],
radius=4,
fill=(0, 0, 0, 160),
outline=(255, 215, 0, 200),
width=1,
)
except Exception:
draw.rectangle(
[bg_left, bg_top, bg_right, bg_bottom],
fill=(0, 0, 0, 160),
outline=(255, 215, 0, 200),
width=1,
)
text_fill = (255, 255, 255, 255)
if font:
draw.text((tx, ty), label, fill=text_fill, font=font)
else:
draw.text((tx, ty), label, fill=text_fill)
except Exception:
continue
# Encode PNG base64
annotated = base_img
if annotated.mode not in ("RGBA", "RGB"):
annotated = annotated.convert("RGBA")
annotated_b64 = _image_to_b64(annotated)
return annotated_b64, detected_names
GROUNDED_COMPUTER_TOOL_SCHEMA = {
"type": "function",
"function": {
"name": "computer",
"description": (
"Control a computer by taking screenshots and interacting with UI elements. "
"The screenshot action will include a list of detected form UI element names when available. "
"Use element descriptions to locate and interact with UI elements on the screen."
),
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment",
],
"description": "The action to perform (required for all actions)",
},
"element_description": {
"type": "string",
"description": "Description of the element to interact with (required for click/double_click/move/scroll)",
},
"start_element_description": {
"type": "string",
"description": "Description of the element to start dragging from (required for drag)",
},
"end_element_description": {
"type": "string",
"description": "Description of the element to drag to (required for drag)",
},
"text": {
"type": "string",
"description": "The text to type (required for type)",
},
"keys": {
"type": "array",
"items": {"type": "string"},
"description": "Key(s) to press (required for keypress)",
},
"button": {
"type": "string",
"enum": ["left", "right", "wheel", "back", "forward"],
"description": "The mouse button to use for click/double_click",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount (required for scroll)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount (required for scroll)",
},
},
"required": ["action"],
},
},
}
@register_agent(r"moondream3\+.*", priority=2)
class Moondream3PlusConfig(AsyncAgentConfig):
def __init__(self):
self.desc2xy: Dict[str, Tuple[float, float]] = {}
async def predict_step(
self,
messages: List[Dict[str, Any]],
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs,
) -> Dict[str, Any]:
# Parse composed model: moondream3+<thinking_model>
if "+" not in model:
raise ValueError(f"Composed model must be 'moondream3+<thinking_model>', got: {model}")
_, thinking_model = model.split("+", 1)
pre_output_items: List[Dict[str, Any]] = []
# Acquire last screenshot; if missing, take one
last_image_b64: Optional[str] = None
for message in reversed(messages):
if (
isinstance(message, dict)
and message.get("type") == "computer_call_output"
and isinstance(message.get("output"), dict)
and message["output"].get("type") == "input_image"
):
image_url = message["output"].get("image_url", "")
if image_url.startswith("data:image/png;base64,"):
last_image_b64 = image_url.split(",", 1)[1]
break
if last_image_b64 is None and computer_handler is not None:
# Take a screenshot
screenshot_b64 = await computer_handler.screenshot() # type: ignore
if screenshot_b64:
call_id = uuid.uuid4().hex
pre_output_items += [
{
"type": "message",
"role": "assistant",
"content": [
{
"type": "output_text",
"text": "Taking a screenshot to analyze the current screen.",
}
],
},
{
"type": "computer_call",
"call_id": call_id,
"status": "completed",
"action": {"type": "screenshot"},
},
{
"type": "computer_call_output",
"call_id": call_id,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_b64}",
},
},
]
last_image_b64 = screenshot_b64
if _on_screenshot:
await _on_screenshot(screenshot_b64)
# If we have a last screenshot, run Moondream detection and labeling
detected_names: List[str] = []
if last_image_b64 is not None:
base_img = _decode_image_b64(last_image_b64)
model_md = get_moondream_model()
annotated_b64, detected_names = _annotate_detect_and_label_ui(base_img, model_md)
if _on_screenshot:
await _on_screenshot(annotated_b64, "annotated_form_ui")
# Also push a user message listing all detected names
if detected_names:
names_text = "\n".join(f"- {n}" for n in detected_names)
pre_output_items.append(
{
"type": "message",
"role": "user",
"content": [
{"type": "input_text", "text": "Detected form UI elements on screen:"},
{"type": "input_text", "text": names_text},
{
"type": "input_text",
"text": "Please continue with the next action needed to perform your task.",
},
],
}
)
tool_schemas = []
for schema in tools or []:
if schema.get("type") == "computer":
tool_schemas.append(GROUNDED_COMPUTER_TOOL_SCHEMA)
else:
tool_schemas.append(schema)
# Step 1: Convert computer calls from xy to descriptions
input_messages = messages + pre_output_items
messages_with_descriptions = convert_computer_calls_xy2desc(input_messages, self.desc2xy)
# Step 2: Convert responses items to completion messages
completion_messages = convert_responses_items_to_completion_messages(
messages_with_descriptions,
allow_images_in_tool_results=False,
)
# Optionally filter images if model lacks vision
if not _supports_vision(thinking_model):
completion_messages = _filter_images_from_completion_messages(completion_messages)
# Step 3: Call thinking model with litellm.acompletion
api_kwargs = {
"model": thinking_model,
"messages": completion_messages,
"tools": tool_schemas,
"max_retries": max_retries,
"stream": stream,
**kwargs,
}
if use_prompt_caching:
api_kwargs["use_prompt_caching"] = use_prompt_caching
if _on_api_start:
await _on_api_start(api_kwargs)
response = await litellm.acompletion(**api_kwargs)
if _on_api_end:
await _on_api_end(api_kwargs, response)
usage = {
**response.usage.model_dump(), # type: ignore
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(usage)
# Step 4: Convert completion messages back to responses items format
response_dict = response.model_dump() # type: ignore
choice_messages = [choice["message"] for choice in response_dict["choices"]]
thinking_output_items: List[Dict[str, Any]] = []
for choice_message in choice_messages:
thinking_output_items.extend(
convert_completion_messages_to_responses_items([choice_message])
)
# Step 5: Use Moondream to get coordinates for each description
element_descriptions = get_all_element_descriptions(thinking_output_items)
if element_descriptions and last_image_b64:
for desc in element_descriptions:
for _ in range(3): # try 3 times
coords = await self.predict_click(
model=model,
image_b64=last_image_b64,
instruction=desc,
)
if coords:
self.desc2xy[desc] = coords
break
# Step 6: Convert computer calls from descriptions back to xy coordinates
final_output_items = convert_computer_calls_desc2xy(thinking_output_items, self.desc2xy)
# Step 7: Return output and usage
return {"output": pre_output_items + final_output_items, "usage": usage}
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs,
) -> Optional[Tuple[float, float]]:
"""Predict click coordinates using Moondream3's point API.
Returns pixel coordinates (x, y) as floats.
"""
img = _decode_image_b64(image_b64)
W, H = img.width, img.height
model_md = get_moondream_model()
try:
result = model_md.point(img, instruction, settings={"max_objects": 1})
except Exception:
return None
try:
pt = (result or {}).get("points", [])[0]
x_norm = float(pt.get("x", 0.0))
y_norm = float(pt.get("y", 0.0))
x_px = max(0.0, min(float(W - 1), x_norm * W))
y_px = max(0.0, min(float(H - 1), y_norm * H))
return (x_px, y_px)
except Exception:
return None
def get_capabilities(self) -> List[AgentCapability]:
return ["click", "step"]

View File

@@ -5,100 +5,102 @@ Code: https://github.com/microsoft/OmniParser
"""
import asyncio
import json
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
import litellm
import inspect
import base64
import inspect
import json
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import litellm
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..loops.base import AsyncAgentConfig
from ..types import AgentCapability, AgentResponse, Messages, Tools
SOM_TOOL_SCHEMA = {
"type": "function",
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment"
],
"description": "The action to perform"
},
"element_id": {
"type": "integer",
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
},
"start_element_id": {
"type": "integer",
"description": "The ID of the element to start dragging from (required for drag action)"
},
"end_element_id": {
"type": "integer",
"description": "The ID of the element to drag to (required for drag action)"
},
"text": {
"type": "string",
"description": "The text to type (required for type action)"
},
"keys": {
"type": "string",
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
},
"button": {
"type": "string",
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
},
"type": "function",
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment",
],
"description": "The action to perform",
},
"element_id": {
"type": "integer",
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)",
},
"start_element_id": {
"type": "integer",
"description": "The ID of the element to start dragging from (required for drag action)",
},
"end_element_id": {
"type": "integer",
"description": "The ID of the element to drag to (required for drag action)",
},
"text": {
"type": "string",
"description": "The text to type (required for type action)",
},
"keys": {
"type": "string",
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')",
},
"button": {
"type": "string",
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
},
},
"required": ["action"],
},
"required": [
"action"
]
}
}
OMNIPARSER_AVAILABLE = False
try:
from som import OmniParser
OMNIPARSER_AVAILABLE = True
except ImportError:
pass
OMNIPARSER_SINGLETON = None
def get_parser():
global OMNIPARSER_SINGLETON
if OMNIPARSER_SINGLETON is None:
OMNIPARSER_SINGLETON = OmniParser()
return OMNIPARSER_SINGLETON
def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Get the last computer_call_output message from a messages list.
Args:
messages: List of messages to search through
Returns:
The last computer_call_output message dict, or None if not found
"""
@@ -107,11 +109,12 @@ def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Di
return message
return None
def _prepare_tools_for_omniparser(tool_schemas: List[Dict[str, Any]]) -> Tuple[Tools, dict]:
"""Prepare tools for OpenAI API format"""
omniparser_tools = []
id2xy = dict()
for schema in tool_schemas:
if schema["type"] == "computer":
omniparser_tools.append(SOM_TOOL_SCHEMA)
@@ -122,72 +125,80 @@ def _prepare_tools_for_omniparser(tool_schemas: List[Dict[str, Any]]) -> Tuple[T
elif schema["type"] == "function":
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
# Schema should be: {type, name, description, parameters}
omniparser_tools.append({ "type": "function", **schema["function"] })
omniparser_tools.append({"type": "function", **schema["function"]})
return omniparser_tools, id2xy
async def replace_function_with_computer_call(item: Dict[str, Any], id2xy: Dict[int, Tuple[float, float]]):
item_type = item.get("type")
def _get_xy(element_id: Optional[int]) -> Union[Tuple[float, float], Tuple[None, None]]:
if element_id is None:
return (None, None)
return id2xy.get(element_id, (None, None))
async def replace_function_with_computer_call(
item: Dict[str, Any], id2xy: Dict[int, Tuple[float, float]]
):
item_type = item.get("type")
if item_type == "function_call":
fn_name = item.get("name")
fn_args = json.loads(item.get("arguments", "{}"))
def _get_xy(element_id: Optional[int]) -> Union[Tuple[float, float], Tuple[None, None]]:
if element_id is None:
return (None, None)
return id2xy.get(element_id, (None, None))
item_id = item.get("id")
call_id = item.get("call_id")
if fn_name == "computer":
action = fn_args.get("action")
element_id = fn_args.get("element_id")
start_element_id = fn_args.get("start_element_id")
end_element_id = fn_args.get("end_element_id")
text = fn_args.get("text")
keys = fn_args.get("keys")
button = fn_args.get("button")
scroll_x = fn_args.get("scroll_x")
scroll_y = fn_args.get("scroll_y")
if item_type == "function_call":
fn_name = item.get("name")
fn_args = json.loads(item.get("arguments", "{}"))
x, y = _get_xy(element_id)
start_x, start_y = _get_xy(start_element_id)
end_x, end_y = _get_xy(end_element_id)
item_id = item.get("id")
call_id = item.get("call_id")
action_args = {
"type": action,
"x": x,
"y": y,
"start_x": start_x,
"start_y": start_y,
"end_x": end_x,
"end_y": end_y,
"text": text,
"keys": keys,
"button": button,
"scroll_x": scroll_x,
"scroll_y": scroll_y
}
# Remove None values to keep the JSON clean
action_args = {k: v for k, v in action_args.items() if v is not None}
if fn_name == "computer":
action = fn_args.get("action")
element_id = fn_args.get("element_id")
start_element_id = fn_args.get("start_element_id")
end_element_id = fn_args.get("end_element_id")
text = fn_args.get("text")
keys = fn_args.get("keys")
button = fn_args.get("button")
scroll_x = fn_args.get("scroll_x")
scroll_y = fn_args.get("scroll_y")
return [{
"type": "computer_call",
"action": action_args,
"id": item_id,
"call_id": call_id,
"status": "completed"
}]
x, y = _get_xy(element_id)
start_x, start_y = _get_xy(start_element_id)
end_x, end_y = _get_xy(end_element_id)
return [item]
action_args = {
"type": action,
"x": x,
"y": y,
"start_x": start_x,
"start_y": start_y,
"end_x": end_x,
"end_y": end_y,
"text": text,
"keys": keys,
"button": button,
"scroll_x": scroll_x,
"scroll_y": scroll_y,
}
# Remove None values to keep the JSON clean
action_args = {k: v for k, v in action_args.items() if v is not None}
async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[Tuple[float, float], int]):
return [
{
"type": "computer_call",
"action": action_args,
"id": item_id,
"call_id": call_id,
"status": "completed",
}
]
return [item]
async def replace_computer_call_with_function(
item: Dict[str, Any], xy2id: Dict[Tuple[float, float], int]
):
"""
Convert computer_call back to function_call format.
Also handles computer_call_output -> function_call_output conversion.
Args:
item: The item to convert
xy2id: Mapping from (x, y) coordinates to element IDs
@@ -202,12 +213,12 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
if item_type == "computer_call":
action_data = item.get("action", {})
# Extract coordinates and convert back to element IDs
element_id = _get_element_id(action_data.get("x"), action_data.get("y"))
start_element_id = _get_element_id(action_data.get("start_x"), action_data.get("start_y"))
end_element_id = _get_element_id(action_data.get("end_x"), action_data.get("end_y"))
# Build function arguments
fn_args = {
"action": action_data.get("type"),
@@ -218,33 +229,36 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
"keys": action_data.get("keys"),
"button": action_data.get("button"),
"scroll_x": action_data.get("scroll_x"),
"scroll_y": action_data.get("scroll_y")
"scroll_y": action_data.get("scroll_y"),
}
# Remove None values to keep the JSON clean
fn_args = {k: v for k, v in fn_args.items() if v is not None}
return [{
"type": "function_call",
"name": "computer",
"arguments": json.dumps(fn_args),
"id": item.get("id"),
"call_id": item.get("call_id"),
"status": "completed",
# Fall back to string representation
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})"
}]
return [
{
"type": "function_call",
"name": "computer",
"arguments": json.dumps(fn_args),
"id": item.get("id"),
"call_id": item.get("call_id"),
"status": "completed",
# Fall back to string representation
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})",
}
]
elif item_type == "computer_call_output":
# Simple conversion: computer_call_output -> function_call_output
return [{
"type": "function_call_output",
"call_id": item.get("call_id"),
"content": [item.get("output")],
"id": item.get("id"),
"status": "completed"
}]
return [
{
"type": "function_call_output",
"call_id": item.get("call_id"),
"content": [item.get("output")],
"id": item.get("id"),
"status": "completed",
}
]
return [item]
@@ -252,7 +266,7 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
@register_agent(models=r"omniparser\+.*|omni\+.*", priority=2)
class OmniparserConfig(AsyncAgentConfig):
"""Omniparser agent configuration implementing AsyncAgentConfig protocol."""
async def predict_step(
self,
messages: List[Dict[str, Any]],
@@ -266,25 +280,27 @@ class OmniparserConfig(AsyncAgentConfig):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""
OpenAI computer-use-preview agent loop using liteLLM responses.
Supports OpenAI's computer use preview models.
"""
if not OMNIPARSER_AVAILABLE:
raise ValueError("omniparser loop requires som to be installed. Install it with `pip install cua-som`.")
raise ValueError(
"omniparser loop requires som to be installed. Install it with `pip install cua-som`."
)
tools = tools or []
llm_model = model.split('+')[-1]
llm_model = model.split("+")[-1]
# Prepare tools for OpenAI API
openai_tools, id2xy = _prepare_tools_for_omniparser(tools)
# Find last computer_call_output
last_computer_call_output = get_last_computer_call_output(messages) # type: ignore
last_computer_call_output = get_last_computer_call_output(messages) # type: ignore
if last_computer_call_output:
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
image_data = image_url.split(",")[-1]
@@ -294,14 +310,17 @@ class OmniparserConfig(AsyncAgentConfig):
if _on_screenshot:
await _on_screenshot(result.annotated_image_base64, "annotated_image")
for element in result.elements:
id2xy[element.id] = ((element.bbox.x1 + element.bbox.x2) / 2, (element.bbox.y1 + element.bbox.y2) / 2)
id2xy[element.id] = (
(element.bbox.x1 + element.bbox.x2) / 2,
(element.bbox.y1 + element.bbox.y2) / 2,
)
# handle computer calls -> function calls
new_messages = []
for message in messages:
if not isinstance(message, dict):
message = message.__dict__
new_messages += await replace_computer_call_with_function(message, id2xy) # type: ignore
new_messages += await replace_computer_call_with_function(message, id2xy) # type: ignore
messages = new_messages
# Prepare API call kwargs
@@ -312,13 +331,13 @@ class OmniparserConfig(AsyncAgentConfig):
"stream": stream,
"truncation": "auto",
"num_retries": max_retries,
**kwargs
**kwargs,
}
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
print(str(api_kwargs)[:1000])
# Use liteLLM responses
@@ -330,60 +349,50 @@ class OmniparserConfig(AsyncAgentConfig):
# Extract usage information
usage = {
**response.usage.model_dump(), # type: ignore
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
**response.usage.model_dump(), # type: ignore
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
}
if _on_usage:
await _on_usage(usage)
# handle som function calls -> xy computer calls
new_output = []
for i in range(len(response.output)): # type: ignore
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) # type: ignore
return {
"output": new_output,
"usage": usage
}
for i in range(len(response.output)): # type: ignore
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) # type: ignore
return {"output": new_output, "usage": usage}
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[float, float]]:
"""
Predict click coordinates using OmniParser and LLM.
Uses OmniParser to annotate the image with element IDs, then uses LLM
to identify the correct element ID based on the instruction.
"""
if not OMNIPARSER_AVAILABLE:
return None
# Parse the image with OmniParser to get annotated image and elements
parser = get_parser()
result = parser.parse(image_b64)
# Extract the LLM model from composed model string
llm_model = model.split('+')[-1]
llm_model = model.split("+")[-1]
# Create system prompt for element ID prediction
SYSTEM_PROMPT = f'''
SYSTEM_PROMPT = """
You are an expert UI element locator. Given a GUI image annotated with numerical IDs over each interactable element, along with a user's element description, provide the ID of the specified element.
The image shows UI elements with numbered overlays. Each number corresponds to a clickable/interactable element.
Output only the element ID as a single integer.
'''.strip()
""".strip()
# Prepare messages for LLM
messages = [
{
"role": "system",
"content": SYSTEM_PROMPT
},
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": [
@@ -391,31 +400,25 @@ Output only the element ID as a single integer.
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{result.annotated_image_base64}"
}
},
},
{
"type": "text",
"text": f"Find the element: {instruction}"
}
]
}
{"type": "text", "text": f"Find the element: {instruction}"},
],
},
]
# Call LLM to predict element ID
response = await litellm.acompletion(
model=llm_model,
messages=messages,
max_tokens=10,
temperature=0.1
model=llm_model, messages=messages, max_tokens=10, temperature=0.1
)
# Extract element ID from response
response_text = response.choices[0].message.content.strip() # type: ignore
response_text = response.choices[0].message.content.strip() # type: ignore
# Try to parse the element ID
try:
element_id = int(response_text)
# Find the element with this ID and return its center coordinates
for element in result.elements:
if element.id == element_id:
@@ -425,9 +428,9 @@ Output only the element ID as a single integer.
except ValueError:
# If we can't parse the ID, return None
pass
return None
def get_capabilities(self) -> List[AgentCapability]:
"""Return the capabilities supported by this agent."""
return ["step"]

View File

@@ -6,12 +6,14 @@ import asyncio
import base64
import json
from io import BytesIO
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import litellm
from PIL import Image
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..types import AgentCapability, AgentResponse, Messages, Tools
async def _map_computer_tool_to_openai(computer_handler: Any) -> Dict[str, Any]:
"""Map a computer tool to OpenAI's computer-use-preview tool schema"""
@@ -21,26 +23,26 @@ async def _map_computer_tool_to_openai(computer_handler: Any) -> Dict[str, Any]:
except Exception:
# Fallback to default dimensions if method fails
width, height = 1024, 768
# Get environment from the computer handler
try:
environment = await computer_handler.get_environment()
except Exception:
# Fallback to default environment if method fails
environment = "linux"
return {
"type": "computer_use_preview",
"display_width": width,
"display_height": height,
"environment": environment # mac, windows, linux, browser
"environment": environment, # mac, windows, linux, browser
}
async def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools:
"""Prepare tools for OpenAI API format"""
openai_tools = []
for schema in tool_schemas:
if schema["type"] == "computer":
# Map computer tool to OpenAI format
@@ -49,19 +51,19 @@ async def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools
elif schema["type"] == "function":
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
# Schema should be: {type, name, description, parameters}
openai_tools.append({ "type": "function", **schema["function"] })
openai_tools.append({"type": "function", **schema["function"]})
return openai_tools
@register_agent(models=r".*computer-use-preview.*")
@register_agent(models=r".*(^|/)computer-use-preview")
class OpenAIComputerUseConfig:
"""
OpenAI computer-use-preview agent configuration using liteLLM responses.
Supports OpenAI's computer use preview models.
"""
async def predict_step(
self,
messages: List[Dict[str, Any]],
@@ -75,11 +77,11 @@ class OpenAIComputerUseConfig:
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""
Predict the next step based on input items.
Args:
messages: Input items following Responses format
model: Model name to use
@@ -92,12 +94,12 @@ class OpenAIComputerUseConfig:
_on_usage: Callback for usage tracking
_on_screenshot: Callback for screenshot events
**kwargs: Additional arguments
Returns:
Dictionary with "output" (output items) and "usage" array
"""
tools = tools or []
# Prepare tools for OpenAI API
openai_tools = await _prepare_tools_for_openai(tools)
@@ -110,16 +112,16 @@ class OpenAIComputerUseConfig:
"reasoning": {"summary": "concise"},
"truncation": "auto",
"num_retries": max_retries,
**kwargs
**kwargs,
}
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
# Use liteLLM responses
response = await litellm.aresponses(**api_kwargs)
# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)
@@ -136,24 +138,21 @@ class OpenAIComputerUseConfig:
output_dict = response.model_dump()
output_dict["usage"] = usage
return output_dict
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str
self, model: str, image_b64: str, instruction: str
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates based on image and instruction.
Uses OpenAI computer-use-preview with manually constructed input items
and a prompt that instructs the agent to only output clicks.
Args:
model: Model name to use
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
@@ -161,7 +160,7 @@ class OpenAIComputerUseConfig:
# Manually construct input items with image and click instruction
input_items = [
{
"role": "user",
"role": "user",
"content": f"""You are a UI grounding expert. Follow these guidelines:
1. NEVER ask for confirmation. Complete all tasks autonomously.
@@ -173,19 +172,16 @@ class OpenAIComputerUseConfig:
7. Be decisive and action-oriented. Complete the requested task fully.
Remember: You are expected to complete tasks autonomously. The user trusts you to do what they asked.
Task: Click {instruction}. Output ONLY a click action on the target element."""
Task: Click {instruction}. Output ONLY a click action on the target element.""",
},
{
"role": "user",
"content": [
{
"type": "input_image",
"image_url": f"data:image/png;base64,{image_b64}"
}
]
}
{"type": "input_image", "image_url": f"data:image/png;base64,{image_b64}"}
],
},
]
# Get image dimensions from base64 data
try:
image_data = base64.b64decode(image_b64)
@@ -194,15 +190,15 @@ Task: Click {instruction}. Output ONLY a click action on the target element."""
except Exception:
# Fallback to default dimensions if image parsing fails
display_width, display_height = 1024, 768
# Prepare computer tool for click actions
computer_tool = {
"type": "computer_use_preview",
"display_width": display_width,
"display_height": display_height,
"environment": "windows"
"environment": "windows",
}
# Prepare API call kwargs
api_kwargs = {
"model": model,
@@ -211,32 +207,34 @@ Task: Click {instruction}. Output ONLY a click action on the target element."""
"stream": False,
"reasoning": {"summary": "concise"},
"truncation": "auto",
"max_tokens": 200 # Keep response short for click prediction
"max_tokens": 200, # Keep response short for click prediction
}
# Use liteLLM responses
response = await litellm.aresponses(**api_kwargs)
# Extract click coordinates from response output
output_dict = response.model_dump()
output_items = output_dict.get("output", [])
output_items = output_dict.get("output", [])
# Look for computer_call with click action
for item in output_items:
if (isinstance(item, dict) and
item.get("type") == "computer_call" and
isinstance(item.get("action"), dict)):
if (
isinstance(item, dict)
and item.get("type") == "computer_call"
and isinstance(item.get("action"), dict)
):
action = item["action"]
if action.get("x") is not None and action.get("y") is not None:
return (int(action.get("x")), int(action.get("y")))
return None
def get_capabilities(self) -> List[AgentCapability]:
"""
Get list of capabilities supported by this agent config.
Returns:
List of capability strings
"""

View File

@@ -4,20 +4,22 @@ Based on OpenCUA model for GUI grounding tasks.
"""
import asyncio
import json
import re
import base64
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
from io import BytesIO
import uuid
from PIL import Image
import litellm
import json
import math
import re
import uuid
from io import BytesIO
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import litellm
from PIL import Image
from .composed_grounded import ComposedGroundedConfig
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..loops.base import AsyncAgentConfig
from ..types import AgentCapability, AgentResponse, Messages, Tools
from .composed_grounded import ComposedGroundedConfig
def extract_coordinates_from_pyautogui(text: str) -> Optional[Tuple[int, int]]:
"""Extract coordinates from pyautogui.click(x=..., y=...) format."""
@@ -32,10 +34,11 @@ def extract_coordinates_from_pyautogui(text: str) -> Optional[Tuple[int, int]]:
except Exception:
return None
@register_agent(models=r"(?i).*OpenCUA.*")
class OpenCUAConfig(ComposedGroundedConfig):
"""OpenCUA agent configuration implementing AsyncAgentConfig protocol for click prediction."""
def __init__(self):
super().__init__()
self.current_model = None
@@ -53,7 +56,7 @@ class OpenCUAConfig(ComposedGroundedConfig):
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""Fallback to a self-composed model"""
return await super().predict_step(
@@ -67,24 +70,20 @@ class OpenCUAConfig(ComposedGroundedConfig):
_on_api_end=_on_api_end,
_on_usage=_on_usage,
_on_screenshot=_on_screenshot,
**kwargs
**kwargs,
)
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates using OpenCUA model via litellm.acompletion.
Args:
model: The OpenCUA model name
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
@@ -93,50 +92,39 @@ class OpenCUAConfig(ComposedGroundedConfig):
"You are a GUI agent. You are given a task and a screenshot of the screen. "
"You need to perform a series of pyautogui actions to complete the task."
)
system_message = {
"role": "system",
"content": system_prompt
}
system_message = {"role": "system", "content": system_prompt}
# Prepare user message with image and instruction
user_message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_b64}"
}
},
{
"type": "text",
"text": f"Click on {instruction}"
}
]
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
{"type": "text", "text": f"Click on {instruction}"},
],
}
# Prepare API call kwargs
api_kwargs = {
"model": model,
"messages": [system_message, user_message],
"max_new_tokens": 2056,
"temperature": 0,
**kwargs
**kwargs,
}
# Use liteLLM acompletion
response = await litellm.acompletion(**api_kwargs)
# Extract response text
output_text = response.choices[0].message.content
# print(output_text)
# Extract coordinates from pyautogui format
coordinates = extract_coordinates_from_pyautogui(output_text)
return coordinates
def get_capabilities(self) -> List[AgentCapability]:
"""Return the capabilities supported by this agent."""
return ["click"]

View File

@@ -0,0 +1,510 @@
"""
Qwen3-VL agent loop implementation using litellm with function/tool calling.
- Passes a ComputerUse tool schema to acompletion
- Converts between Responses items and completion messages using helpers
"""
from __future__ import annotations
import json
import re
from typing import Any, Dict, List, Optional, Tuple
import litellm
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
from ..decorators import register_agent
from ..loops.base import AsyncAgentConfig
from ..responses import (
convert_completion_messages_to_responses_items,
convert_responses_items_to_completion_messages,
)
from ..types import AgentCapability
# ComputerUse tool schema (OpenAI function tool format)
QWEN3_COMPUTER_TOOL: Dict[str, Any] = {
"type": "function",
"function": {
"name": "computer",
"description": (
"Use a mouse and keyboard to interact with a computer, and take screenshots.\n"
"* This is an interface to a desktop GUI. You do not have access to a terminal or applications menu. You must click on desktop icons to start applications.\n"
"* Some applications may take time to start or process actions, so you may need to wait and take successive screenshots to see the results of your actions. E.g. if you click on Firefox and a window doesn't open, try wait and taking another screenshot.\n"
"* The screen's resolution is 1000x1000.\n"
"* Whenever you intend to move the cursor to click on an element like an icon, you should consult a screenshot to determine the coordinates of the element before moving the cursor.\n"
"* If you tried clicking on a program or link but it failed to load, even after waiting, try adjusting your cursor position so that the tip of the cursor visually falls on the element that you want to click.\n"
"* Make sure to click any buttons, links, icons, etc with the cursor tip in the center of the element. Don't click boxes on their edges."
),
"parameters": {
"type": "object",
"properties": {
"action": {
"description": "The action to perform.",
"enum": [
"key",
"type",
"mouse_move",
"left_click",
"left_click_drag",
"right_click",
"middle_click",
"double_click",
"triple_click",
"scroll",
"hscroll",
"screenshot",
"wait",
# "terminate",
# "answer",
],
"type": "string",
},
"keys": {
"description": "Required only by action=key.",
"type": "array",
"items": {"type": "string"},
},
"text": {
"description": "Required only by action=type and action=answer.",
"type": "string",
},
"coordinate": {
"description": "(x, y): Pixel coordinates from top-left.",
"type": "array",
"items": {"type": ["number", "integer"]},
"minItems": 2,
"maxItems": 2,
},
"pixels": {
"description": "Scroll amount. Positive=up, negative=down. For scroll/hscroll.",
"type": "number",
},
"time": {
"description": "Seconds to wait (action=wait).",
"type": "number",
},
# "status": {
# "description": "Task status (action=terminate).",
# "type": "string",
# "enum": ["success", "failure"],
# },
},
"required": ["action"],
},
},
}
def _build_nous_system(functions: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Use qwen-agent NousFnCallPrompt to generate a system message embedding tool schema."""
try:
from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
ContentItem as NousContentItem,
)
from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
Message as NousMessage,
)
from qwen_agent.llm.fncall_prompts.nous_fncall_prompt import (
NousFnCallPrompt,
)
except ImportError:
raise ImportError(
"qwen-agent not installed. Please install it with `pip install cua-agent[qwen]`."
)
msgs = NousFnCallPrompt().preprocess_fncall_messages(
messages=[
NousMessage(
role="system", content=[NousContentItem(text="You are a helpful assistant.")]
)
],
functions=functions,
lang="en",
)
sys = msgs[0].model_dump()
# Convert qwen-agent structured content to OpenAI-style content list
content = [{"type": "text", "text": c["text"]} for c in sys.get("content", [])]
return {"role": "system", "content": content}
def _parse_tool_call_from_text(text: str) -> Optional[Dict[str, Any]]:
"""Extract JSON object within <tool_call>...</tool_call> from model text."""
m = re.search(r"<tool_call>\s*(\{[\s\S]*?\})\s*</tool_call>", text)
if not m:
return None
try:
return json.loads(m.group(1))
except Exception:
return None
async def _unnormalize_coordinate(args: Dict[str, Any], dims: Tuple[int, int]) -> Dict[str, Any]:
"""Coordinates appear in 0..1000 space, scale to actual screen size using dims if provided."""
coord = args.get("coordinate")
if not coord or not isinstance(coord, (list, tuple)) or len(coord) < 2:
return args
x, y = float(coord[0]), float(coord[1])
width, height = float(dims[0]), float(dims[1])
x_abs = max(0.0, min(width, (x / 1000.0) * width))
y_abs = max(0.0, min(height, (y / 1000.0) * height))
args = {**args, "coordinate": [round(x_abs), round(y_abs)]}
return args
def convert_qwen_tool_args_to_computer_action(args: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""
Convert Qwen computer tool arguments to the Computer Calls action schema.
Qwen (example):
{"action": "left_click", "coordinate": [114, 68]}
Target (example):
{"action": "left_click", "x": 114, "y": 68}
Other mappings:
- right_click, middle_click, double_click (triple_click -> double_click)
- mouse_move -> { action: "move", x, y }
- key -> { action: "keypress", keys: [...] }
- type -> { action: "type", text }
- scroll/hscroll -> { action: "scroll", scroll_x, scroll_y, x, y }
- wait -> { action: "wait" }
- terminate/answer are not direct UI actions; return None for now
"""
if not isinstance(args, dict):
return None
action = args.get("action")
if not isinstance(action, str):
return None
# Coordinates helper
coord = args.get("coordinate")
x = y = None
if isinstance(coord, (list, tuple)) and len(coord) >= 2:
try:
x = int(round(float(coord[0])))
y = int(round(float(coord[1])))
except Exception:
x = y = None
# Map actions
a = action.lower()
if a in {"left_click", "right_click", "middle_click", "double_click"}:
if x is None or y is None:
return None
return {"action": a, "x": x, "y": y}
if a == "triple_click":
# Approximate as double_click
if x is None or y is None:
return None
return {"action": "double_click", "x": x, "y": y}
if a == "mouse_move":
if x is None or y is None:
return None
return {"action": "move", "x": x, "y": y}
if a == "key":
keys = args.get("keys")
if isinstance(keys, list) and all(isinstance(k, str) for k in keys):
return {"action": "keypress", "keys": keys}
return None
if a == "type":
text = args.get("text")
if isinstance(text, str):
return {"action": "type", "text": text}
return None
if a in {"scroll", "hscroll"}:
pixels = args.get("pixels") or 0
try:
pixels_val = int(round(float(pixels)))
except Exception:
pixels_val = 0
scroll_x = pixels_val if a == "hscroll" else 0
scroll_y = pixels_val if a == "scroll" else 0
# Include cursor position if available (optional)
out: Dict[str, Any] = {"action": "scroll", "scroll_x": scroll_x, "scroll_y": scroll_y}
if x is not None and y is not None:
out.update({"x": x, "y": y})
return out
if a == "wait":
return {"action": "wait"}
# Non-UI or terminal actions: terminate/answer -> not mapped here
return None
@register_agent(models=r"(?i).*qwen.*", priority=-1)
class Qwen3VlConfig(AsyncAgentConfig):
async def predict_step(
self,
messages: List[Dict[str, Any]],
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs,
) -> Dict[str, Any]:
# Build messages using NousFnCallPrompt system with tool schema in text
# Start with converted conversation (images/text preserved)
converted_msgs = convert_responses_items_to_completion_messages(
messages,
allow_images_in_tool_results=False,
)
# Prepend Nous-generated system if available
nous_system = _build_nous_system([QWEN3_COMPUTER_TOOL["function"]])
completion_messages = ([nous_system] if nous_system else []) + converted_msgs
# If there is no screenshot in the conversation, take one now and inject it.
# Also record a pre_output_items assistant message to reflect action.
def _has_any_image(msgs: List[Dict[str, Any]]) -> bool:
for m in msgs:
content = m.get("content")
if isinstance(content, list):
for p in content:
if isinstance(p, dict) and p.get("type") == "image_url":
return True
return False
pre_output_items: List[Dict[str, Any]] = []
if not _has_any_image(completion_messages):
if computer_handler is None or not hasattr(computer_handler, "screenshot"):
raise RuntimeError(
"No screenshots present and computer_handler.screenshot is not available."
)
screenshot_b64 = await computer_handler.screenshot()
if not screenshot_b64:
raise RuntimeError("Failed to capture screenshot from computer_handler.")
# Inject a user message with the screenshot so the model can see current context
completion_messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{screenshot_b64}"},
},
{"type": "text", "text": "Current screen"},
],
}
)
# Add assistant message to outputs to reflect the action, similar to composed_grounded.py
pre_output_items.append(
{
"type": "message",
"role": "assistant",
"content": [
{
"type": "text",
"text": "Taking a screenshot to see the current computer screen.",
}
],
}
)
# Smart-resize all screenshots and attach min/max pixel hints. Fail fast if deps missing.
# Also record the last resized width/height to unnormalize coordinates later.
last_rw: Optional[int] = None
last_rh: Optional[int] = None
MIN_PIXELS = 3136
MAX_PIXELS = 12845056
try:
import base64
import io
from PIL import Image # type: ignore
from qwen_vl_utils import smart_resize # type: ignore
except Exception:
raise ImportError(
"qwen-vl-utils not installed. Please install it with `pip install cua-agent[qwen]`."
)
for msg in completion_messages:
content = msg.get("content")
if not isinstance(content, list):
continue
for part in content:
if isinstance(part, dict) and part.get("type") == "image_url":
url = ((part.get("image_url") or {}).get("url")) or ""
# Expect data URL like data:image/png;base64,<b64>
if url.startswith("data:") and "," in url:
b64 = url.split(",", 1)[1]
img_bytes = base64.b64decode(b64)
im = Image.open(io.BytesIO(img_bytes))
h, w = im.height, im.width
rh, rw = smart_resize(
h, w, factor=32, min_pixels=MIN_PIXELS, max_pixels=MAX_PIXELS
)
# Attach hints on this image block
part["min_pixels"] = MIN_PIXELS
part["max_pixels"] = MAX_PIXELS
last_rw, last_rh = rw, rh
api_kwargs: Dict[str, Any] = {
"model": model,
"messages": completion_messages,
"max_retries": max_retries,
"stream": stream,
**{k: v for k, v in kwargs.items()},
}
if use_prompt_caching:
api_kwargs["use_prompt_caching"] = use_prompt_caching
if _on_api_start:
await _on_api_start(api_kwargs)
response = await litellm.acompletion(**api_kwargs)
if _on_api_end:
await _on_api_end(api_kwargs, response)
usage = {
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage( # type: ignore
response.usage
).model_dump(),
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(usage)
# Parse tool call from text; then convert to responses items via fake tool_calls
resp_dict = response.model_dump() # type: ignore
choice = (resp_dict.get("choices") or [{}])[0]
content_text = ((choice.get("message") or {}).get("content")) or ""
tool_call = _parse_tool_call_from_text(content_text)
output_items: List[Dict[str, Any]] = []
if tool_call and isinstance(tool_call, dict):
fn_name = tool_call.get("name") or "computer"
raw_args = tool_call.get("arguments") or {}
# Unnormalize coordinates to actual screen size using last resized dims
if last_rw is None or last_rh is None:
raise RuntimeError(
"No screenshots found to derive dimensions for coordinate unnormalization."
)
args = await _unnormalize_coordinate(raw_args, (last_rw, last_rh))
# Build an OpenAI-style tool call so we can reuse the converter
fake_cm = {
"role": "assistant",
"tool_calls": [
{
"type": "function",
"id": "call_0",
"function": {
"name": fn_name,
"arguments": json.dumps(args),
},
}
],
}
output_items.extend(convert_completion_messages_to_responses_items([fake_cm]))
else:
# Fallback: just return assistant text
fake_cm = {"role": "assistant", "content": content_text}
output_items.extend(convert_completion_messages_to_responses_items([fake_cm]))
# Prepend any pre_output_items (e.g., simulated screenshot-taking message)
return {"output": (pre_output_items + output_items), "usage": usage}
def get_capabilities(self) -> List[AgentCapability]:
return ["step"]
async def predict_click(
self, model: str, image_b64: str, instruction: str, **kwargs
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates using Qwen3-VL via litellm.acompletion.
Only exposes a reduced tool schema with left_click to bias model to output a single click.
Returns (x, y) absolute pixels when screen dimensions can be obtained; otherwise normalized 0..1000 integers.
"""
# Reduced tool
reduced_tool = {
"type": "function",
"function": {
**QWEN3_COMPUTER_TOOL["function"],
"parameters": {
"type": "object",
"properties": {
"action": {"type": "string", "enum": ["left_click"]},
"coordinate": {
"description": "(x, y) in 0..1000 reference space",
"type": "array",
"items": {"type": ["number", "integer"]},
"minItems": 2,
"maxItems": 2,
},
},
"required": ["action", "coordinate"],
},
},
}
# Build Nous system (lazy import inside helper already raises clear guidance if missing)
nous_system = _build_nous_system([reduced_tool["function"]])
# Pre-process using smart_resize
min_pixels = 3136
max_pixels = 12845056
try:
# Lazy import to avoid hard dependency
import base64
import io
# If PIL is available, estimate size from image to derive smart bounds
from PIL import Image
from qwen_vl_utils import smart_resize # type: ignore
img_bytes = base64.b64decode(image_b64)
im = Image.open(io.BytesIO(img_bytes))
h, w = im.height, im.width
# Qwen notebook suggests factor=32 and a wide min/max range
rh, rw = smart_resize(h, w, factor=32, min_pixels=min_pixels, max_pixels=max_pixels)
except Exception:
raise ImportError(
"qwen-vl-utils not installed. Please install it with `pip install cua-agent[qwen]`."
)
messages = []
if nous_system:
messages.append(nous_system)
image_block: Dict[str, Any] = {
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_b64}"},
"min_pixels": min_pixels,
"max_pixels": max_pixels,
}
# Single user message with image and instruction, matching OpenAI-style content blocks
messages.append(
{
"role": "user",
"content": [
image_block,
{"type": "text", "text": instruction},
],
}
)
api_kwargs: Dict[str, Any] = {
"model": model,
"messages": messages,
**{k: v for k, v in kwargs.items()},
}
response = await litellm.acompletion(**api_kwargs)
resp = response.model_dump() # type: ignore
choice = (resp.get("choices") or [{}])[0]
content_text = ((choice.get("message") or {}).get("content")) or ""
tool_call = _parse_tool_call_from_text(content_text) or {}
args = tool_call.get("arguments") or {}
args = await _unnormalize_coordinate(args, (rh, rw))
coord = args.get("coordinate")
if isinstance(coord, (list, tuple)) and len(coord) >= 2:
return int(coord[0]), int(coord[1])
return None

View File

@@ -4,39 +4,50 @@ Paper: https://arxiv.org/abs/2501.12326
Code: https://github.com/bytedance/UI-TARS
"""
import ast
import asyncio
from ctypes import cast
import json
import base64
import json
import math
import re
import ast
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
from ctypes import cast
from io import BytesIO
from PIL import Image
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import litellm
from litellm.types.utils import ModelResponse
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
from litellm.responses.litellm_completion_transformation.transformation import (
LiteLLMCompletionResponsesConfig,
)
from litellm.responses.utils import Usage
from openai.types.responses.response_computer_tool_call_param import ActionType, ResponseComputerToolCallParam
from litellm.types.utils import ModelResponse
from openai.types.responses.response_computer_tool_call_param import (
ActionType,
ResponseComputerToolCallParam,
)
from openai.types.responses.response_input_param import ComputerCallOutput
from openai.types.responses.response_output_message_param import ResponseOutputMessageParam
from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam, Summary
from openai.types.responses.response_output_message_param import (
ResponseOutputMessageParam,
)
from openai.types.responses.response_reasoning_item_param import (
ResponseReasoningItemParam,
Summary,
)
from PIL import Image
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..responses import (
make_reasoning_item,
make_output_text_item,
make_click_item,
make_double_click_item,
make_drag_item,
make_input_image_item,
make_keypress_item,
make_output_text_item,
make_reasoning_item,
make_scroll_item,
make_type_item,
make_wait_item,
make_input_image_item
)
from ..types import AgentCapability, AgentResponse, Messages, Tools
# Constants from reference code
IMAGE_FACTOR = 28
@@ -94,6 +105,7 @@ click(point='<|box_start|>(x1,y1)<|box_end|>')
## User Instruction
{instruction}"""
def round_by_factor(number: float, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
@@ -110,7 +122,11 @@ def floor_by_factor(number: float, factor: int) -> int:
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
height: int,
width: int,
factor: int = IMAGE_FACTOR,
min_pixels: int = MIN_PIXELS,
max_pixels: int = MAX_PIXELS,
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
@@ -144,14 +160,14 @@ def escape_single_quotes(text):
def parse_action(action_str):
"""Parse action string into structured format."""
try:
node = ast.parse(action_str, mode='eval')
node = ast.parse(action_str, mode="eval")
if not isinstance(node, ast.Expression):
raise ValueError("Not an expression")
call = node.body
if not isinstance(call, ast.Call):
raise ValueError("Not a function call")
# Get function name
if isinstance(call.func, ast.Name):
func_name = call.func.id
@@ -159,7 +175,7 @@ def parse_action(action_str):
func_name = call.func.attr
else:
func_name = None
# Get keyword arguments
kwargs = {}
for kw in call.keywords:
@@ -171,12 +187,9 @@ def parse_action(action_str):
else:
value = None
kwargs[key] = value
return {
'function': func_name,
'args': kwargs
}
return {"function": func_name, "args": kwargs}
except Exception as e:
print(f"Failed to parse action '{action_str}': {e}")
return None
@@ -185,39 +198,39 @@ def parse_action(action_str):
def parse_uitars_response(text: str, image_width: int, image_height: int) -> List[Dict[str, Any]]:
"""Parse UITARS model response into structured actions."""
text = text.strip()
# Extract thought
thought = None
if text.startswith("Thought:"):
thought_match = re.search(r"Thought: (.+?)(?=\s*Action:|$)", text, re.DOTALL)
if thought_match:
thought = thought_match.group(1).strip()
# Extract action
if "Action:" not in text:
raise ValueError("No Action found in response")
action_str = text.split("Action:")[-1].strip()
# Handle special case for type actions
if "type(content" in action_str:
def escape_quotes(match):
return match.group(1)
pattern = r"type\(content='(.*?)'\)"
content = re.sub(pattern, escape_quotes, action_str)
action_str = escape_single_quotes(content)
action_str = "type(content='" + action_str + "')"
# Parse the action
parsed_action = parse_action(action_str.replace("\n", "\\n").lstrip())
if parsed_action is None:
raise ValueError(f"Action can't parse: {action_str}")
action_type = parsed_action["function"]
params = parsed_action["args"]
# Process parameters
action_inputs = {}
for param_name, param in params.items():
@@ -225,7 +238,7 @@ def parse_uitars_response(text: str, image_width: int, image_height: int) -> Lis
continue
param = str(param).lstrip()
action_inputs[param_name.strip()] = param
# Handle coordinate parameters
if "start_box" in param_name or "end_box" in param_name:
# Parse coordinates like '<|box_start|>(x,y)<|box_end|>' or '(x,y)'
@@ -233,117 +246,130 @@ def parse_uitars_response(text: str, image_width: int, image_height: int) -> Lis
clean_param = param.replace("<|box_start|>", "").replace("<|box_end|>", "")
# Then remove parentheses and split
numbers = clean_param.replace("(", "").replace(")", "").split(",")
try:
float_numbers = [float(num.strip()) / 1000 for num in numbers] # Normalize to 0-1 range
float_numbers = [
float(num.strip()) / 1000 for num in numbers
] # Normalize to 0-1 range
if len(float_numbers) == 2:
# Single point, duplicate for box format
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
float_numbers = [
float_numbers[0],
float_numbers[1],
float_numbers[0],
float_numbers[1],
]
action_inputs[param_name.strip()] = str(float_numbers)
except ValueError as e:
# If parsing fails, keep the original parameter value
print(f"Warning: Could not parse coordinates '{param}': {e}")
action_inputs[param_name.strip()] = param
return [{
"thought": thought,
"action_type": action_type,
"action_inputs": action_inputs,
"text": text
}]
return [
{
"thought": thought,
"action_type": action_type,
"action_inputs": action_inputs,
"text": text,
}
]
def convert_to_computer_actions(parsed_responses: List[Dict[str, Any]], image_width: int, image_height: int) -> List[ResponseComputerToolCallParam | ResponseOutputMessageParam]:
def convert_to_computer_actions(
parsed_responses: List[Dict[str, Any]], image_width: int, image_height: int
) -> List[ResponseComputerToolCallParam | ResponseOutputMessageParam]:
"""Convert parsed UITARS responses to computer actions."""
computer_actions = []
for response in parsed_responses:
action_type = response.get("action_type")
action_inputs = response.get("action_inputs", {})
if action_type == "finished":
finished_text = action_inputs.get("content", "Task completed successfully.")
computer_actions.append(make_output_text_item(finished_text))
break
elif action_type == "wait":
computer_actions.append(make_wait_item())
elif action_type == "call_user":
computer_actions.append(make_output_text_item("I need assistance from the user to proceed with this task."))
computer_actions.append(
make_output_text_item("I need assistance from the user to proceed with this task.")
)
elif action_type in ["click", "left_single"]:
start_box = action_inputs.get("start_box")
if start_box:
coords = eval(start_box)
x = int((coords[0] + coords[2]) / 2 * image_width)
y = int((coords[1] + coords[3]) / 2 * image_height)
computer_actions.append(make_click_item(x, y, "left"))
elif action_type == "double_click":
start_box = action_inputs.get("start_box")
if start_box:
coords = eval(start_box)
x = int((coords[0] + coords[2]) / 2 * image_width)
y = int((coords[1] + coords[3]) / 2 * image_height)
computer_actions.append(make_double_click_item(x, y))
elif action_type == "right_click":
start_box = action_inputs.get("start_box")
if start_box:
coords = eval(start_box)
x = int((coords[0] + coords[2]) / 2 * image_width)
y = int((coords[1] + coords[3]) / 2 * image_height)
computer_actions.append(make_click_item(x, y, "right"))
elif action_type == "type":
content = action_inputs.get("content", "")
computer_actions.append(make_type_item(content))
elif action_type == "hotkey":
key = action_inputs.get("key", "")
keys = key.split()
computer_actions.append(make_keypress_item(keys))
elif action_type == "press":
key = action_inputs.get("key", "")
computer_actions.append(make_keypress_item([key]))
elif action_type == "scroll":
start_box = action_inputs.get("start_box")
direction = action_inputs.get("direction", "down")
if start_box:
coords = eval(start_box)
x = int((coords[0] + coords[2]) / 2 * image_width)
y = int((coords[1] + coords[3]) / 2 * image_height)
else:
x, y = image_width // 2, image_height // 2
scroll_y = 5 if "up" in direction.lower() else -5
computer_actions.append(make_scroll_item(x, y, 0, scroll_y))
elif action_type == "drag":
start_box = action_inputs.get("start_box")
end_box = action_inputs.get("end_box")
if start_box and end_box:
start_coords = eval(start_box)
end_coords = eval(end_box)
start_x = int((start_coords[0] + start_coords[2]) / 2 * image_width)
start_y = int((start_coords[1] + start_coords[3]) / 2 * image_height)
end_x = int((end_coords[0] + end_coords[2]) / 2 * image_width)
end_y = int((end_coords[1] + end_coords[3]) / 2 * image_height)
path = [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}]
computer_actions.append(make_drag_item(path))
return computer_actions
@@ -354,33 +380,35 @@ def pil_to_base64(image: Image.Image) -> str:
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def process_image_for_uitars(image_data: str, max_pixels: int = MAX_PIXELS, min_pixels: int = MIN_PIXELS) -> tuple[Image.Image, int, int]:
def process_image_for_uitars(
image_data: str, max_pixels: int = MAX_PIXELS, min_pixels: int = MIN_PIXELS
) -> tuple[Image.Image, int, int]:
"""Process image for UITARS model input."""
# Decode base64 image
if image_data.startswith('data:image'):
image_data = image_data.split(',')[1]
if image_data.startswith("data:image"):
image_data = image_data.split(",")[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes))
original_width, original_height = image.size
# Resize image according to UITARS requirements
if image.width * image.height > max_pixels:
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width = int(image.width * resize_factor)
height = int(image.height * resize_factor)
image = image.resize((width, height))
if image.width * image.height < min_pixels:
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
width = math.ceil(image.width * resize_factor)
height = math.ceil(image.height * resize_factor)
image = image.resize((width, height))
if image.mode != "RGB":
image = image.convert("RGB")
return image, original_width, original_height
@@ -391,7 +419,11 @@ def sanitize_message(msg: Any) -> Any:
for key, value in msg.items():
if key == "content" and isinstance(value, list):
result[key] = [
{k: v for k, v in item.items() if k != "image_url"} if isinstance(item, dict) else item
(
{k: v for k, v in item.items() if k != "image_url"}
if isinstance(item, dict)
else item
)
for item in value
]
else:
@@ -406,38 +438,41 @@ def sanitize_message(msg: Any) -> Any:
def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any]]:
"""
Convert UITARS internal message format back to LiteLLM format.
This function processes reasoning, computer_call, and computer_call_output messages
and converts them to the appropriate LiteLLM assistant message format.
Args:
messages: List of UITARS internal messages
Returns:
List of LiteLLM formatted messages
"""
litellm_messages = []
current_assistant_content = []
for message in messages:
if isinstance(message, dict):
message_type = message.get("type")
if message_type == "reasoning":
# Extract reasoning text from summary
summary = message.get("summary", [])
if summary and isinstance(summary, list):
for summary_item in summary:
if isinstance(summary_item, dict) and summary_item.get("type") == "summary_text":
if (
isinstance(summary_item, dict)
and summary_item.get("type") == "summary_text"
):
reasoning_text = summary_item.get("text", "")
if reasoning_text:
current_assistant_content.append(f"Thought: {reasoning_text}")
elif message_type == "computer_call":
# Convert computer action to UITARS action format
action = message.get("action", {})
action_type = action.get("type")
if action_type == "click":
x, y = action.get("x", 0), action.get("y", 0)
button = action.get("button", "left")
@@ -447,59 +482,65 @@ def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any
action_text = f"Action: right_single(start_box='({x},{y})')"
else:
action_text = f"Action: click(start_box='({x},{y})')"
elif action_type == "double_click":
x, y = action.get("x", 0), action.get("y", 0)
action_text = f"Action: left_double(start_box='({x},{y})')"
elif action_type == "drag":
start_x, start_y = action.get("start_x", 0), action.get("start_y", 0)
end_x, end_y = action.get("end_x", 0), action.get("end_y", 0)
action_text = f"Action: drag(start_box='({start_x},{start_y})', end_box='({end_x},{end_y})')"
elif action_type == "key":
key = action.get("key", "")
action_text = f"Action: hotkey(key='{key}')"
elif action_type == "type":
text = action.get("text", "")
# Escape single quotes in the text
escaped_text = escape_single_quotes(text)
action_text = f"Action: type(content='{escaped_text}')"
elif action_type == "scroll":
x, y = action.get("x", 0), action.get("y", 0)
direction = action.get("direction", "down")
action_text = f"Action: scroll(start_box='({x},{y})', direction='{direction}')"
elif action_type == "wait":
action_text = "Action: wait()"
else:
# Fallback for unknown action types
action_text = f"Action: {action_type}({action})"
current_assistant_content.append(action_text)
# When we hit a computer_call_output, finalize the current assistant message
if current_assistant_content:
litellm_messages.append({
"role": "assistant",
"content": [{"type": "text", "text": "\n".join(current_assistant_content)}]
})
litellm_messages.append(
{
"role": "assistant",
"content": [
{"type": "text", "text": "\n".join(current_assistant_content)}
],
}
)
current_assistant_content = []
elif message_type == "computer_call_output":
# Add screenshot from computer call output
output = message.get("output", {})
if isinstance(output, dict) and output.get("type") == "input_image":
image_url = output.get("image_url", "")
if image_url:
litellm_messages.append({
"role": "user",
"content": [{"type": "image_url", "image_url": {"url": image_url}}]
})
litellm_messages.append(
{
"role": "user",
"content": [{"type": "image_url", "image_url": {"url": image_url}}],
}
)
elif message.get("role") == "user":
# # Handle user messages
# content = message.get("content", "")
@@ -514,24 +555,22 @@ def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any
# "content": content
# })
pass
# Add any remaining assistant content
if current_assistant_content:
litellm_messages.append({
"role": "assistant",
"content": current_assistant_content
})
litellm_messages.append({"role": "assistant", "content": current_assistant_content})
return litellm_messages
@register_agent(models=r"(?i).*ui-?tars.*")
class UITARSConfig:
"""
UITARS agent configuration using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B model.
Supports UITARS vision-language models for computer control.
"""
async def predict_step(
self,
messages: List[Dict[str, Any]],
@@ -545,11 +584,11 @@ class UITARSConfig:
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
**kwargs,
) -> Dict[str, Any]:
"""
Predict the next step based on input messages.
Args:
messages: Input messages following Responses format
model: Model name to use
@@ -562,22 +601,22 @@ class UITARSConfig:
_on_usage: Callback for usage tracking
_on_screenshot: Callback for screenshot events
**kwargs: Additional arguments
Returns:
Dictionary with "output" (output items) and "usage" array
"""
tools = tools or []
# Create response items
response_items = []
# Find computer tool for screen dimensions
computer_tool = None
for tool_schema in tools:
if tool_schema["type"] == "computer":
computer_tool = tool_schema["computer"]
break
# Get screen dimensions
screen_width, screen_height = 1024, 768
if computer_tool:
@@ -585,20 +624,20 @@ class UITARSConfig:
screen_width, screen_height = await computer_tool.get_dimensions()
except:
pass
# Process messages to extract instruction and image
instruction = ""
image_data = None
# Convert messages to list if string
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
# Extract instruction and latest screenshot
for message in reversed(messages):
if isinstance(message, dict):
content = message.get("content", "")
# Handle different content formats
if isinstance(content, str):
if not instruction and message.get("role") == "user":
@@ -614,46 +653,41 @@ class UITARSConfig:
image_data = image_url.get("url", "")
else:
image_data = image_url
# Also check for computer_call_output with screenshots
if message.get("type") == "computer_call_output" and not image_data:
output = message.get("output", {})
if isinstance(output, dict) and output.get("type") == "input_image":
image_data = output.get("image_url", "")
if instruction and image_data:
break
if not instruction:
instruction = "Help me complete this task by analyzing the screen and taking appropriate actions."
instruction = (
"Help me complete this task by analyzing the screen and taking appropriate actions."
)
# Create prompt
user_prompt = UITARS_PROMPT_TEMPLATE.format(
instruction=instruction,
action_space=UITARS_ACTION_SPACE,
language="English"
instruction=instruction, action_space=UITARS_ACTION_SPACE, language="English"
)
# Convert conversation history to LiteLLM format
history_messages = convert_uitars_messages_to_litellm(messages)
# Prepare messages for liteLLM
litellm_messages = [
{
"role": "system",
"content": "You are a helpful assistant."
}
]
litellm_messages = [{"role": "system", "content": "You are a helpful assistant."}]
# Add current user instruction with screenshot
current_user_message = {
"role": "user",
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
]
],
}
litellm_messages.append(current_user_message)
# Process image for UITARS
if not image_data:
# Take screenshot if none found in messages
@@ -667,17 +701,22 @@ class UITARSConfig:
raise ValueError("No screenshot found in messages and no computer_handler provided")
processed_image, original_width, original_height = process_image_for_uitars(image_data)
encoded_image = pil_to_base64(processed_image)
# Add conversation history
if history_messages:
litellm_messages.extend(history_messages)
else:
litellm_messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
]
})
litellm_messages.append(
{
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{encoded_image}"},
}
],
}
)
# Prepare API call kwargs
api_kwargs = {
@@ -687,146 +726,142 @@ class UITARSConfig:
"temperature": kwargs.get("temperature", 0.0),
"do_sample": kwargs.get("temperature", 0.0) > 0.0,
"num_retries": max_retries,
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]},
}
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
# Call liteLLM with UITARS model
response = await litellm.acompletion(**api_kwargs)
# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)
# Extract response content
response_content = response.choices[0].message.content.strip() # type: ignore
response_content = response.choices[0].message.content.strip() # type: ignore
# Parse UITARS response
parsed_responses = parse_uitars_response(response_content, original_width, original_height)
# Convert to computer actions
computer_actions = convert_to_computer_actions(parsed_responses, original_width, original_height)
computer_actions = convert_to_computer_actions(
parsed_responses, original_width, original_height
)
# Add computer actions to response items
thought = parsed_responses[0].get("thought", "")
if thought:
response_items.append(make_reasoning_item(thought))
response_items.extend(computer_actions)
# Extract usage information
response_usage = {
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(
response.usage
).model_dump(),
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(response_usage)
# Create agent response
agent_response = {
"output": response_items,
"usage": response_usage
}
agent_response = {"output": response_items, "usage": response_usage}
return agent_response
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str
self, model: str, image_b64: str, instruction: str
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates based on image and instruction.
UITARS supports click prediction through its action parsing.
Args:
model: Model name to use
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
Tuple with (x, y) coordinates or None
"""
try:
# Create prompt using grounding template
user_prompt = GROUNDING_UITARS_PROMPT_TEMPLATE.format(
instruction=instruction
)
user_prompt = GROUNDING_UITARS_PROMPT_TEMPLATE.format(instruction=instruction)
# Process image for UITARS
processed_image, original_width, original_height = process_image_for_uitars(image_b64)
encoded_image = pil_to_base64(processed_image)
# Prepare messages for liteLLM
litellm_messages = [
{
"role": "system",
"content": "You are a helpful assistant."
},
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
]
}
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{encoded_image}"},
},
],
},
]
# Prepare API call kwargs
api_kwargs = {
"model": model,
"messages": litellm_messages,
"max_tokens": 2056,
"temperature": 0.0,
"do_sample": False
"do_sample": False,
}
# Call liteLLM with UITARS model
response = await litellm.acompletion(**api_kwargs)
# Extract response content
response_content = response.choices[0].message.content.strip() # type: ignore
response_content = response.choices[0].message.content.strip() # type: ignore
print(response_content)
# Parse the response to extract click coordinates
# Look for click action with coordinates (with special tokens)
click_pattern = r"click\(point='<\|box_start\|>\((\d+),(\d+)\)<\|box_end\|>'\)"
match = re.search(click_pattern, response_content)
# Fallback: Look for simpler format without special tokens
if not match:
# Pattern for: click(start_box='(x,y)') or click(point='(x,y)')
fallback_pattern = r"click\((?:start_box|point)='\((\d+),(\d+)\)'\)"
match = re.search(fallback_pattern, response_content)
if match:
x, y = int(match.group(1)), int(match.group(2))
# Scale coordinates back to original image dimensions
scale_x = original_width / processed_image.width
scale_y = original_height / processed_image.height
scaled_x = int(x * scale_x)
scaled_y = int(y * scale_y)
return (scaled_x, scaled_y)
return None
except Exception as e:
# Log error and return None
print(f"Error in predict_click: {e}")
return None
def get_capabilities(self) -> List[AgentCapability]:
"""
Get list of capabilities supported by this agent config.
Returns:
List of capability strings
"""
return ["step", "click"]
return ["step", "click"]

View File

@@ -1,19 +1,22 @@
"""
Example usage of the proxy server and client requests.
"""
import dotenv
dotenv.load_dotenv()
import asyncio
import json
import os
from typing import Any, Dict
import aiohttp
from typing import Dict, Any
async def test_http_endpoint():
"""Test the HTTP /responses endpoint."""
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
assert isinstance(anthropic_api_key, str), "ANTHROPIC_API_KEY environment variable must be set"
@@ -21,11 +24,9 @@ async def test_http_endpoint():
simple_request = {
"model": "anthropic/claude-3-5-sonnet-20241022",
"input": "Tell me a three sentence bedtime story about a unicorn.",
"env": {
"ANTHROPIC_API_KEY": anthropic_api_key
}
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
}
# Example 2: Multi-modal request with image
multimodal_request = {
"model": "anthropic/claude-3-5-sonnet-20241022",
@@ -36,70 +37,72 @@ async def test_http_endpoint():
{"type": "input_text", "text": "what is in this image?"},
{
"type": "input_image",
"image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
]
"image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg",
},
],
}
],
"env": {
"ANTHROPIC_API_KEY": anthropic_api_key
}
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
}
# Example 3: Request with custom agent and computer kwargs
custom_request = {
"model": "anthropic/claude-3-5-sonnet-20241022",
"input": "Take a screenshot and tell me what you see",
"env": {
"ANTHROPIC_API_KEY": anthropic_api_key
}
"env": {"ANTHROPIC_API_KEY": anthropic_api_key},
}
# Test requests
base_url = "https://m-linux-96lcxd2c2k.containers.cloud.trycua.com:8443"
# base_url = "http://localhost:8000"
api_key = os.getenv("CUA_API_KEY")
assert isinstance(api_key, str), "CUA_API_KEY environment variable must be set"
async with aiohttp.ClientSession() as session:
for i, request_data in enumerate([
simple_request,
# multimodal_request,
custom_request
], 1):
for i, request_data in enumerate(
[
simple_request,
# multimodal_request,
custom_request,
],
1,
):
print(f"\n--- Test {i} ---")
print(f"Request: {json.dumps(request_data, indent=2)}")
try:
print(f"Sending request to {base_url}/responses")
async with session.post(
f"{base_url}/responses",
json=request_data,
headers={"Content-Type": "application/json", "X-API-Key": api_key}
headers={"Content-Type": "application/json", "X-API-Key": api_key},
) as response:
result = await response.json()
print(f"Status: {response.status}")
print(f"Response: {json.dumps(result, indent=2)}")
except Exception as e:
print(f"Error: {e}")
def curl_examples():
"""Print curl command examples."""
print("=== CURL Examples ===\n")
print("1. Simple text request:")
print("""curl http://localhost:8000/responses \\
print(
"""curl http://localhost:8000/responses \\
-H "Content-Type: application/json" \\
-d '{
"model": "anthropic/claude-3-5-sonnet-20241022",
"input": "Tell me a three sentence bedtime story about a unicorn."
}'""")
}'"""
)
print("\n2. Multi-modal request with image:")
print("""curl http://localhost:8000/responses \\
print(
"""curl http://localhost:8000/responses \\
-H "Content-Type: application/json" \\
-d '{
"model": "anthropic/claude-3-5-sonnet-20241022",
@@ -115,10 +118,12 @@ def curl_examples():
]
}
]
}'""")
}'"""
)
print("\n3. Request with custom configuration:")
print("""curl http://localhost:8000/responses \\
print(
"""curl http://localhost:8000/responses \\
-H "Content-Type: application/json" \\
-d '{
"model": "anthropic/claude-3-5-sonnet-20241022",
@@ -131,50 +136,49 @@ def curl_examples():
"os_type": "linux",
"provider_type": "cloud"
}
}'""")
}'"""
)
async def test_p2p_client():
"""Example P2P client using peerjs-python."""
try:
from peerjs import Peer, PeerOptions, ConnectionEventType
from aiortc import RTCConfiguration, RTCIceServer
from peerjs import ConnectionEventType, Peer, PeerOptions
# Set up client peer
options = PeerOptions(
host="0.peerjs.com",
port=443,
secure=True,
config=RTCConfiguration(
iceServers=[RTCIceServer(urls="stun:stun.l.google.com:19302")]
)
config=RTCConfiguration(iceServers=[RTCIceServer(urls="stun:stun.l.google.com:19302")]),
)
client_peer = Peer(id="test-client", peer_options=options)
await client_peer.start()
# Connect to proxy server
connection = client_peer.connect("computer-agent-proxy")
@connection.on(ConnectionEventType.Open)
async def connection_open():
print("Connected to proxy server")
# Send a test request
request = {
"model": "anthropic/claude-3-5-sonnet-20241022",
"input": "Hello from P2P client!"
"input": "Hello from P2P client!",
}
await connection.send(json.dumps(request))
@connection.on(ConnectionEventType.Data)
async def connection_data(data):
print(f"Received response: {data}")
await client_peer.destroy()
# Wait for connection
await asyncio.sleep(10)
except ImportError:
print("P2P dependencies not available. Install peerjs-python for P2P testing.")
except Exception as e:
@@ -183,7 +187,7 @@ async def test_p2p_client():
if __name__ == "__main__":
import sys
if len(sys.argv) > 1 and sys.argv[1] == "curl":
curl_examples()
elif len(sys.argv) > 1 and sys.argv[1] == "p2p":

View File

@@ -7,24 +7,25 @@ import json
import logging
import os
from contextlib import contextmanager
from typing import Dict, Any, List, Union, Optional
from typing import Any, Dict, List, Optional, Union
from computer import Computer
from ..agent import ComputerAgent
from computer import Computer
logger = logging.getLogger(__name__)
class ResponsesHandler:
"""Handler for /responses endpoint that processes agent requests."""
def __init__(self):
self.computer = None
self.agent = None
# Simple in-memory caches
self._computer_cache: Dict[str, Any] = {}
self._agent_cache: Dict[str, Any] = {}
async def setup_computer_agent(
self,
model: str,
@@ -75,7 +76,9 @@ class ResponsesHandler:
computer = Computer(**default_c_config)
await computer.__aenter__()
self._computer_cache[comp_key] = computer
logger.info(f"Computer created and cached with key={comp_key} config={default_c_config}")
logger.info(
f"Computer created and cached with key={comp_key} config={default_c_config}"
)
else:
logger.info(f"Reusing cached computer for key={comp_key}")
@@ -115,14 +118,14 @@ class ResponsesHandler:
# Bind current agent reference
self.agent = agent
async def process_request(self, request_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process a /responses request and return the result.
Args:
request_data: Dictionary containing model, input, and optional kwargs
Returns:
Dictionary with the agent's response
"""
@@ -133,12 +136,12 @@ class ResponsesHandler:
agent_kwargs = request_data.get("agent_kwargs", {})
computer_kwargs = request_data.get("computer_kwargs", {})
env_overrides = request_data.get("env", {}) or {}
if not model:
raise ValueError("Model is required")
if not input_data:
raise ValueError("Input is required")
# Apply env overrides for the duration of this request
with self._env_overrides(env_overrides):
# Set up (and possibly reuse) computer and agent via caches
@@ -155,28 +158,22 @@ class ResponsesHandler:
# Run agent and get first result
async for result in agent.run(messages):
# Return the first result and break
return {
"success": True,
"result": result,
"model": model
}
return {"success": True, "result": result, "model": model}
# If no results were yielded
return {
"success": False,
"error": "No results from agent",
"model": model
}
return {"success": False, "error": "No results from agent", "model": model}
except Exception as e:
logger.error(f"Error processing request: {e}")
return {
"success": False,
"error": str(e),
"model": request_data.get("model", "unknown")
"model": request_data.get("model", "unknown"),
}
def _convert_input_to_messages(self, input_data: Union[str, List[Dict[str, Any]]]) -> List[Dict[str, Any]]:
def _convert_input_to_messages(
self, input_data: Union[str, List[Dict[str, Any]]]
) -> List[Dict[str, Any]]:
"""Convert input data to messages format."""
if isinstance(input_data, str):
# Simple string input
@@ -192,22 +189,18 @@ class ResponsesHandler:
if part.get("type") == "input_text":
content_parts.append({"type": "text", "text": part["text"]})
elif part.get("type") == "input_image":
content_parts.append({
"type": "image_url",
"image_url": {"url": part["image_url"]}
})
content_parts.append(
{"type": "image_url", "image_url": {"url": part["image_url"]}}
)
else:
content_parts.append(part)
messages.append({
"role": msg["role"],
"content": content_parts
})
messages.append({"role": msg["role"], "content": content_parts})
else:
messages.append(msg)
return messages
else:
raise ValueError("Input must be string or list of messages")
async def cleanup(self):
"""Clean up resources."""
if self.computer:

File diff suppressed because it is too large Load Diff

View File

@@ -2,37 +2,43 @@
Type definitions for agent
"""
from typing import Dict, List, Any, Optional, Callable, Protocol, Literal
from pydantic import BaseModel
import re
from litellm import ResponseInputParam, ResponsesAPIResponse, ToolParam
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Literal, Optional, Protocol
from litellm import ResponseInputParam, ResponsesAPIResponse, ToolParam
from pydantic import BaseModel
# Agent input types
Messages = str | ResponseInputParam | List[Dict[str, Any]]
Tools = Optional[Iterable[ToolParam]]
# Agent output types
AgentResponse = ResponsesAPIResponse
AgentResponse = ResponsesAPIResponse
AgentCapability = Literal["step", "click"]
# Exception types
class ToolError(RuntimeError):
"""Base exception for tool-related errors"""
pass
class IllegalArgumentError(ToolError):
"""Exception raised when function arguments are invalid"""
pass
# Agent config registration
class AgentConfigInfo(BaseModel):
"""Information about a registered agent config"""
agent_class: type
models_regex: str
priority: int = 0
def matches_model(self, model: str) -> bool:
"""Check if this agent config matches the given model"""
return bool(re.match(self.models_regex, model))

View File

@@ -2,6 +2,6 @@
UI components for agent
"""
from .gradio import launch_ui, create_gradio_ui
from .gradio import create_gradio_ui, launch_ui
__all__ = ["launch_ui", "create_gradio_ui"]

View File

@@ -1,4 +1,4 @@
from .gradio import launch_ui
if __name__ == "__main__":
launch_ui()
launch_ui()

View File

@@ -18,21 +18,21 @@ Requirements:
- OpenAI or Anthropic API key
"""
import os
import asyncio
import logging
import json
import logging
import os
import platform
from pathlib import Path
from typing import Dict, List, Optional, AsyncGenerator, Any, Tuple, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union, cast
import gradio as gr
from gradio.components.chatbot import MetadataDict
from typing import cast
# Import from agent package
from agent import ComputerAgent
from agent.types import Messages, AgentResponse
from agent.types import AgentResponse, Messages
from computer import Computer
from gradio.components.chatbot import MetadataDict
# Global variables
global_agent = None
@@ -42,11 +42,13 @@ SETTINGS_FILE = Path(".gradio_settings.json")
logging.basicConfig(level=logging.INFO)
import dotenv
if dotenv.load_dotenv():
print(f"DEBUG - Loaded environment variables from {dotenv.find_dotenv()}")
else:
print("DEBUG - No .env file found")
# --- Settings Load/Save Functions ---
def load_settings() -> Dict[str, Any]:
"""Loads settings from the JSON file."""
@@ -84,7 +86,7 @@ def save_settings(settings: Dict[str, Any]):
# async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None:
# """Add screenshot to chatbot when a screenshot is taken."""
# image_markdown = f"![Screenshot after {action_type}](data:image/png;base64,{screenshot_base64})"
# if self.chatbot_history is not None:
# self.chatbot_history.append(
# gr.ChatMessage(
@@ -141,7 +143,7 @@ def get_model_string(model_name: str, loop_provider: str) -> str:
ollama_model = model_name.split("OMNI: Ollama ", 1)[1]
return f"omniparser+ollama_chat/{ollama_model}"
return "omniparser+ollama_chat/llama3"
# Map based on loop provider
mapping = MODEL_MAPPINGS.get(loop_provider.lower(), MODEL_MAPPINGS["openai"])
return mapping.get(model_name, mapping["default"])
@@ -151,6 +153,7 @@ def get_ollama_models() -> List[str]:
"""Get available models from Ollama if installed."""
try:
import subprocess
result = subprocess.run(["ollama", "list"], capture_output=True, text=True)
if result.returncode == 0:
lines = result.stdout.strip().split("\n")
@@ -174,16 +177,14 @@ def create_computer_instance(
os_type: str = "macos",
provider_type: str = "lume",
name: Optional[str] = None,
api_key: Optional[str] = None
api_key: Optional[str] = None,
) -> Computer:
"""Create or get the global Computer instance."""
global global_computer
if global_computer is None:
if provider_type == "localhost":
global_computer = Computer(
verbosity=verbosity,
os_type=os_type,
use_host_computer_server=True
verbosity=verbosity, os_type=os_type, use_host_computer_server=True
)
else:
global_computer = Computer(
@@ -191,7 +192,7 @@ def create_computer_instance(
os_type=os_type,
provider_type=provider_type,
name=name if name else "",
api_key=api_key
api_key=api_key,
)
return global_computer
@@ -217,7 +218,7 @@ def create_agent(
os_type=computer_os,
provider_type=computer_provider,
name=computer_name,
api_key=computer_api_key
api_key=computer_api_key,
)
# Handle custom models
@@ -233,12 +234,15 @@ def create_agent(
"only_n_most_recent_images": only_n_most_recent_images,
"verbosity": verbosity,
}
if save_trajectory:
agent_kwargs["trajectory_dir"] = "trajectories"
if max_trajectory_budget:
agent_kwargs["max_trajectory_budget"] = {"max_budget": max_trajectory_budget, "raise_error": True}
agent_kwargs["max_trajectory_budget"] = {
"max_budget": max_trajectory_budget,
"raise_error": True,
}
global_agent = ComputerAgent(**agent_kwargs)
return global_agent
@@ -247,7 +251,8 @@ def create_agent(
def launch_ui():
"""Standalone function to launch the Gradio app."""
from agent.ui.gradio.ui_components import create_gradio_ui
print(f"Starting Gradio app for CUA Agent...")
print("Starting Gradio app for CUA Agent...")
demo = create_gradio_ui()
demo.launch(share=False, inbrowser=True)

View File

@@ -2,19 +2,25 @@
UI Components for the Gradio interface
"""
import os
import asyncio
import logging
import json
import logging
import os
import platform
from pathlib import Path
from typing import Dict, List, Optional, Any, cast
from typing import Any, Dict, List, Optional, cast
import gradio as gr
from gradio.components.chatbot import MetadataDict
from .app import (
load_settings, save_settings, create_agent, get_model_string,
get_ollama_models, global_agent, global_computer
create_agent,
get_model_string,
get_ollama_models,
global_agent,
global_computer,
load_settings,
save_settings,
)
# Global messages array to maintain conversation history
@@ -23,15 +29,15 @@ global_messages = []
def create_gradio_ui() -> gr.Blocks:
"""Create a Gradio UI for the Computer-Use Agent."""
# Load settings
saved_settings = load_settings()
# Check for API keys
openai_api_key = os.environ.get("OPENAI_API_KEY", "")
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY", "")
cua_api_key = os.environ.get("CUA_API_KEY", "")
# Model choices
openai_models = ["OpenAI: Computer-Use Preview"]
anthropic_models = [
@@ -43,10 +49,10 @@ def create_gradio_ui() -> gr.Blocks:
omni_models = [
"OMNI: OpenAI GPT-4o",
"OMNI: OpenAI GPT-4o mini",
"OMNI: Claude 3.7 Sonnet (20250219)",
"OMNI: Claude 3.5 Sonnet (20241022)"
"OMNI: Claude 3.7 Sonnet (20250219)",
"OMNI: Claude 3.5 Sonnet (20241022)",
]
# Check if API keys are available
has_openai_key = bool(openai_api_key)
has_anthropic_key = bool(anthropic_api_key)
@@ -59,15 +65,20 @@ def create_gradio_ui() -> gr.Blocks:
# Detect platform
is_mac = platform.system().lower() == "darwin"
# Format model choices
provider_to_models = {
"OPENAI": openai_models,
"ANTHROPIC": anthropic_models,
"OMNI": omni_models + ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
"UITARS": ([
"huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
] if is_mac else []) + ["Custom model (OpenAI compatible API)"],
"UITARS": (
[
"huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
]
if is_mac
else []
)
+ ["Custom model (OpenAI compatible API)"],
}
# Apply saved settings
@@ -82,7 +93,9 @@ def create_gradio_ui() -> gr.Blocks:
elif initial_loop == "ANTHROPIC":
initial_model = anthropic_models[0] if anthropic_models else "No models available"
else: # OMNI
initial_model = omni_models[0] if omni_models else "Custom model (OpenAI compatible API)"
initial_model = (
omni_models[0] if omni_models else "Custom model (OpenAI compatible API)"
)
initial_custom_model = saved_settings.get("custom_model", "Qwen2.5-VL-7B-Instruct")
initial_provider_base_url = saved_settings.get("provider_base_url", "http://localhost:1234/v1")
@@ -96,16 +109,27 @@ def create_gradio_ui() -> gr.Blocks:
"Open Safari, search for 'macOS automation tools', and save the first three results as bookmarks",
"Configure SSH keys and set up a connection to a remote server",
]
def generate_python_code(agent_loop_choice, model_name, tasks, recent_images=3, save_trajectory=True, computer_os="linux", computer_provider="cloud", container_name="", cua_cloud_api_key="", max_budget=None):
def generate_python_code(
agent_loop_choice,
model_name,
tasks,
recent_images=3,
save_trajectory=True,
computer_os="linux",
computer_provider="cloud",
container_name="",
cua_cloud_api_key="",
max_budget=None,
):
"""Generate Python code for the current configuration and tasks."""
tasks_str = ""
for task in tasks:
if task and task.strip():
tasks_str += f' "{task}",\n'
model_string = get_model_string(model_name, agent_loop_choice)
computer_args = []
if computer_os != "macos":
computer_args.append(f'os_type="{computer_os}"')
@@ -115,14 +139,14 @@ def create_gradio_ui() -> gr.Blocks:
computer_args.append(f'name="{container_name}"')
if cua_cloud_api_key:
computer_args.append(f'api_key="{cua_cloud_api_key}"')
computer_args_str = ", ".join(computer_args)
if computer_args_str:
computer_args_str = f"({computer_args_str})"
else:
computer_args_str = "()"
code = f'''import asyncio
code = f"""import asyncio
from computer import Computer
from agent import ComputerAgent
@@ -131,22 +155,22 @@ async def main():
agent = ComputerAgent(
model="{model_string}",
tools=[computer],
only_n_most_recent_images={recent_images},'''
only_n_most_recent_images={recent_images},"""
if save_trajectory:
code += '''
trajectory_dir="trajectories",'''
code += """
trajectory_dir="trajectories","""
if max_budget:
code += f'''
max_trajectory_budget={{"max_budget": {max_budget}, "raise_error": True}},'''
code += '''
code += f"""
max_trajectory_budget={{"max_budget": {max_budget}, "raise_error": True}},"""
code += """
)
'''
"""
if tasks_str:
code += f'''
code += f"""
# Prompts for the computer-use agent
tasks = [
{tasks_str.rstrip()}
@@ -158,23 +182,23 @@ async def main():
async for result in agent.run(messages):
for item in result["output"]:
if item["type"] == "message":
print(item["content"][0]["text"])'''
print(item["content"][0]["text"])"""
else:
code += f'''
code += """
# Execute a single task
task = "Search for information about CUA on GitHub"
print(f"Executing task: {{task}}")
messages = [{{"role": "user", "content": task}}]
print(f"Executing task: {task}")
messages = [{"role": "user", "content": task}]
async for result in agent.run(messages):
for item in result["output"]:
if item["type"] == "message":
print(item["content"][0]["text"])'''
print(item["content"][0]["text"])"""
code += '''
code += """
if __name__ == "__main__":
asyncio.run(main())'''
asyncio.run(main())"""
return code
# Create the Gradio interface
@@ -199,11 +223,11 @@ if __name__ == "__main__":
value=generate_python_code(initial_loop, "gpt-4o", []),
interactive=False,
)
with gr.Accordion("Computer Configuration", open=True):
is_windows = platform.system().lower() == "windows"
is_mac = platform.system().lower() == "darwin"
providers = ["cloud", "localhost", "docker"]
if is_mac:
providers += ["lume"]
@@ -227,30 +251,30 @@ if __name__ == "__main__":
value=computer_choices[0],
info="Select the operating system for the computer",
)
computer_provider = gr.Radio(
choices=providers,
label="Provider",
value="lume" if is_mac else "cloud",
info="Select the computer provider",
)
container_name = gr.Textbox(
label="Container Name",
placeholder="Enter container name (optional)",
value=os.environ.get("CUA_CONTAINER_NAME", ""),
info="Optional name for the container",
)
cua_cloud_api_key = gr.Textbox(
label="CUA Cloud API Key",
placeholder="Enter your CUA Cloud API key",
value=os.environ.get("CUA_API_KEY", ""),
type="password",
info="Required for cloud provider",
visible=(not has_cua_key)
visible=(not has_cua_key),
)
with gr.Accordion("Agent Configuration", open=True):
agent_loop = gr.Dropdown(
choices=["OPENAI", "ANTHROPIC", "OMNI", "UITARS"],
@@ -267,90 +291,113 @@ if __name__ == "__main__":
value=openai_models[0] if openai_models else "No models available",
info="Select OpenAI model",
interactive=True,
visible=(initial_loop == "OPENAI")
visible=(initial_loop == "OPENAI"),
)
anthropic_model_choice = gr.Dropdown(
choices=anthropic_models,
label="Anthropic Model",
value=anthropic_models[0] if anthropic_models else "No models available",
value=(
anthropic_models[0] if anthropic_models else "No models available"
),
info="Select Anthropic model",
interactive=True,
visible=(initial_loop == "ANTHROPIC")
visible=(initial_loop == "ANTHROPIC"),
)
omni_model_choice = gr.Dropdown(
choices=omni_models + ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
choices=omni_models
+ ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
label="OMNI Model",
value=omni_models[0] if omni_models else "Custom model (OpenAI compatible API)",
value=(
omni_models[0]
if omni_models
else "Custom model (OpenAI compatible API)"
),
info="Select OMNI model or choose a custom model option",
interactive=True,
visible=(initial_loop == "OMNI")
visible=(initial_loop == "OMNI"),
)
uitars_model_choice = gr.Dropdown(
choices=provider_to_models.get("UITARS", ["No models available"]),
label="UITARS Model",
value=provider_to_models.get("UITARS", ["No models available"])[0] if provider_to_models.get("UITARS") else "No models available",
value=(
provider_to_models.get("UITARS", ["No models available"])[0]
if provider_to_models.get("UITARS")
else "No models available"
),
info="Select UITARS model",
interactive=True,
visible=(initial_loop == "UITARS")
visible=(initial_loop == "UITARS"),
)
model_choice = gr.Textbox(visible=False)
# API key inputs
with gr.Group(visible=not has_openai_key and (initial_loop == "OPENAI" or initial_loop == "OMNI")) as openai_key_group:
with gr.Group(
visible=not has_openai_key
and (initial_loop == "OPENAI" or initial_loop == "OMNI")
) as openai_key_group:
openai_api_key_input = gr.Textbox(
label="OpenAI API Key",
placeholder="Enter your OpenAI API key",
value=os.environ.get("OPENAI_API_KEY", ""),
interactive=True,
type="password",
info="Required for OpenAI models"
info="Required for OpenAI models",
)
with gr.Group(visible=not has_anthropic_key and (initial_loop == "ANTHROPIC" or initial_loop == "OMNI")) as anthropic_key_group:
with gr.Group(
visible=not has_anthropic_key
and (initial_loop == "ANTHROPIC" or initial_loop == "OMNI")
) as anthropic_key_group:
anthropic_api_key_input = gr.Textbox(
label="Anthropic API Key",
placeholder="Enter your Anthropic API key",
value=os.environ.get("ANTHROPIC_API_KEY", ""),
interactive=True,
type="password",
info="Required for Anthropic models"
info="Required for Anthropic models",
)
# API key handlers
def set_openai_api_key(key):
if key and key.strip():
os.environ["OPENAI_API_KEY"] = key.strip()
print(f"DEBUG - Set OpenAI API key environment variable")
print("DEBUG - Set OpenAI API key environment variable")
return key
def set_anthropic_api_key(key):
if key and key.strip():
os.environ["ANTHROPIC_API_KEY"] = key.strip()
print(f"DEBUG - Set Anthropic API key environment variable")
print("DEBUG - Set Anthropic API key environment variable")
return key
openai_api_key_input.change(
fn=set_openai_api_key,
inputs=[openai_api_key_input],
outputs=[openai_api_key_input],
queue=False
queue=False,
)
anthropic_api_key_input.change(
fn=set_anthropic_api_key,
inputs=[anthropic_api_key_input],
outputs=[anthropic_api_key_input],
queue=False
queue=False,
)
# UI update function
def update_ui(loop=None, openai_model=None, anthropic_model=None, omni_model=None, uitars_model=None):
def update_ui(
loop=None,
openai_model=None,
anthropic_model=None,
omni_model=None,
uitars_model=None,
):
loop = loop or agent_loop.value
model_value = None
if loop == "OPENAI" and openai_model:
model_value = openai_model
@@ -360,21 +407,37 @@ if __name__ == "__main__":
model_value = omni_model
elif loop == "UITARS" and uitars_model:
model_value = uitars_model
openai_visible = (loop == "OPENAI")
anthropic_visible = (loop == "ANTHROPIC")
omni_visible = (loop == "OMNI")
uitars_visible = (loop == "UITARS")
show_openai_key = not has_openai_key and (loop == "OPENAI" or (loop == "OMNI" and model_value and "OpenAI" in model_value and "Custom" not in model_value))
show_anthropic_key = not has_anthropic_key and (loop == "ANTHROPIC" or (loop == "OMNI" and model_value and "Claude" in model_value and "Custom" not in model_value))
openai_visible = loop == "OPENAI"
anthropic_visible = loop == "ANTHROPIC"
omni_visible = loop == "OMNI"
uitars_visible = loop == "UITARS"
show_openai_key = not has_openai_key and (
loop == "OPENAI"
or (
loop == "OMNI"
and model_value
and "OpenAI" in model_value
and "Custom" not in model_value
)
)
show_anthropic_key = not has_anthropic_key and (
loop == "ANTHROPIC"
or (
loop == "OMNI"
and model_value
and "Claude" in model_value
and "Custom" not in model_value
)
)
is_custom_openai_api = model_value == "Custom model (OpenAI compatible API)"
is_custom_ollama = model_value == "Custom model (ollama)"
is_any_custom = is_custom_openai_api or is_custom_ollama
model_choice_value = model_value if model_value else ""
return [
gr.update(visible=openai_visible),
gr.update(visible=anthropic_visible),
@@ -385,15 +448,18 @@ if __name__ == "__main__":
gr.update(visible=is_any_custom),
gr.update(visible=is_custom_openai_api),
gr.update(visible=is_custom_openai_api),
gr.update(value=model_choice_value)
gr.update(value=model_choice_value),
]
# Custom model inputs
custom_model = gr.Textbox(
label="Custom Model Name",
placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct or llama3)",
value=initial_custom_model,
visible=(initial_model == "Custom model (OpenAI compatible API)" or initial_model == "Custom model (ollama)"),
visible=(
initial_model == "Custom model (OpenAI compatible API)"
or initial_model == "Custom model (ollama)"
),
interactive=True,
)
@@ -413,36 +479,56 @@ if __name__ == "__main__":
interactive=True,
type="password",
)
# Provider visibility update function
def update_provider_visibility(provider):
"""Update visibility of container name and API key based on selected provider."""
is_localhost = provider == "localhost"
return [
gr.update(visible=not is_localhost), # container_name
gr.update(visible=not is_localhost and not has_cua_key) # cua_cloud_api_key
gr.update(
visible=not is_localhost and not has_cua_key
), # cua_cloud_api_key
]
# Connect provider change event
computer_provider.change(
fn=update_provider_visibility,
inputs=[computer_provider],
outputs=[container_name, cua_cloud_api_key],
queue=False
queue=False,
)
# Connect UI update events
for dropdown in [agent_loop, omni_model_choice, uitars_model_choice, openai_model_choice, anthropic_model_choice]:
for dropdown in [
agent_loop,
omni_model_choice,
uitars_model_choice,
openai_model_choice,
anthropic_model_choice,
]:
dropdown.change(
fn=update_ui,
inputs=[agent_loop, openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice],
outputs=[
openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice,
openai_key_group, anthropic_key_group,
custom_model, provider_base_url, provider_api_key,
model_choice
inputs=[
agent_loop,
openai_model_choice,
anthropic_model_choice,
omni_model_choice,
uitars_model_choice,
],
queue=False
outputs=[
openai_model_choice,
anthropic_model_choice,
omni_model_choice,
uitars_model_choice,
openai_key_group,
anthropic_key_group,
custom_model,
provider_base_url,
provider_api_key,
model_choice,
],
queue=False,
)
save_trajectory = gr.Checkbox(
@@ -461,7 +547,7 @@ if __name__ == "__main__":
info="Number of recent images to keep in context",
interactive=True,
)
max_budget = gr.Number(
label="Max Budget ($)",
value=lambda: None,
@@ -479,9 +565,7 @@ if __name__ == "__main__":
)
chatbot_history = gr.Chatbot(type="messages")
msg = gr.Textbox(
placeholder="Ask me to perform tasks in a virtual environment"
)
msg = gr.Textbox(placeholder="Ask me to perform tasks in a virtual environment")
clear = gr.Button("Clear")
cancel_button = gr.Button("Cancel", variant="stop")
@@ -498,11 +582,23 @@ if __name__ == "__main__":
global global_agent
if global_agent:
print("DEBUG - Cancelling agent task")
history.append(gr.ChatMessage(role="assistant", content="Task cancelled by user", metadata={"title": "❌ Cancelled"}))
history.append(
gr.ChatMessage(
role="assistant",
content="Task cancelled by user",
metadata={"title": "❌ Cancelled"},
)
)
else:
history.append(gr.ChatMessage(role="assistant", content="No active agent task to cancel", metadata={"title": " Info"}))
history.append(
gr.ChatMessage(
role="assistant",
content="No active agent task to cancel",
metadata={"title": " Info"},
)
)
return history
# Process response function
async def process_response(
history,
@@ -542,10 +638,13 @@ if __name__ == "__main__":
model_choice_value = uitars_model_value
else:
model_choice_value = "No models available"
# Determine if this is a custom model selection
is_custom_model_selected = model_choice_value in ["Custom model (OpenAI compatible API)", "Custom model (ollama)"]
is_custom_model_selected = model_choice_value in [
"Custom model (OpenAI compatible API)",
"Custom model (ollama)",
]
# Determine the model name string to analyze
if is_custom_model_selected:
model_string_to_analyze = custom_model_value
@@ -583,13 +682,19 @@ if __name__ == "__main__":
model_string=model_string,
save_trajectory=save_traj,
only_n_most_recent_images=recent_imgs,
custom_model_name=custom_model_value if is_custom_model_selected else None,
custom_model_name=(
custom_model_value if is_custom_model_selected else None
),
computer_os=computer_os,
computer_provider=computer_provider,
computer_name=container_name,
computer_api_key=cua_cloud_api_key,
verbosity=logging.DEBUG,
max_trajectory_budget=max_budget_value if max_budget_value and max_budget_value > 0 else None,
max_trajectory_budget=(
max_budget_value
if max_budget_value and max_budget_value > 0
else None
),
)
if global_agent is None:
@@ -605,7 +710,7 @@ if __name__ == "__main__":
# Add user message to global history
global global_messages
global_messages.append({"role": "user", "content": last_user_message})
# Stream responses from the agent
async for result in global_agent.run(global_messages):
global_messages += result.get("output", [])
@@ -613,18 +718,20 @@ if __name__ == "__main__":
# from pprint import pprint
# pprint(result)
# print(f"DEBUG - Agent response ------- END")
# Process the result output
for item in result.get("output", []):
if item.get("type") == "message":
content = item.get("content", [])
for content_part in content:
if content_part.get("text"):
history.append(gr.ChatMessage(
role=item.get("role", "assistant"),
content=content_part.get("text", ""),
metadata=content_part.get("metadata", {})
))
history.append(
gr.ChatMessage(
role=item.get("role", "assistant"),
content=content_part.get("text", ""),
metadata=content_part.get("metadata", {}),
)
)
elif item.get("type") == "computer_call":
action = item.get("action", {})
action_type = action.get("type", "")
@@ -632,43 +739,52 @@ if __name__ == "__main__":
action_title = f"🛠️ Performing {action_type}"
if action.get("x") and action.get("y"):
action_title += f" at ({action['x']}, {action['y']})"
history.append(gr.ChatMessage(
role="assistant",
content=f"```json\n{json.dumps(action)}\n```",
metadata={"title": action_title}
))
history.append(
gr.ChatMessage(
role="assistant",
content=f"```json\n{json.dumps(action)}\n```",
metadata={"title": action_title},
)
)
elif item.get("type") == "function_call":
function_name = item.get("name", "")
arguments = item.get("arguments", "{}")
history.append(gr.ChatMessage(
role="assistant",
content=f"🔧 Calling function: {function_name}\n```json\n{arguments}\n```",
metadata={"title": f"Function Call: {function_name}"}
))
history.append(
gr.ChatMessage(
role="assistant",
content=f"🔧 Calling function: {function_name}\n```json\n{arguments}\n```",
metadata={"title": f"Function Call: {function_name}"},
)
)
elif item.get("type") == "function_call_output":
output = item.get("output", "")
history.append(gr.ChatMessage(
role="assistant",
content=f"📤 Function output:\n```\n{output}\n```",
metadata={"title": "Function Output"}
))
history.append(
gr.ChatMessage(
role="assistant",
content=f"📤 Function output:\n```\n{output}\n```",
metadata={"title": "Function Output"},
)
)
elif item.get("type") == "computer_call_output":
output = item.get("output", {}).get("image_url", "")
image_markdown = f"![Computer output]({output})"
history.append(gr.ChatMessage(
role="assistant",
content=image_markdown,
metadata={"title": "🖥️ Computer Output"}
))
history.append(
gr.ChatMessage(
role="assistant",
content=image_markdown,
metadata={"title": "🖥️ Computer Output"},
)
)
yield history
except Exception as e:
import traceback
traceback.print_exc()
history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
yield history
# Connect the submit button
submit_event = msg.submit(
fn=chat_submit,
@@ -706,44 +822,77 @@ if __name__ == "__main__":
global global_messages
global_messages.clear()
return None
clear.click(clear_chat, None, chatbot_history, queue=False)
# Connect cancel button
cancel_button.click(
cancel_agent_task,
[chatbot_history],
[chatbot_history],
queue=False
cancel_agent_task, [chatbot_history], [chatbot_history], queue=False
)
# Code display update function
def update_code_display(agent_loop, model_choice_val, custom_model_val, chat_history, recent_images_val, save_trajectory_val, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget_val):
def update_code_display(
agent_loop,
model_choice_val,
custom_model_val,
chat_history,
recent_images_val,
save_trajectory_val,
computer_os,
computer_provider,
container_name,
cua_cloud_api_key,
max_budget_val,
):
messages = []
if chat_history:
for msg in chat_history:
if isinstance(msg, dict) and msg.get("role") == "user":
messages.append(msg.get("content", ""))
return generate_python_code(
agent_loop,
model_choice_val or custom_model_val or "gpt-4o",
messages,
agent_loop,
model_choice_val or custom_model_val or "gpt-4o",
messages,
recent_images_val,
save_trajectory_val,
computer_os,
computer_provider,
container_name,
cua_cloud_api_key,
max_budget_val
max_budget_val,
)
# Update code display when configuration changes
for component in [agent_loop, model_choice, custom_model, chatbot_history, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget]:
for component in [
agent_loop,
model_choice,
custom_model,
chatbot_history,
recent_images,
save_trajectory,
computer_os,
computer_provider,
container_name,
cua_cloud_api_key,
max_budget,
]:
component.change(
update_code_display,
inputs=[agent_loop, model_choice, custom_model, chatbot_history, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget],
outputs=[code_display]
inputs=[
agent_loop,
model_choice,
custom_model,
chatbot_history,
recent_images,
save_trajectory,
computer_os,
computer_provider,
container_name,
cua_cloud_api_key,
max_budget,
],
outputs=[code_display],
)
return demo

View File

@@ -5,26 +5,30 @@ This directory contains benchmarks designed to test agent providers in the Compu
## Overview
The benchmark system evaluates models on GUI grounding tasks, specifically click prediction accuracy. It supports both:
- **Computer Agent SDK providers** (using model strings like `"huggingface-local/HelloKKMe/GTA1-7B"`)
- **Reference agent implementations** (custom model classes implementing the `ModelProtocol`)
## Available Benchmarks
### 1. ScreenSpot-v2 (`ss-v2.py`)
- **Dataset**: ScreenSpot-v2 (click-only GUI grounding)
- **Format**: Standard resolution screenshots
- **Task**: Predict click coordinates given an instruction and image
- **Metrics**: Accuracy, Error Rate, Timing, VRAM usage
### 2. ScreenSpot-Pro (`ss-pro.py`)
### 2. ScreenSpot-Pro (`ss-pro.py`)
- **Dataset**: ScreenSpot-Pro (high-resolution click-only GUI grounding)
- **Format**: High-resolution screenshots
- **Task**: Predict click coordinates given an instruction and image
- **Metrics**: Accuracy, Error Rate, Timing, VRAM usage
### 3. Interactive Testing (`interactive.py`)
- **Real-time testing**: Take screenshots and visualize model predictions
- **Commands**:
- **Commands**:
- Type instruction → test all models on last screenshot
- `screenshot` → take screenshot
- `models` → list available models
@@ -34,14 +38,16 @@ The benchmark system evaluates models on GUI grounding tasks, specifically click
## Running Benchmarks
### 1. Configure Models
Edit `utils.py` to specify which models you want to test in `get_available_models()`.
### 2. Run Benchmark
```bash
# ScreenSpot-v2 benchmark
python ss-v2.py --samples 50
# ScreenSpot-Pro benchmark
# ScreenSpot-Pro benchmark
python ss-pro.py --samples 50
# Interactive testing
@@ -51,6 +57,7 @@ python interactive.py
## Output
### Console Output
```
Model Results:
Accuracy: 85.50% (171/200)
@@ -59,10 +66,11 @@ Model Results:
```
### Generated Files
- **Markdown Report**: `*_results.md` with detailed results tables
- **Visualizations**: `output/` directory with prediction visualizations
- **Interactive Output**: `interactive_output/` for interactive session results
## Contributing
To add a new reference model, follow the instructions in [contrib.md](contrib.md).
To add a new reference model, follow the instructions in [contrib.md](contrib.md).

View File

@@ -17,29 +17,29 @@ class YourModelName(ModelProtocol):
def __init__(self, model_path: str):
self.model_path = model_path
self._model = None
@property
def model_name(self) -> str:
return self.model_path
async def load_model(self) -> None:
"""Load the model into memory."""
# Your model loading logic here
pass
async def unload_model(self) -> None:
"""Unload the model from memory."""
# Your model cleanup logic here
pass
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates for the given image and instruction.
Args:
image: PIL Image to analyze
instruction: Text instruction describing what to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
@@ -56,7 +56,7 @@ def get_available_models() -> List[Union[str, ModelProtocol]]:
models = [
# Computer Agent SDK providers
"huggingface-local/HelloKKMe/GTA1-7B",
# Reference implementations
GTA1Model("HelloKKMe/GTA1-7B"),
YourModelName("path/to/your/model"), # Add your model here
@@ -79,6 +79,7 @@ This will help you verify that your model loads correctly and produces reasonabl
Here's a complete example of adding a hypothetical "MyVisionModel":
1. **Create `models/my_vision_model.py`:**
```python
import torch
from transformers import AutoModel, AutoProcessor
@@ -91,11 +92,11 @@ class MyVisionModel(ModelProtocol):
self.model_path = model_path
self.model = None
self.processor = None
@property
def model_name(self) -> str:
return f"MyVisionModel({self.model_path})"
async def load_model(self) -> None:
"""Load the model and processor."""
self.processor = AutoProcessor.from_pretrained(self.model_path)
@@ -104,7 +105,7 @@ class MyVisionModel(ModelProtocol):
torch_dtype=torch.float16,
device_map="auto"
)
async def unload_model(self) -> None:
"""Clean up model resources."""
del self.model
@@ -112,7 +113,7 @@ class MyVisionModel(ModelProtocol):
self.model = None
self.processor = None
torch.cuda.empty_cache()
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
"""Predict click coordinates."""
try:
@@ -122,19 +123,19 @@ class MyVisionModel(ModelProtocol):
images=image,
return_tensors="pt"
)
# Run inference
with torch.no_grad():
outputs = self.model(**inputs)
# Extract coordinates (model-specific logic)
x, y = self._extract_coordinates(outputs)
return (int(x), int(y))
except Exception as e:
print(f"Prediction failed: {e}")
return None
def _extract_coordinates(self, outputs):
"""Extract x, y coordinates from model outputs."""
# Your model-specific coordinate extraction logic
@@ -142,6 +143,7 @@ class MyVisionModel(ModelProtocol):
```
2. **Update `models/__init__.py`:**
```python
from .gta1 import GTA1Model
from .my_vision_model import MyVisionModel
@@ -150,6 +152,7 @@ __all__ = ["GTA1Model", "MyVisionModel"]
```
3. **Update `utils.py`:**
```python
from models import GTA1Model, MyVisionModel

View File

@@ -9,60 +9,56 @@ Models are loaded/unloaded one at a time to avoid memory issues.
import asyncio
import os
from datetime import datetime
from typing import List, Dict, Any
from typing import Any, Dict, List
from utils import (
ModelWrapper,
take_screenshot,
get_available_models,
save_prediction_visualization,
get_available_models
take_screenshot,
)
async def predict_with_all_models(image, instruction: str, models) -> List[Dict[str, Any]]:
"""
Predict click coordinates with all models sequentially.
Args:
image: PIL Image to analyze
instruction: Instruction text
models: List of model instances
Returns:
List of prediction results
"""
predictions = []
for model in models:
model_wrapper = ModelWrapper(model)
print(f"\n🔄 Loading {model_wrapper.model_name}...")
try:
# Load model
await model_wrapper.load_model()
# Predict
coords = await model_wrapper.predict_click(image, instruction)
predictions.append({
'model_name': model_wrapper.model_name,
'coords': coords,
'error': None
})
predictions.append(
{"model_name": model_wrapper.model_name, "coords": coords, "error": None}
)
if coords:
print(f"{model_wrapper.model_name}: ({coords[0]}, {coords[1]})")
else:
print(f"{model_wrapper.model_name}: No prediction")
except Exception as e:
print(f"{model_wrapper.model_name}: ERROR - {str(e)}")
predictions.append({
'model_name': model_wrapper.model_name,
'coords': None,
'error': str(e)
})
predictions.append(
{"model_name": model_wrapper.model_name, "coords": None, "error": str(e)}
)
finally:
# Always unload model to free memory
try:
@@ -70,7 +66,7 @@ async def predict_with_all_models(image, instruction: str, models) -> List[Dict[
print(f"🗑️ Unloaded {model_wrapper.model_name}")
except Exception as e:
print(f"⚠️ Error unloading {model_wrapper.model_name}: {e}")
return predictions
@@ -103,87 +99,91 @@ async def main():
Main interactive loop.
"""
print_header()
# Get available models
models = get_available_models()
print_models(models)
# Create output directory for visualizations
output_dir = "interactive_output"
os.makedirs(output_dir, exist_ok=True)
session_count = 0
last_screenshot = None
screenshot_timestamp = None
while True:
try:
# Get user input
print(f"\n{'='*40}")
user_input = input("🎯 Enter instruction (or command): ").strip()
if not user_input:
continue
# Handle commands
if user_input.lower() in ['quit', 'exit', 'q']:
if user_input.lower() in ["quit", "exit", "q"]:
print("👋 Goodbye!")
break
elif user_input.lower() == 'models':
elif user_input.lower() == "models":
print_models(models)
continue
elif user_input.lower() == 'screenshot':
elif user_input.lower() == "screenshot":
print("📸 Taking screenshot...")
try:
last_screenshot = take_screenshot()
screenshot_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
screenshot_path = os.path.join(output_dir, f"screenshot_{screenshot_timestamp}.png")
screenshot_path = os.path.join(
output_dir, f"screenshot_{screenshot_timestamp}.png"
)
last_screenshot.save(screenshot_path)
print(f"✅ Screenshot captured and saved to: {screenshot_path}")
print(f"📝 Ready for instructions! Screenshot size: {last_screenshot.size}")
except Exception as e:
print(f"❌ Error taking screenshot: {e}")
continue
# Handle instruction input
if last_screenshot is None:
print("⚠️ No screenshot available! Please take a screenshot first using 'screenshot' command.")
print(
"⚠️ No screenshot available! Please take a screenshot first using 'screenshot' command."
)
continue
session_count += 1
print(f"\n🎯 Session {session_count}: '{user_input}'")
print(f"📷 Using screenshot from: {screenshot_timestamp}")
# Predict with all models using last screenshot
print(f"\n🤖 Testing {len(models)} models on screenshot...")
predictions = await predict_with_all_models(last_screenshot, user_input, models)
# Display results summary
print(f"\n📊 Results Summary:")
print("\n📊 Results Summary:")
print("-" * 50)
for pred in predictions:
if pred['coords']:
if pred["coords"]:
print(f"{pred['model_name']}: ({pred['coords'][0]}, {pred['coords'][1]})")
elif pred['error']:
elif pred["error"]:
print(f"{pred['model_name']}: ERROR - {pred['error']}")
else:
print(f"{pred['model_name']}: No prediction")
# Save visualization
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
vis_filename = f"session_{session_count:03d}_{timestamp}.png"
vis_path = os.path.join(output_dir, vis_filename)
try:
save_prediction_visualization(last_screenshot, user_input, predictions, vis_path)
print(f"\n💾 Visualization saved to: {vis_path}")
except Exception as e:
print(f"⚠️ Error saving visualization: {e}")
print(f"\n✨ Session {session_count} completed!")
except KeyboardInterrupt:
print("\n\n👋 Interrupted by user. Goodbye!")
break

View File

@@ -2,34 +2,37 @@
Base protocol for benchmark models.
"""
from typing import Protocol, Optional, Tuple
from typing import Optional, Protocol, Tuple
from PIL import Image
class ModelProtocol(Protocol):
"""Protocol for benchmark models that can predict click coordinates."""
@property
def model_name(self) -> str:
"""Return the name of the model."""
...
async def load_model(self) -> None:
"""Load the model into memory."""
...
async def unload_model(self) -> None:
"""Unload the model from memory."""
...
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
async def predict_click(
self, image: Image.Image, instruction: str
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates for the given image and instruction.
Args:
image: PIL Image to analyze
instruction: Text instruction describing what to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""

View File

@@ -2,54 +2,51 @@
GTA1 model implementation for benchmarking.
"""
from typing import Optional, Tuple
from PIL import Image
import torch
import re
import gc
import re
from typing import Optional, Tuple
import torch
from PIL import Image
from qwen_vl_utils import process_vision_info, smart_resize
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
from .base import ModelProtocol
class GTA1Model:
"""Ground truth GTA1 model implementation."""
def __init__(self, model_path: str = "HelloKKMe/GTA1-7B"):
self.model_path = model_path
self.model = None
self.processor = None
self.max_new_tokens = 32
self.system_prompt = '''
self.system_prompt = """
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
Output the coordinate pair exactly:
(x,y)
'''.strip()
""".strip()
@property
def model_name(self) -> str:
"""Return the name of the model."""
return f"GTA1-{self.model_path.split('/')[-1]}"
async def load_model(self) -> None:
"""Load the model into memory."""
if self.model is None:
print(f"Loading GTA1 model: {self.model_path}")
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.model_path,
torch_dtype=torch.bfloat16,
device_map="auto"
self.model_path, torch_dtype=torch.bfloat16, device_map="auto"
)
self.processor = AutoProcessor.from_pretrained(
self.model_path,
min_pixels=3136,
max_pixels=4096 * 2160
self.model_path, min_pixels=3136, max_pixels=4096 * 2160
)
print("GTA1 model loaded successfully")
async def unload_model(self) -> None:
"""Unload the model from memory."""
if self.model is not None:
@@ -62,23 +59,25 @@ Output the coordinate pair exactly:
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("GTA1 model unloaded")
def _extract_coordinates(self, raw_string: str) -> Tuple[int, int]:
"""Extract coordinates from model output."""
try:
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
return tuple(map(int, map(float, matches[0]))) # type: ignore
return tuple(map(int, map(float, matches[0]))) # type: ignore
except:
return (0, 0)
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
async def predict_click(
self, image: Image.Image, instruction: str
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates for the given image and instruction.
Args:
image: PIL Image to analyze
instruction: Text instruction describing what to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
@@ -87,76 +86,73 @@ Output the coordinate pair exactly:
assert self.processor is not None
assert self.model is not None
try:
width, height = image.width, image.height
# Resize image according to processor requirements
resized_height, resized_width = smart_resize(
image.height,
image.width,
factor=self.processor.image_processor.patch_size * self.processor.image_processor.merge_size,
factor=self.processor.image_processor.patch_size
* self.processor.image_processor.merge_size,
min_pixels=self.processor.image_processor.min_pixels,
max_pixels=self.processor.image_processor.max_pixels,
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height
# Prepare messages
system_message = {
"role": "system",
"content": self.system_prompt.format(height=resized_height, width=resized_width)
"content": self.system_prompt.format(height=resized_height, width=resized_width),
}
user_message = {
"role": "user",
"content": [
{"type": "image", "image": resized_image},
{"type": "text", "text": instruction}
]
{"type": "text", "text": instruction},
],
}
# Process inputs
image_inputs, video_inputs = process_vision_info([system_message, user_message]) # type: ignore
image_inputs, video_inputs = process_vision_info([system_message, user_message]) # type: ignore
text = self.processor.apply_chat_template(
[system_message, user_message],
tokenize=False,
add_generation_prompt=True
[system_message, user_message], tokenize=False, add_generation_prompt=True
)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt",
)
inputs = inputs.to(self.model.device)
# Generate prediction
output_ids = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=False,
temperature=1.0,
use_cache=True
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=False,
temperature=1.0,
use_cache=True,
)
generated_ids = [
output_ids[len(input_ids):]
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
output_text = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)[0]
# Extract and rescale coordinates
pred_x, pred_y = self._extract_coordinates(output_text)
pred_x = int(pred_x * scale_x)
pred_y = int(pred_y * scale_y)
return (pred_x, pred_y)
except Exception as e:
print(f"Error in GTA1 prediction: {e}")
return None

View File

@@ -15,103 +15,106 @@ from typing import Optional
from datasets import load_dataset
from tqdm import tqdm
from utils import (
ModelWrapper,
is_click_in_bbox,
save_results_to_markdown,
save_visualizations,
ModelWrapper,
get_available_models,
get_gpu_memory
get_gpu_memory,
is_click_in_bbox,
save_results_to_markdown,
save_visualizations,
)
async def evaluate_model(model_wrapper: ModelWrapper, dataset, max_samples: Optional[int] = None) -> dict:
async def evaluate_model(
model_wrapper: ModelWrapper, dataset, max_samples: Optional[int] = None
) -> dict:
"""
Evaluate a model on the ScreenSpot-Pro dataset.
Args:
model_wrapper: ModelWrapper instance
dataset: ScreenSpot-Pro dataset (list of samples)
max_samples: Maximum number of samples to evaluate (None for all)
Returns:
Dictionary with evaluation results
"""
print(f"\nEvaluating model: {model_wrapper.model_name}")
# Load model
await model_wrapper.load_model()
total_samples = len(dataset)
if max_samples is not None:
total_samples = min(max_samples, total_samples)
correct_predictions = 0
error_predictions = 0
results = []
for i in tqdm(range(total_samples), desc=f"Evaluating {model_wrapper.model_name}"):
sample = dataset[i]
# Extract sample data
image = sample['image']
instruction = sample['instruction']
bbox = sample['bbox'] # [x1, y1, x2, y2]
sample_id = sample['img_filename']
image = sample["image"]
instruction = sample["instruction"]
bbox = sample["bbox"] # [x1, y1, x2, y2]
sample_id = sample["img_filename"]
# Predict click coordinates with timing
start_time = time.time()
click_coords = await model_wrapper.predict_click(image, instruction)
prediction_time = time.time() - start_time
# Check if prediction is correct
is_correct = is_click_in_bbox(click_coords, bbox)
if is_correct:
correct_predictions += 1
results.append({
'id': sample_id,
'instruction': instruction,
'bbox': bbox,
'predicted_coords': click_coords,
'is_correct': is_correct,
'failed': False,
'prediction_time': prediction_time
})
results.append(
{
"id": sample_id,
"instruction": instruction,
"bbox": bbox,
"predicted_coords": click_coords,
"is_correct": is_correct,
"failed": False,
"prediction_time": prediction_time,
}
)
# Unload model
await model_wrapper.unload_model()
# Calculate metrics
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
error_rate = error_predictions / total_samples if total_samples > 0 else 0.0
# Calculate timing statistics
successful_times = [r['prediction_time'] for r in results if not r['failed']]
successful_times = [r["prediction_time"] for r in results if not r["failed"]]
avg_prediction_time = sum(successful_times) / len(successful_times) if successful_times else 0.0
median_prediction_time = statistics.median(successful_times) if successful_times else 0.0
min_prediction_time = min(successful_times) if successful_times else 0.0
max_prediction_time = max(successful_times) if successful_times else 0.0
# Get VRAM statistics
vram_stats = model_wrapper.get_vram_stats()
return {
'model_name': model_wrapper.model_name,
'total_samples': total_samples,
'correct_predictions': correct_predictions,
'failed_predictions': error_predictions,
'accuracy': accuracy,
'failure_rate': error_rate,
'avg_prediction_time': avg_prediction_time,
'median_prediction_time': median_prediction_time,
'min_prediction_time': min_prediction_time,
'max_prediction_time': max_prediction_time,
'vram_max_mb': vram_stats['max_mb'],
'vram_avg_mb': vram_stats['avg_mb'],
'results': results
"model_name": model_wrapper.model_name,
"total_samples": total_samples,
"correct_predictions": correct_predictions,
"failed_predictions": error_predictions,
"accuracy": accuracy,
"failure_rate": error_rate,
"avg_prediction_time": avg_prediction_time,
"median_prediction_time": median_prediction_time,
"min_prediction_time": min_prediction_time,
"max_prediction_time": max_prediction_time,
"vram_max_mb": vram_stats["max_mb"],
"vram_avg_mb": vram_stats["avg_mb"],
"results": results,
}
@@ -120,42 +123,44 @@ async def main():
Main function to run the benchmark.
"""
# Parse command line arguments
parser = argparse.ArgumentParser(description='ScreenSpot-Pro Benchmark Script')
parser.add_argument('--samples', type=int, default=300,
help='Number of samples to evaluate (default: 300)')
parser.add_argument('--seed', type=int, default=42,
help='Random seed for shuffling (default: 42)')
parser = argparse.ArgumentParser(description="ScreenSpot-Pro Benchmark Script")
parser.add_argument(
"--samples", type=int, default=300, help="Number of samples to evaluate (default: 300)"
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for shuffling (default: 42)"
)
args = parser.parse_args()
# Set random seed
random.seed(args.seed)
# Load dataset
print("Loading ScreenSpot-Pro dataset...")
ds = load_dataset("lmms-lab/ScreenSpot-Pro")
dataset = ds['train'] # type: ignore
dataset = ds["train"] # type: ignore
# Convert to list to support indexing
dataset_list = list(dataset)
print(f"Dataset loaded: {len(dataset_list)} samples")
# Shuffle dataset with seed
random.shuffle(dataset_list)
print(f"Dataset shuffled with seed {args.seed}")
# Get available models
models = get_available_models()
# Evaluation settings
max_samples = args.samples # Use command line argument
# Run evaluations
all_results = []
for model in models:
model_wrapper = ModelWrapper(model)
result = await evaluate_model(model_wrapper, dataset_list, max_samples)
all_results.append(result)
# Print summary
print(f"\n{result['model_name']} Results:")
print(f" Accuracy: {result['accuracy']*100:.2f}%")
@@ -164,15 +169,17 @@ async def main():
print(f" Error Rate: {result['failure_rate']*100:.2f}%")
print(f" Avg Time: {result['avg_prediction_time']:.2f}s")
print(f" Median Time: {result['median_prediction_time']:.2f}s")
print(f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s")
print(
f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s"
)
print(f" VRAM Max: {result['vram_max_mb']:.1f}MB")
print(f" VRAM Avg: {result['vram_avg_mb']:.1f}MB")
# Print GPU memory info
gpu_memory = get_gpu_memory()
if gpu_memory and gpu_memory[0] > 0:
print(f" GPU Free Memory: {gpu_memory[0]:.1f}MB")
# Save results
if all_results:
save_results_to_markdown(all_results)
@@ -183,4 +190,4 @@ async def main():
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

View File

@@ -15,36 +15,37 @@ from typing import Optional
from datasets import load_dataset
from tqdm import tqdm
from utils import (
ModelWrapper,
is_click_in_bbox,
save_results_to_markdown,
save_visualizations,
ModelWrapper,
get_available_models,
get_gpu_memory
get_gpu_memory,
is_click_in_bbox,
save_results_to_markdown,
save_visualizations,
)
async def evaluate_model(model_wrapper: ModelWrapper, samples, max_samples: Optional[int] = None) -> dict:
async def evaluate_model(
model_wrapper: ModelWrapper, samples, max_samples: Optional[int] = None
) -> dict:
"""
Evaluate a model on any iterable of samples.
Args:
model_wrapper: ModelWrapper instance
samples: Iterable of dicts with keys: image, bbox, instruction
max_samples: Maximum number of samples to evaluate (None for all)
Returns:
Dictionary with evaluation results
"""
print(f"\nEvaluating model: {model_wrapper.model_name}")
# Load model
await model_wrapper.load_model()
# Convert to list if needed and limit samples
if hasattr(samples, '__len__'):
if hasattr(samples, "__len__"):
total_samples = len(samples)
if max_samples is not None:
total_samples = min(max_samples, total_samples)
@@ -55,69 +56,71 @@ async def evaluate_model(model_wrapper: ModelWrapper, samples, max_samples: Opti
if max_samples is not None:
sample_list = sample_list[:max_samples]
total_samples = len(sample_list)
correct_predictions = 0
error_predictions = 0
results = []
for i, sample in enumerate(tqdm(sample_list, desc=f"Evaluating {model_wrapper.model_name}")):
# Extract required data (only these 3 keys matter)
image = sample['image']
instruction = sample['instruction']
bbox = sample['bbox'] # [x1, y1, x2, y2]
image = sample["image"]
instruction = sample["instruction"]
bbox = sample["bbox"] # [x1, y1, x2, y2]
# Predict click coordinates with timing
start_time = time.time()
click_coords = await model_wrapper.predict_click(image, instruction)
prediction_time = time.time() - start_time
# Check if prediction is correct
is_correct = is_click_in_bbox(click_coords, bbox)
if is_correct:
correct_predictions += 1
results.append({
'sample_idx': i,
'instruction': instruction,
'bbox': bbox,
'predicted_coords': click_coords,
'is_correct': is_correct,
'failed': False,
'prediction_time': prediction_time
})
results.append(
{
"sample_idx": i,
"instruction": instruction,
"bbox": bbox,
"predicted_coords": click_coords,
"is_correct": is_correct,
"failed": False,
"prediction_time": prediction_time,
}
)
# Unload model
await model_wrapper.unload_model()
# Calculate metrics
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
error_rate = error_predictions / total_samples if total_samples > 0 else 0.0
# Calculate timing statistics
successful_times = [r['prediction_time'] for r in results if not r['failed']]
successful_times = [r["prediction_time"] for r in results if not r["failed"]]
avg_prediction_time = sum(successful_times) / len(successful_times) if successful_times else 0.0
median_prediction_time = statistics.median(successful_times) if successful_times else 0.0
min_prediction_time = min(successful_times) if successful_times else 0.0
max_prediction_time = max(successful_times) if successful_times else 0.0
# Get VRAM statistics
vram_stats = model_wrapper.get_vram_stats()
return {
'model_name': model_wrapper.model_name,
'total_samples': total_samples,
'correct_predictions': correct_predictions,
'failed_predictions': error_predictions,
'accuracy': accuracy,
'failure_rate': error_rate,
'avg_prediction_time': avg_prediction_time,
'median_prediction_time': median_prediction_time,
'min_prediction_time': min_prediction_time,
'max_prediction_time': max_prediction_time,
'vram_max_mb': vram_stats['max_mb'],
'vram_avg_mb': vram_stats['avg_mb'],
'results': results
"model_name": model_wrapper.model_name,
"total_samples": total_samples,
"correct_predictions": correct_predictions,
"failed_predictions": error_predictions,
"accuracy": accuracy,
"failure_rate": error_rate,
"avg_prediction_time": avg_prediction_time,
"median_prediction_time": median_prediction_time,
"min_prediction_time": min_prediction_time,
"max_prediction_time": max_prediction_time,
"vram_max_mb": vram_stats["max_mb"],
"vram_avg_mb": vram_stats["avg_mb"],
"results": results,
}
@@ -126,56 +129,60 @@ async def main():
Main function to run the benchmark.
"""
# Parse command line arguments
parser = argparse.ArgumentParser(description='ScreenSpot-v2 Benchmark Script')
parser.add_argument('--samples', type=int, default=500,
help='Number of samples to evaluate (default: 500)')
parser.add_argument('--seed', type=int, default=42,
help='Random seed for shuffling (default: 42)')
parser = argparse.ArgumentParser(description="ScreenSpot-v2 Benchmark Script")
parser.add_argument(
"--samples", type=int, default=500, help="Number of samples to evaluate (default: 500)"
)
parser.add_argument(
"--seed", type=int, default=42, help="Random seed for shuffling (default: 42)"
)
args = parser.parse_args()
# Set random seed
random.seed(args.seed)
# Load dataset
print("Loading ScreenSpot-v2 dataset...")
ds = load_dataset("lmms-lab/ScreenSpot-v2")
dataset = ds['train'] # type: ignore
dataset = ds["train"] # type: ignore
# Convert to simple list of dicts with only required keys
samples = []
for item in dataset:
# Convert dataset item to dict if needed
item_dict = dict(item) if hasattr(item, 'keys') else item
item_dict = dict(item) if hasattr(item, "keys") else item
# Convert ScreenSpot-v2 bbox format [x, y, w, h] to [x1, y1, x2, y2]
bbox_xywh = item_dict['bbox'] # type: ignore
bbox_xywh = item_dict["bbox"] # type: ignore
x, y, w, h = bbox_xywh
bbox_xyxy = [x, y, x + w, y + h]
samples.append({
'image': item_dict['image'], # type: ignore
'instruction': item_dict['instruction'], # type: ignore
'bbox': bbox_xyxy
})
samples.append(
{
"image": item_dict["image"], # type: ignore
"instruction": item_dict["instruction"], # type: ignore
"bbox": bbox_xyxy,
}
)
print(f"Dataset loaded: {len(samples)} samples")
# Shuffle samples with seed
random.shuffle(samples)
print(f"Samples shuffled with seed {args.seed}")
# Get available models
models = get_available_models()
# Evaluation settings
max_samples = args.samples # Use command line argument
# Run evaluations
all_results = []
for model in models:
model_wrapper = ModelWrapper(model)
result = await evaluate_model(model_wrapper, samples, max_samples)
all_results.append(result)
# Print summary
print(f"\n{result['model_name']} Results:")
print(f" Accuracy: {result['accuracy']*100:.2f}%")
@@ -184,18 +191,22 @@ async def main():
print(f" Error Rate: {result['failure_rate']*100:.2f}%")
print(f" Avg Time: {result['avg_prediction_time']:.2f}s")
print(f" Median Time: {result['median_prediction_time']:.2f}s")
print(f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s")
print(
f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s"
)
print(f" VRAM Max: {result['vram_max_mb']:.1f}MB")
print(f" VRAM Avg: {result['vram_avg_mb']:.1f}MB")
# Print GPU memory info
gpu_memory = get_gpu_memory()
if gpu_memory and gpu_memory[0] > 0:
print(f" GPU Free Memory: {gpu_memory[0]:.1f}MB")
# Save results
if all_results:
save_results_to_markdown(all_results, "screenspot_v2_results.md", title="ScreenSpot-v2 Benchmark Results")
save_results_to_markdown(
all_results, "screenspot_v2_results.md", title="ScreenSpot-v2 Benchmark Results"
)
save_visualizations(all_results, samples)
print("\nBenchmark completed successfully!")
else:
@@ -203,4 +214,4 @@ async def main():
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

View File

@@ -4,38 +4,40 @@ Shared utilities for ScreenSpot-Pro benchmarking and interactive testing.
"""
import dotenv
dotenv.load_dotenv()
import asyncio
import base64
import gc
import os
import sys
import subprocess as sp
import statistics
import subprocess as sp
import sys
from datetime import datetime
from io import BytesIO
from typing import List, Union, Tuple, Optional
from typing import List, Optional, Tuple, Union
import torch
from PIL import Image, ImageDraw
from tqdm import tqdm
import gc
import torch
# Add parent directory to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from agent.agent import ComputerAgent
from models.base import ModelProtocol
def get_gpu_memory() -> List[int]:
"""
Get GPU memory usage using nvidia-smi.
Returns:
List of free memory values in MB for each GPU
"""
try:
command = "nvidia-smi --query-gpu=memory.free --format=csv"
memory_free_info = sp.check_output(command.split()).decode('ascii').split('\n')[:-1][1:]
memory_free_info = sp.check_output(command.split()).decode("ascii").split("\n")[:-1][1:]
memory_free_values = [int(x.split()[0]) for i, x in enumerate(memory_free_info)]
return memory_free_values
except (sp.CalledProcessError, FileNotFoundError, IndexError):
@@ -51,39 +53,34 @@ def get_gpu_memory() -> List[int]:
def get_vram_usage() -> dict:
"""
Get current VRAM usage statistics.
Returns:
Dictionary with VRAM usage info (in MB)
"""
if torch.cuda.is_available():
device = torch.cuda.current_device()
allocated = torch.cuda.memory_allocated(device) / 1024 / 1024 # Convert to MB
reserved = torch.cuda.memory_reserved(device) / 1024 / 1024 # Convert to MB
reserved = torch.cuda.memory_reserved(device) / 1024 / 1024 # Convert to MB
total = torch.cuda.get_device_properties(device).total_memory / 1024 / 1024
return {
'allocated_mb': allocated,
'reserved_mb': reserved,
'total_mb': total,
'free_mb': total - reserved
"allocated_mb": allocated,
"reserved_mb": reserved,
"total_mb": total,
"free_mb": total - reserved,
}
else:
return {
'allocated_mb': 0.0,
'reserved_mb': 0.0,
'total_mb': 0.0,
'free_mb': 0.0
}
return {"allocated_mb": 0.0, "reserved_mb": 0.0, "total_mb": 0.0, "free_mb": 0.0}
def get_available_models() -> List[Union[str, ModelProtocol]]:
"""
Get list of available models for testing.
Returns:
List of model strings and model classes
"""
local_provider = "huggingface-local/" # Options: huggingface-local/ or mlx/
# from models.gta1 import GTA1Model
models = [
@@ -94,42 +91,41 @@ def get_available_models() -> List[Union[str, ModelProtocol]]:
# f"{local_provider}HelloKKMe/GTA1-32B",
"openai/computer-use-preview+openai/gpt-4o-mini",
"anthropic/claude-opus-4-20250514+openai/gpt-4o-mini",
# === Reference model classes ===
# GTA1Model("HelloKKMe/GTA1-7B"),
# GTA1Model("HelloKKMe/GTA1-32B"),
# GTA1Model("HelloKKMe/GTA1-32B"),
]
return models
def is_click_in_bbox(click_coords: Optional[Tuple[int, int]], bbox: List[int]) -> bool:
"""
Check if click coordinates are within the bounding box.
Args:
click_coords: (x, y) coordinates or None
bbox: [x1, y1, x2, y2] bounding box
Returns:
True if click is within bbox, False otherwise
"""
if click_coords is None:
return False
x, y = click_coords
x1, y1, x2, y2 = bbox
return x1 <= x <= x2 and y1 <= y <= y2
def image_to_base64(image: Image.Image) -> str:
"""
Convert PIL Image to base64 string.
Args:
image: PIL Image
Returns:
Base64 encoded image string
"""
@@ -142,213 +138,252 @@ class ModelWrapper:
"""
Wrapper to provide unified interface for both ComputerAgent and custom models.
"""
def __init__(self, model: Union[str, ModelProtocol]):
self.model = model
self.is_computer_agent = isinstance(model, str)
self.agent: Optional[ComputerAgent] = None
self.vram_usage_history: List[float] = [] # Track VRAM usage over time
if self.is_computer_agent:
self.model_name = str(model)
else:
self.model_name = f"{model.__class__.__name__}('{getattr(model, 'model_name', 'unknown')}')"
self.model_name = (
f"{model.__class__.__name__}('{getattr(model, 'model_name', 'unknown')}')"
)
async def load_model(self) -> None:
"""Load the model."""
if self.is_computer_agent:
self.agent = ComputerAgent(model=str(self.model))
else:
await self.model.load_model() # type: ignore
await self.model.load_model() # type: ignore
# Record initial VRAM usage after loading
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
self.vram_usage_history.append(vram_info["allocated_mb"])
async def unload_model(self) -> None:
"""Unload the model."""
if not self.is_computer_agent:
await self.model.unload_model() # type: ignore
await self.model.unload_model() # type: ignore
else:
del self.agent
self.agent = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Record VRAM usage after unloading
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
self.vram_usage_history.append(vram_info["allocated_mb"])
def get_vram_stats(self) -> dict:
"""Get VRAM usage statistics for this model."""
if not self.vram_usage_history:
return {'max_mb': 0.0, 'avg_mb': 0.0}
return {"max_mb": 0.0, "avg_mb": 0.0}
return {
'max_mb': max(self.vram_usage_history),
'avg_mb': sum(self.vram_usage_history) / len(self.vram_usage_history)
"max_mb": max(self.vram_usage_history),
"avg_mb": sum(self.vram_usage_history) / len(self.vram_usage_history),
}
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
async def predict_click(
self, image: Image.Image, instruction: str
) -> Optional[Tuple[int, int]]:
"""Predict click coordinates."""
# Record VRAM usage before prediction
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
self.vram_usage_history.append(vram_info["allocated_mb"])
if self.is_computer_agent:
if self.agent is None:
await self.load_model()
if self.agent is not None:
image_b64 = image_to_base64(image)
result = await self.agent.predict_click(instruction=instruction, image_b64=image_b64)
result = await self.agent.predict_click(
instruction=instruction, image_b64=image_b64
)
# Record VRAM usage after prediction
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
self.vram_usage_history.append(vram_info["allocated_mb"])
return result
return None
else:
result = await self.model.predict_click(image, instruction) # type: ignore
result = await self.model.predict_click(image, instruction) # type: ignore
# Record VRAM usage after prediction
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
self.vram_usage_history.append(vram_info["allocated_mb"])
return result
def save_results_to_markdown(all_results: List[dict],output_file: str = "screenspot_pro_results.md", title: str = "ScreenSpot-Pro Benchmark Results") -> None:
def save_results_to_markdown(
all_results: List[dict],
output_file: str = "screenspot_pro_results.md",
title: str = "ScreenSpot-Pro Benchmark Results",
) -> None:
"""
Save evaluation results to a markdown table.
Args:
all_results: List of evaluation results for each model
output_file: Output markdown file path
"""
with open(output_file, 'w', encoding='utf-8') as f:
with open(output_file, "w", encoding="utf-8") as f:
f.write(f"# {title}\n\n")
f.write(f"**Evaluation Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
# Summary table
f.write("## Summary\n\n")
f.write("| Model | Total Samples | Correct | Errors | Accuracy | Error Rate | Avg Time (s) | Median Time (s) | Time Range (s) | VRAM Max (GB) | VRAM Avg (GB) |\n")
f.write("|-------|---------------|---------|--------|----------|------------|--------------|-----------------|----------------|---------------|---------------|\n")
f.write(
"| Model | Total Samples | Correct | Errors | Accuracy | Error Rate | Avg Time (s) | Median Time (s) | Time Range (s) | VRAM Max (GB) | VRAM Avg (GB) |\n"
)
f.write(
"|-------|---------------|---------|--------|----------|------------|--------------|-----------------|----------------|---------------|---------------|\n"
)
for result in all_results:
model_name = result['model_name']
total = result['total_samples']
correct = result['correct_predictions']
errors = result['failed_predictions']
accuracy = result['accuracy'] * 100
error_rate = result['failure_rate'] * 100
avg_time = result.get('avg_prediction_time', 0.0)
median_time = result.get('median_prediction_time', 0.0)
min_time = result.get('min_prediction_time', 0.0)
max_time = result.get('max_prediction_time', 0.0)
model_name = result["model_name"]
total = result["total_samples"]
correct = result["correct_predictions"]
errors = result["failed_predictions"]
accuracy = result["accuracy"] * 100
error_rate = result["failure_rate"] * 100
avg_time = result.get("avg_prediction_time", 0.0)
median_time = result.get("median_prediction_time", 0.0)
min_time = result.get("min_prediction_time", 0.0)
max_time = result.get("max_prediction_time", 0.0)
time_range = f"{min_time:.2f} - {max_time:.2f}"
vram_max = result.get('vram_max_mb', 0.0) / 1024
vram_avg = result.get('vram_avg_mb', 0.0) / 1024
f.write(f"| {model_name} | {total} | {correct} | {errors} | {accuracy:.2f}% | {error_rate:.2f}% | {avg_time:.2f} | {median_time:.2f} | {time_range} | {vram_max:.1f} | {vram_avg:.1f} |\n")
vram_max = result.get("vram_max_mb", 0.0) / 1024
vram_avg = result.get("vram_avg_mb", 0.0) / 1024
f.write(
f"| {model_name} | {total} | {correct} | {errors} | {accuracy:.2f}% | {error_rate:.2f}% | {avg_time:.2f} | {median_time:.2f} | {time_range} | {vram_max:.1f} | {vram_avg:.1f} |\n"
)
# Detailed results for each model
for result in all_results:
f.write(f"\n## {result['model_name']} - Detailed Results\n\n")
f.write("| Sample Index | Instruction | BBox | Predicted | Correct | Error | Time (s) |\n")
f.write(
"| Sample Index | Instruction | BBox | Predicted | Correct | Error | Time (s) |\n"
)
f.write("|-----------|-------------|------|-----------|---------|-------|----------|\n")
for sample_result in result['results'][:10]: # Show first 10 samples
sample_idx = sample_result['sample_idx']
instruction = sample_result['instruction'][:50] + "..." if len(sample_result['instruction']) > 50 else sample_result['instruction']
bbox = str(sample_result['bbox'])
predicted = str(sample_result['predicted_coords']) if sample_result['predicted_coords'] else "None"
correct = "PASS" if sample_result['is_correct'] else "FAIL"
error = "YES" if sample_result['failed'] else "NO"
pred_time = sample_result.get('prediction_time', 0.0)
f.write(f"| {sample_idx} | {instruction} | {bbox} | {predicted} | {correct} | {error} | {pred_time:.2f} |\n")
if len(result['results']) > 10:
for sample_result in result["results"][:10]: # Show first 10 samples
sample_idx = sample_result["sample_idx"]
instruction = (
sample_result["instruction"][:50] + "..."
if len(sample_result["instruction"]) > 50
else sample_result["instruction"]
)
bbox = str(sample_result["bbox"])
predicted = (
str(sample_result["predicted_coords"])
if sample_result["predicted_coords"]
else "None"
)
correct = "PASS" if sample_result["is_correct"] else "FAIL"
error = "YES" if sample_result["failed"] else "NO"
pred_time = sample_result.get("prediction_time", 0.0)
f.write(
f"| {sample_idx} | {instruction} | {bbox} | {predicted} | {correct} | {error} | {pred_time:.2f} |\n"
)
if len(result["results"]) > 10:
f.write(f"\n*Showing first 10 of {len(result['results'])} samples*\n")
print(f"\nResults saved to: {output_file}")
def save_visualizations(all_results: List[dict], samples, output_dir: str = "output") -> None:
"""
Save visualizations of predicted coordinates vs bboxes to an output folder.
Args:
all_results: List of evaluation results for each model
samples: List of sample dicts with image, bbox, instruction keys
output_dir: Output directory path
"""
os.makedirs(output_dir, exist_ok=True)
for result in all_results:
model_name = result['model_name'].replace('/', '_').replace('\\', '_')
model_name = result["model_name"].replace("/", "_").replace("\\", "_")
model_dir = os.path.join(output_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
print(f"Saving visualizations for {result['model_name']}...")
# Save first 10 samples for visualization
for i, sample_result in enumerate(tqdm(result['results'][:10], desc=f"Saving {model_name} visualizations")):
for i, sample_result in enumerate(
tqdm(result["results"][:10], desc=f"Saving {model_name} visualizations")
):
# Get sample data using index
sample_idx = sample_result['sample_idx']
sample_idx = sample_result["sample_idx"]
if sample_idx < len(samples):
sample = samples[sample_idx]
image = sample['image'].copy() # Make a copy to avoid modifying original
image = sample["image"].copy() # Make a copy to avoid modifying original
else:
print(f"Warning: Could not find sample at index {sample_idx}")
continue
bbox = sample_result['bbox']
predicted_coords = sample_result['predicted_coords']
is_correct = sample_result['is_correct']
bbox = sample_result["bbox"]
predicted_coords = sample_result["predicted_coords"]
is_correct = sample_result["is_correct"]
# Draw on image
draw = ImageDraw.Draw(image)
# Draw bounding box (ground truth) in green
x1, y1, x2, y2 = bbox
draw.rectangle([x1, y1, x2, y2], outline="green", width=3)
draw.text((x1, y1-20), "Ground Truth", fill="green")
draw.text((x1, y1 - 20), "Ground Truth", fill="green")
# Draw predicted click in red or blue
if predicted_coords is not None:
px, py = predicted_coords
color = "blue" if is_correct else "red"
# Draw crosshair
crosshair_size = 15
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=3)
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=3)
draw.text((px+10, py-20), f"Predicted ({px},{py})", fill=color)
draw.line(
[(px - crosshair_size, py), (px + crosshair_size, py)], fill=color, width=3
)
draw.line(
[(px, py - crosshair_size), (px, py + crosshair_size)], fill=color, width=3
)
draw.text((px + 10, py - 20), f"Predicted ({px},{py})", fill=color)
# Add status text
status = "CORRECT" if is_correct else "INCORRECT"
status_color = "blue" if is_correct else "red"
draw.text((10, 10), f"Status: {status}", fill=status_color)
draw.text((10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black")
draw.text(
(10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black"
)
# Save image
filename = f"sample_{i+1:02d}_idx{sample_idx}_{status.lower()}.png"
filepath = os.path.join(model_dir, filename)
image.save(filepath)
print(f"Visualizations saved to: {model_dir}")
def save_prediction_visualization(image: Image.Image, instruction: str, predictions: List[dict],
output_file: str = "interactive_prediction.png") -> None:
def save_prediction_visualization(
image: Image.Image,
instruction: str,
predictions: List[dict],
output_file: str = "interactive_prediction.png",
) -> None:
"""
Save visualization of multiple model predictions on a single image.
Args:
image: PIL Image to visualize
instruction: Instruction text
@@ -358,32 +393,32 @@ def save_prediction_visualization(image: Image.Image, instruction: str, predicti
# Create a copy of the image
vis_image = image.copy()
draw = ImageDraw.Draw(vis_image)
# Colors for different models
colors = ["red", "blue", "orange", "purple", "brown", "pink", "gray", "olive"]
# Draw predictions
for i, pred in enumerate(predictions):
color = colors[i % len(colors)]
model_name = pred['model_name']
coords = pred.get('coords')
error = pred.get('error')
model_name = pred["model_name"]
coords = pred.get("coords")
error = pred.get("error")
if coords is not None:
px, py = coords
# Draw crosshair
crosshair_size = 20
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=4)
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=4)
draw.line([(px - crosshair_size, py), (px + crosshair_size, py)], fill=color, width=4)
draw.line([(px, py - crosshair_size), (px, py + crosshair_size)], fill=color, width=4)
# Draw model name
draw.text((px+15, py+15), f"{model_name}: ({px},{py})", fill=color)
draw.text((px + 15, py + 15), f"{model_name}: ({px},{py})", fill=color)
else:
# Draw error text
draw.text((10, 50 + i*20), f"{model_name}: ERROR - {error}", fill=color)
draw.text((10, 50 + i * 20), f"{model_name}: ERROR - {error}", fill=color)
# Add instruction at the top
draw.text((10, 10), f"Instruction: {instruction}", fill="black")
# Save image
vis_image.save(output_file)
print(f"Prediction visualization saved to: {output_file}")
@@ -392,12 +427,13 @@ def save_prediction_visualization(image: Image.Image, instruction: str, predicti
def take_screenshot() -> Image.Image:
"""
Take a screenshot of the current screen.
Returns:
PIL Image of the screenshot
"""
try:
import pyautogui
screenshot = pyautogui.screenshot()
return screenshot
except ImportError:
@@ -406,4 +442,3 @@ def take_screenshot() -> Image.Image:
except Exception as e:
print(f"Error taking screenshot: {e}")
raise

View File

@@ -9,58 +9,61 @@ from agent import ComputerAgent
from computer import Computer
from computer.helpers import sandboxed
@sandboxed()
def read_file(location: str) -> str:
"""Read contents of a file
Parameters
----------
location : str
Path to the file to read
Returns
-------
str
Contents of the file or error message
"""
try:
with open(location, 'r') as f:
with open(location, "r") as f:
return f.read()
except Exception as e:
return f"Error reading file: {str(e)}"
def save_note(content: str, filename: str = "note.txt") -> str:
"""Save content to a note file
Parameters
----------
content : str
Content to save to the file
filename : str, optional
Name of the file to save to (default is "note.txt")
Returns
-------
str
Success or error message
"""
try:
with open(filename, 'w') as f:
with open(filename, "w") as f:
f.write(content)
return f"Saved note to {filename}"
except Exception as e:
return f"Error saving note: {str(e)}"
def calculate(a: int, b: int) -> int:
"""Calculate the sum of two integers
Parameters
----------
a : int
First integer
b : int
Second integer
Returns
-------
int
@@ -68,15 +71,18 @@ def calculate(a: int, b: int) -> int:
"""
return a + b
async def main():
"""Example usage of ComputerAgent with different models"""
# Example 1: Using Claude with computer and custom tools
print("=== Example 1: Claude with Computer ===")
import os
import dotenv
import json
import os
import dotenv
dotenv.load_dotenv()
assert os.getenv("CUA_CONTAINER_NAME") is not None, "CUA_CONTAINER_NAME is not set"
@@ -86,38 +92,37 @@ async def main():
os_type="linux",
provider_type="cloud",
name=os.getenv("CUA_CONTAINER_NAME") or "",
api_key=os.getenv("CUA_API_KEY") or ""
api_key=os.getenv("CUA_API_KEY") or "",
) as computer:
agent = ComputerAgent(
# Supported models:
# == OpenAI CUA (computer-use-preview) ==
model="openai/computer-use-preview",
# == Anthropic CUA (Claude > 3.5) ==
# model="anthropic/claude-opus-4-20250514",
# model="anthropic/claude-opus-4-20250514",
# model="anthropic/claude-sonnet-4-20250514",
# model="anthropic/claude-3-7-sonnet-20250219",
# model="anthropic/claude-3-5-sonnet-20241022",
# == UI-TARS ==
# model="huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
# TODO: add local mlx provider
# model="mlx-community/UI-TARS-1.5-7B-6bit",
# model="ollama_chat/0000/ui-tars-1.5-7b",
# == Omniparser + Any LLM ==
# model="omniparser+..."
# model="omniparser+anthropic/claude-opus-4-20250514",
tools=[computer],
only_n_most_recent_images=3,
verbosity=logging.INFO,
trajectory_dir="trajectories",
use_prompt_caching=True,
max_trajectory_budget={ "max_budget": 1.0, "raise_error": True, "reset_after_each_run": False },
max_trajectory_budget={
"max_budget": 1.0,
"raise_error": True,
"reset_after_each_run": False,
},
)
history = []
while True:
user_input = input("> ")
@@ -143,5 +148,6 @@ async def main():
# elif item["type"] == "function_call_output":
# print("===>", item["output"])
if __name__ == "__main__":
asyncio.run(main())
asyncio.run(main())

View File

@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
[project]
name = "cua-agent"
version = "0.4.0"
version = "0.4.35"
description = "CUA (Computer Use) Agent for AI-driven computer interaction"
readme = "README.md"
authors = [
@@ -29,6 +29,11 @@ requires-python = ">=3.12"
[project.optional-dependencies]
openai = []
anthropic = []
qwen = [
"qwen-vl-utils",
"qwen-agent",
"Pillow>=10.0.0",
]
omni = [
"cua-som>=0.1.0,<0.2.0",
]
@@ -49,7 +54,7 @@ glm45v-hf = [
opencua-hf = [
"accelerate",
"torch",
"transformers==4.53.0",
"transformers>=4.53.0",
"tiktoken>=0.11.0",
"blobfile>=3.0.0"
]
@@ -60,6 +65,11 @@ internvl-hf = [
"einops",
"timm"
]
moondream3 = [
"accelerate",
"torch",
"transformers>=4.55.0"
]
ui = [
"gradio>=5.23.3",
"python-dotenv>=1.0.1",
@@ -68,7 +78,10 @@ cli = [
"yaspin>=3.1.0",
]
hud = [
"hud-python==0.4.26",
"hud-python==0.4.52",
]
gemini = [
"google-genai>=1.41.0",
]
all = [
# uitars requirements
@@ -88,7 +101,13 @@ all = [
# cli requirements
"yaspin>=3.1.0",
# hud requirements
"hud-python==0.4.26",
"hud-python==0.4.52",
# gemini requirements
"google-genai>=1.41.0",
# qwen requirements
"qwen-vl-utils",
"qwen-agent",
"Pillow>=10.0.0",
]
[tool.uv]
@@ -98,4 +117,4 @@ constraint-dependencies = ["fastrtc>0.43.0", "mlx-audio>0.2.3"]
distribution = true
[tool.pdm.build]
includes = ["agent/"]
includes = ["agent/"]