Merge branch 'main' into feat/docs/init

This commit is contained in:
Morgan Dean
2025-07-31 18:04:02 +01:00
119 changed files with 7393 additions and 13344 deletions

View File

@@ -50,7 +50,7 @@
# 🚀 Quick Start
Read our guide on getting started with a Computer-Use Agent:
[Computer-Use Agent Quickstart](https://docs.trycua.com/home/guides/usage-guide)
[Computer-Use Agent Quickstart](https://trycua.com/docs/guides/usage-guide)
Get started using Cua services on your machine:
[Cua Usage Guide](https://docs.trycua.com/home/guides/cua-usage-guide)

View File

@@ -8,7 +8,7 @@ import signal
from computer import Computer, VMProviderType
# Import the unified agent class and types
from agent import ComputerAgent, LLMProvider, LLM, AgentLoop
from agent import ComputerAgent
# Import utility functions
from utils import load_dotenv_files, handle_sigint
@@ -19,8 +19,8 @@ logger = logging.getLogger(__name__)
async def run_agent_example():
"""Run example of using the ComputerAgent with OpenAI and Omni provider."""
print("\n=== Example: ComputerAgent with OpenAI and Omni provider ===")
"""Run example of using the ComputerAgent with different models."""
print("\n=== Example: ComputerAgent with different models ===")
try:
# Create a local macOS computer
@@ -37,28 +37,37 @@ async def run_agent_example():
# provider_type=VMProviderType.CLOUD,
# )
# Create Computer instance with async context manager
# Create ComputerAgent with new API
agent = ComputerAgent(
computer=computer,
loop=AgentLoop.OPENAI,
# loop=AgentLoop.ANTHROPIC,
# loop=AgentLoop.UITARS,
# loop=AgentLoop.OMNI,
model=LLM(provider=LLMProvider.OPENAI), # No model name for Operator CUA
# model=LLM(provider=LLMProvider.OPENAI, name="gpt-4o"),
# model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"),
# model=LLM(provider=LLMProvider.OLLAMA, name="gemma3:4b-it-q4_K_M"),
# model=LLM(provider=LLMProvider.MLXVLM, name="mlx-community/UI-TARS-1.5-7B-4bit"),
# model=LLM(
# provider=LLMProvider.OAICOMPAT,
# name="gemma-3-12b-it",
# provider_base_url="http://localhost:1234/v1", # LM Studio local endpoint
# ),
save_trajectory=True,
# Supported models:
# == OpenAI CUA (computer-use-preview) ==
model="openai/computer-use-preview",
# == Anthropic CUA (Claude > 3.5) ==
# model="anthropic/claude-opus-4-20250514",
# model="anthropic/claude-sonnet-4-20250514",
# model="anthropic/claude-3-7-sonnet-20250219",
# model="anthropic/claude-3-5-sonnet-20240620",
# == UI-TARS ==
# model="huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
# model="mlx/mlx-community/UI-TARS-1.5-7B-6bit",
# model="ollama_chat/0000/ui-tars-1.5-7b",
# == Omniparser + Any LLM ==
# model="omniparser+anthropic/claude-opus-4-20250514",
# model="omniparser+ollama_chat/gemma3:12b-it-q4_K_M",
tools=[computer],
only_n_most_recent_images=3,
verbosity=logging.DEBUG,
trajectory_dir="trajectories",
use_prompt_caching=True,
max_trajectory_budget=1.0,
)
# Example tasks to demonstrate the agent
tasks = [
"Look for a repository named trycua/cua on GitHub.",
"Check the open issues, open the most recent one and read it.",
@@ -68,43 +77,35 @@ async def run_agent_example():
"Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.",
]
# Use message-based conversation history
history = []
for i, task in enumerate(tasks):
print(f"\nExecuting task {i}/{len(tasks)}: {task}")
async for result in agent.run(task):
print("Response ID: ", result.get("id"))
# Print detailed usage information
usage = result.get("usage")
if usage:
print("\nUsage Details:")
print(f" Input Tokens: {usage.get('input_tokens')}")
if "input_tokens_details" in usage:
print(f" Input Tokens Details: {usage.get('input_tokens_details')}")
print(f" Output Tokens: {usage.get('output_tokens')}")
if "output_tokens_details" in usage:
print(f" Output Tokens Details: {usage.get('output_tokens_details')}")
print(f" Total Tokens: {usage.get('total_tokens')}")
print("Response Text: ", result.get("text"))
# Print tools information
tools = result.get("tools")
if tools:
print("\nTools:")
print(tools)
# Print reasoning and tool call outputs
outputs = result.get("output", [])
for output in outputs:
output_type = output.get("type")
if output_type == "reasoning":
print("\nReasoning Output:")
print(output)
elif output_type == "computer_call":
print("\nTool Call Output:")
print(output)
print(f"\n✅ Task {i+1}/{len(tasks)} completed: {task}")
print(f"\nExecuting task {i+1}/{len(tasks)}: {task}")
# Add user message to history
history.append({"role": "user", "content": task})
# Run agent with conversation history
async for result in agent.run(history, stream=False):
# Add agent outputs to history
history += result.get("output", [])
# Print output for debugging
for item in result.get("output", []):
if item.get("type") == "message":
content = item.get("content", [])
for content_part in content:
if content_part.get("text"):
print(f"Agent: {content_part.get('text')}")
elif item.get("type") == "computer_call":
action = item.get("action", {})
action_type = action.get("type", "")
print(f"Computer Action: {action_type}({action})")
elif item.get("type") == "computer_call_output":
print("Computer Output: [Screenshot/Result]")
print(f"✅ Task {i+1}/{len(tasks)} completed: {task}")
except Exception as e:
logger.error(f"Error in run_agent_example: {e}")

View File

@@ -2,8 +2,8 @@
<h1>
<div class="image-wrapper" style="display: inline-block;">
<picture>
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="../../img/logo_white.png" style="display: block; margin: auto;">
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="../../img/logo_black.png" style="display: block; margin: auto;">
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="../../../img/logo_white.png" style="display: block; margin: auto;">
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="../../../img/logo_black.png" style="display: block; margin: auto;">
<img alt="Shows my svg">
</picture>
</div>
@@ -15,208 +15,367 @@
</h1>
</div>
**cua-agent** is a general Computer-Use framework for running multi-app agentic workflows targeting macOS and Linux sandbox created with Cua, supporting local (Ollama) and cloud model providers (OpenAI, Anthropic, Groq, DeepSeek, Qwen).
**cua-agent** is a general Computer-Use framework with liteLLM integration for running agentic workflows on macOS, Windows, and Linux sandboxes. It provides a unified interface for computer-use agents across multiple LLM providers with advanced callback system for extensibility.
### Get started with Agent
## Features
<div align="center">
<img src="../../img/agent.png"/>
</div>
- **Safe Computer-Use/Tool-Use**: Using Computer SDK for sandboxed desktops
- **Multi-Agent Support**: Anthropic Claude, OpenAI computer-use-preview, UI-TARS, Omniparser + any LLM
- **Multi-API Support**: Take advantage of liteLLM supporting 100+ LLMs / model APIs, including local models (`huggingface-local/`, `ollama_chat/`, `mlx/`)
- **Cross-Platform**: Works on Windows, macOS, and Linux with cloud and local computer instances
- **Extensible Callbacks**: Built-in support for image retention, cache control, PII anonymization, budget limits, and trajectory tracking
## Install
```bash
pip install "cua-agent[all]"
# or install specific loop providers
pip install "cua-agent[openai]" # OpenAI Cua Loop
pip install "cua-agent[anthropic]" # Anthropic Cua Loop
pip install "cua-agent[uitars]" # UI-Tars support
pip install "cua-agent[omni]" # Cua Loop based on OmniParser (includes Ollama for local models)
pip install "cua-agent[ui]" # Gradio UI for the agent
pip install "cua-agent[uitars-mlx]" # MLX UI-Tars support
# or install specific providers
pip install "cua-agent[openai]" # OpenAI computer-use-preview support
pip install "cua-agent[anthropic]" # Anthropic Claude support
pip install "cua-agent[omni]" # Omniparser + any LLM support
pip install "cua-agent[uitars]" # UI-TARS
pip install "cua-agent[uitars-mlx]" # UI-TARS + MLX support
pip install "cua-agent[uitars-hf]" # UI-TARS + Huggingface support
pip install "cua-agent[ui]" # Gradio UI support
```
## Run
```bash
async with Computer() as macos_computer:
# Create agent with loop and provider
agent = ComputerAgent(
computer=macos_computer,
loop=AgentLoop.OPENAI,
model=LLM(provider=LLMProvider.OPENAI)
# or
# loop=AgentLoop.ANTHROPIC,
# model=LLM(provider=LLMProvider.ANTHROPIC)
# or
# loop=AgentLoop.OMNI,
# model=LLM(provider=LLMProvider.OLLAMA, name="gemma3")
# or
# loop=AgentLoop.UITARS,
# model=LLM(provider=LLMProvider.OAICOMPAT, name="ByteDance-Seed/UI-TARS-1.5-7B", provider_base_url="https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1")
)
tasks = [
"Look for a repository named trycua/cua on GitHub.",
"Check the open issues, open the most recent one and read it.",
"Clone the repository in users/lume/projects if it doesn't exist yet.",
"Open the repository with an app named Cursor (on the dock, black background and white cube icon).",
"From Cursor, open Composer if not already open.",
"Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.",
]
for i, task in enumerate(tasks):
print(f"\nExecuting task {i}/{len(tasks)}: {task}")
async for result in agent.run(task):
print(result)
print(f"\n✅ Task {i+1}/{len(tasks)} completed: {task}")
```
Refer to these notebooks for step-by-step guides on how to use the Computer-Use Agent (CUA):
- [Agent Notebook](../../notebooks/agent_nb.ipynb) - Complete examples and workflows
## Using the Gradio UI
The agent includes a Gradio-based user interface for easier interaction.
<div align="center">
<img src="../../img/agent_gradio_ui.png"/>
</div>
To use it:
```bash
# Install with Gradio support
pip install "cua-agent[ui]"
```
### Create a simple launcher script
## Quick Start
```python
# launch_ui.py
from agent.ui.gradio.app import create_gradio_ui
import asyncio
import os
from agent import ComputerAgent
from computer import Computer
app = create_gradio_ui()
app.launch(share=False)
async def main():
# Set up computer instance
async with Computer(
os_type="linux",
provider_type="cloud",
name=os.getenv("CUA_CONTAINER_NAME"),
api_key=os.getenv("CUA_API_KEY")
) as computer:
# Create agent
agent = ComputerAgent(
model="anthropic/claude-3-5-sonnet-20241022",
tools=[computer],
only_n_most_recent_images=3,
trajectory_dir="trajectories",
max_trajectory_budget=5.0 # $5 budget limit
)
# Run agent
messages = [{"role": "user", "content": "Take a screenshot and tell me what you see"}]
async for result in agent.run(messages):
for item in result["output"]:
if item["type"] == "message":
print(item["content"][0]["text"])
if __name__ == "__main__":
asyncio.run(main())
```
### Setting up API Keys
## Supported Models
For the Gradio UI to show available models, you need to set API keys as environment variables:
```bash
# For OpenAI models
export OPENAI_API_KEY=your_openai_key_here
# For Anthropic models
export ANTHROPIC_API_KEY=your_anthropic_key_here
# Launch with both keys set
OPENAI_API_KEY=your_key ANTHROPIC_API_KEY=your_key python launch_ui.py
### Anthropic Claude (Computer Use API)
```python
model="anthropic/claude-3-5-sonnet-20241022"
model="anthropic/claude-3-5-sonnet-20240620"
model="anthropic/claude-opus-4-20250514"
model="anthropic/claude-sonnet-4-20250514"
```
Without these environment variables, the UI will show "No models available" for the corresponding providers, but you can still use local models with the OMNI loop provider.
### OpenAI Computer Use Preview
```python
model="openai/computer-use-preview"
```
### Using Local Models
### UI-TARS (Local or Huggingface Inference)
```python
model="huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B"
model="ollama_chat/0000/ui-tars-1.5-7b"
```
You can use local models with the OMNI loop provider by selecting "Custom model..." from the dropdown. The default provider URL is set to `http://localhost:1234/v1` which works with LM Studio.
### Omniparser + Any LLM
```python
model="omniparser+ollama_chat/mistral-small3.2"
model="omniparser+vertex_ai/gemini-pro"
model="omniparser+anthropic/claude-3-5-sonnet-20241022"
model="omniparser+openai/gpt-4o"
```
If you're using a different local model server:
- vLLM: `http://localhost:8000/v1`
- LocalAI: `http://localhost:8080/v1`
- Ollama with OpenAI compat API: `http://localhost:11434/v1`
## Custom Tools
The Gradio UI provides:
- Selection of different agent loops (OpenAI, Anthropic, OMNI)
- Model selection for each provider
- Configuration of agent parameters
- Chat interface for interacting with the agent
### Using UI-TARS
The UI-TARS models are available in two forms:
1. **MLX UI-TARS models** (Default): These models run locally using MLXVLM provider
- `mlx-community/UI-TARS-1.5-7B-4bit` (default) - 4-bit quantized version
- `mlx-community/UI-TARS-1.5-7B-6bit` - 6-bit quantized version for higher quality
```python
agent = ComputerAgent(
computer=macos_computer,
loop=AgentLoop.UITARS,
model=LLM(provider=LLMProvider.MLXVLM, name="mlx-community/UI-TARS-1.5-7B-4bit")
)
```
2. **OpenAI-compatible UI-TARS**: For using the original ByteDance model
- If you want to use the original ByteDance UI-TARS model via an OpenAI-compatible API, follow the [deployment guide](https://github.com/bytedance/UI-TARS/blob/main/README_deploy.md)
- This will give you a provider URL like `https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1` which you can use in the code or Gradio UI:
```python
agent = ComputerAgent(
computer=macos_computer,
loop=AgentLoop.UITARS,
model=LLM(provider=LLMProvider.OAICOMPAT, name="tgi",
provider_base_url="https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1")
)
```
## Agent Loops
The `cua-agent` package provides three agent loops variations, based on different CUA models providers and techniques:
| Agent Loop | Supported Models | Description | Set-Of-Marks |
|:-----------|:-----------------|:------------|:-------------|
| `AgentLoop.OPENAI` | • `computer_use_preview` | Use OpenAI Operator CUA model | Not Required |
| `AgentLoop.ANTHROPIC` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219` | Use Anthropic Computer-Use | Not Required |
| `AgentLoop.UITARS` | • `mlx-community/UI-TARS-1.5-7B-4bit` (default)<br>• `mlx-community/UI-TARS-1.5-7B-6bit`<br>• `ByteDance-Seed/UI-TARS-1.5-7B` (via openAI-compatible endpoint) | Uses UI-TARS models with MLXVLM (default) or OAICOMPAT providers | Not Required |
| `AgentLoop.OMNI` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219`<br>• `gpt-4.5-preview`<br>• `gpt-4o`<br>• `gpt-4`<br>• `phi4`<br>• `phi4-mini`<br>• `gemma3`<br>• `...`<br>• `Any Ollama or OpenAI-compatible model` | Use OmniParser for element pixel-detection (SoM) and any VLMs for UI Grounding and Reasoning | OmniParser |
## AgentResponse
The `AgentResponse` class represents the structured output returned after each agent turn. It contains the agent's response, reasoning, tool usage, and other metadata. The response format aligns with the new [OpenAI Agent SDK specification](https://platform.openai.com/docs/api-reference/responses) for better consistency across different agent loops.
Define custom tools using decorated functions:
```python
async for result in agent.run(task):
print("Response ID: ", result.get("id"))
from computer.helpers import sandboxed
# Print detailed usage information
usage = result.get("usage")
if usage:
print("\nUsage Details:")
print(f" Input Tokens: {usage.get('input_tokens')}")
if "input_tokens_details" in usage:
print(f" Input Tokens Details: {usage.get('input_tokens_details')}")
print(f" Output Tokens: {usage.get('output_tokens')}")
if "output_tokens_details" in usage:
print(f" Output Tokens Details: {usage.get('output_tokens_details')}")
print(f" Total Tokens: {usage.get('total_tokens')}")
@sandboxed()
def read_file(location: str) -> str:
"""Read contents of a file
Parameters
----------
location : str
Path to the file to read
Returns
-------
str
Contents of the file or error message
"""
try:
with open(location, 'r') as f:
return f.read()
except Exception as e:
return f"Error reading file: {str(e)}"
print("Response Text: ", result.get("text"))
def calculate(a: int, b: int) -> int:
"""Calculate the sum of two integers"""
return a + b
# Print tools information
tools = result.get("tools")
if tools:
print("\nTools:")
print(tools)
# Print reasoning and tool call outputs
outputs = result.get("output", [])
for output in outputs:
output_type = output.get("type")
if output_type == "reasoning":
print("\nReasoning Output:")
print(output)
elif output_type == "computer_call":
print("\nTool Call Output:")
print(output)
# Use with agent
agent = ComputerAgent(
model="anthropic/claude-3-5-sonnet-20241022",
tools=[computer, read_file, calculate]
)
```
**Note on Settings Persistence:**
## Callbacks System
* The Gradio UI automatically saves your configuration (Agent Loop, Model Choice, Custom Base URL, Save Trajectory state, Recent Images count) to a file named `.gradio_settings.json` in the project's root directory when you successfully run a task.
* This allows your preferences to persist between sessions.
* API keys entered into the custom provider field are **not** saved in this file for security reasons. Manage API keys using environment variables (e.g., `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`) or a `.env` file.
* It's recommended to add `.gradio_settings.json` to your `.gitignore` file.
agent provides a comprehensive callback system for extending functionality:
### Built-in Callbacks
```python
from agent.callbacks import (
ImageRetentionCallback,
TrajectorySaverCallback,
BudgetManagerCallback,
LoggingCallback
)
agent = ComputerAgent(
model="anthropic/claude-3-5-sonnet-20241022",
tools=[computer],
callbacks=[
ImageRetentionCallback(only_n_most_recent_images=3),
TrajectorySaverCallback(trajectory_dir="trajectories"),
BudgetManagerCallback(max_budget=10.0, raise_error=True),
LoggingCallback(level=logging.INFO)
]
)
```
### Custom Callbacks
```python
from agent.callbacks.base import AsyncCallbackHandler
class CustomCallback(AsyncCallbackHandler):
async def on_llm_start(self, messages):
"""Preprocess messages before LLM call"""
# Add custom preprocessing logic
return messages
async def on_llm_end(self, messages):
"""Postprocess messages after LLM call"""
# Add custom postprocessing logic
return messages
async def on_usage(self, usage):
"""Track usage information"""
print(f"Tokens used: {usage.total_tokens}")
```
## Budget Management
Control costs with built-in budget management:
```python
# Simple budget limit
agent = ComputerAgent(
model="anthropic/claude-3-5-sonnet-20241022",
max_trajectory_budget=5.0 # $5 limit
)
# Advanced budget configuration
agent = ComputerAgent(
model="anthropic/claude-3-5-sonnet-20241022",
max_trajectory_budget={
"max_budget": 10.0,
"raise_error": True, # Raise error when exceeded
"reset_after_each_run": False # Persistent across runs
}
)
```
## Trajectory Management
Save and replay agent conversations:
```python
agent = ComputerAgent(
model="anthropic/claude-3-5-sonnet-20241022",
trajectory_dir="trajectories", # Auto-save trajectories
tools=[computer]
)
# Trajectories are saved with:
# - Complete conversation history
# - Usage statistics and costs
# - Timestamps and metadata
# - Screenshots and computer actions
```
## Configuration Options
### ComputerAgent Parameters
- `model`: Model identifier (required)
- `tools`: List of computer objects and decorated functions
- `callbacks`: List of callback handlers for extensibility
- `only_n_most_recent_images`: Limit recent images to prevent context overflow
- `verbosity`: Logging level (logging.INFO, logging.DEBUG, etc.)
- `trajectory_dir`: Directory to save conversation trajectories
- `max_retries`: Maximum API call retries (default: 3)
- `screenshot_delay`: Delay between actions and screenshots (default: 0.5s)
- `use_prompt_caching`: Enable prompt caching for supported models
- `max_trajectory_budget`: Budget limit configuration
### Environment Variables
```bash
# Computer instance (cloud)
export CUA_CONTAINER_NAME="your-container-name"
export CUA_API_KEY="your-cua-api-key"
# LLM API keys
export ANTHROPIC_API_KEY="your-anthropic-key"
export OPENAI_API_KEY="your-openai-key"
```
## Advanced Usage
### Streaming Responses
```python
async for result in agent.run(messages, stream=True):
# Process streaming chunks
for item in result["output"]:
if item["type"] == "message":
print(item["content"][0]["text"], end="", flush=True)
elif item["type"] == "computer_call":
action = item["action"]
print(f"\n[Action: {action['type']}]")
```
### Interactive Chat Loop
```python
history = []
while True:
user_input = input("> ")
if user_input.lower() in ['quit', 'exit']:
break
history.append({"role": "user", "content": user_input})
async for result in agent.run(history):
history += result["output"]
# Display assistant responses
for item in result["output"]:
if item["type"] == "message":
print(item["content"][0]["text"])
```
### Error Handling
```python
try:
async for result in agent.run(messages):
# Process results
pass
except BudgetExceededException:
print("Budget limit exceeded")
except Exception as e:
print(f"Agent error: {e}")
```
## API Reference
### ComputerAgent.run()
```python
async def run(
self,
messages: Messages,
stream: bool = False,
**kwargs
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Run the agent with the given messages.
Args:
messages: List of message dictionaries
stream: Whether to stream the response
**kwargs: Additional arguments
Returns:
AsyncGenerator that yields response chunks
"""
```
### Message Format
```python
messages = [
{
"role": "user",
"content": "Take a screenshot and describe what you see"
},
{
"role": "assistant",
"content": "I'll take a screenshot for you."
}
]
```
### Response Format
```python
{
"output": [
{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "I can see..."}]
},
{
"type": "computer_call",
"action": {"type": "screenshot"},
"call_id": "call_123"
},
{
"type": "computer_call_output",
"call_id": "call_123",
"output": {"image_url": "data:image/png;base64,..."}
}
],
"usage": {
"prompt_tokens": 150,
"completion_tokens": 75,
"total_tokens": 225,
"response_cost": 0.01,
}
}
```
## License
MIT License - see LICENSE file for details.

View File

@@ -1,12 +1,27 @@
"""CUA (Computer Use) Agent for AI-driven computer interaction."""
"""
agent - Decorator-based Computer Use Agent with liteLLM integration
"""
import sys
import logging
import sys
__version__ = "0.1.0"
from .decorators import agent_loop
from .agent import ComputerAgent
from .types import Messages, AgentResponse
# Initialize logging
logger = logging.getLogger("agent")
# Import loops to register them
from . import loops
__all__ = [
"agent_loop",
"ComputerAgent",
"Messages",
"AgentResponse"
]
__version__ = "0.4.0"
logger = logging.getLogger(__name__)
# Initialize telemetry when the package is imported
try:
@@ -18,7 +33,7 @@ try:
)
# Import set_dimension from our own telemetry module
from .core.telemetry import set_dimension
from .telemetry import set_dimension
# Check if telemetry is enabled
if is_telemetry_enabled():
@@ -47,9 +62,3 @@ except ImportError as e:
except Exception as e:
# Other issues with telemetry
logger.warning(f"Error initializing telemetry: {e}")
from .core.types import LLMProvider, LLM
from .core.factory import AgentLoop
from .core.agent import ComputerAgent
__all__ = ["AgentLoop", "LLMProvider", "LLM", "ComputerAgent"]

View File

@@ -0,0 +1,21 @@
"""
Entry point for running agent CLI module.
Usage:
python -m agent.cli <model_string>
"""
import sys
import asyncio
from .cli import main
if __name__ == "__main__":
# Check if 'cli' is specified as the module
if len(sys.argv) > 1 and sys.argv[1] == "cli":
# Remove 'cli' from arguments and run CLI
sys.argv.pop(1)
asyncio.run(main())
else:
print("Usage: python -m agent.cli <model_string>")
print("Example: python -m agent.cli openai/computer-use-preview")
sys.exit(1)

View File

@@ -0,0 +1,9 @@
"""
Adapters package for agent - Custom LLM adapters for LiteLLM
"""
from .huggingfacelocal_adapter import HuggingFaceLocalAdapter
__all__ = [
"HuggingFaceLocalAdapter",
]

View File

@@ -0,0 +1,229 @@
import asyncio
import warnings
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional
from litellm.types.utils import GenericStreamingChunk, ModelResponse
from litellm.llms.custom_llm import CustomLLM
from litellm import completion, acompletion
# Try to import HuggingFace dependencies
try:
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
class HuggingFaceLocalAdapter(CustomLLM):
"""HuggingFace Local Adapter for running vision-language models locally."""
def __init__(self, device: str = "auto", **kwargs):
"""Initialize the adapter.
Args:
device: Device to load model on ("auto", "cuda", "cpu", etc.)
**kwargs: Additional arguments
"""
super().__init__()
self.device = device
self.models = {} # Cache for loaded models
self.processors = {} # Cache for loaded processors
def _load_model_and_processor(self, model_name: str):
"""Load model and processor if not already cached.
Args:
model_name: Name of the model to load
Returns:
Tuple of (model, processor)
"""
if model_name not in self.models:
# Load model
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map=self.device,
attn_implementation="sdpa"
)
# Load processor
processor = AutoProcessor.from_pretrained(model_name)
# Cache them
self.models[model_name] = model
self.processors[model_name] = processor
return self.models[model_name], self.processors[model_name]
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert OpenAI format messages to HuggingFace format.
Args:
messages: Messages in OpenAI format
Returns:
Messages in HuggingFace format
"""
converted_messages = []
for message in messages:
converted_message = {
"role": message["role"],
"content": []
}
content = message.get("content", [])
if isinstance(content, str):
# Simple text content
converted_message["content"].append({
"type": "text",
"text": content
})
elif isinstance(content, list):
# Multi-modal content
for item in content:
if item.get("type") == "text":
converted_message["content"].append({
"type": "text",
"text": item.get("text", "")
})
elif item.get("type") == "image_url":
# Convert image_url format to image format
image_url = item.get("image_url", {}).get("url", "")
converted_message["content"].append({
"type": "image",
"image": image_url
})
converted_messages.append(converted_message)
return converted_messages
def _generate(self, **kwargs) -> str:
"""Generate response using the local HuggingFace model.
Args:
**kwargs: Keyword arguments containing messages and model info
Returns:
Generated text response
"""
if not HF_AVAILABLE:
raise ImportError(
"HuggingFace transformers dependencies not found. "
"Please install with: pip install \"cua-agent[uitars-hf]\""
)
# Extract messages and model from kwargs
messages = kwargs.get('messages', [])
model_name = kwargs.get('model', 'ByteDance-Seed/UI-TARS-1.5-7B')
max_new_tokens = kwargs.get('max_tokens', 128)
# Warn about ignored kwargs
ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
if ignored_kwargs:
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
# Load model and processor
model, processor = self._load_model_and_processor(model_name)
# Convert messages to HuggingFace format
hf_messages = self._convert_messages(messages)
# Apply chat template and tokenize
inputs = processor.apply_chat_template(
hf_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
)
# Move inputs to the same device as model
if torch.cuda.is_available() and self.device != "cpu":
inputs = inputs.to("cuda")
# Generate response
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
# Trim input tokens from output
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
# Decode output
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return output_text[0] if output_text else ""
def completion(self, *args, **kwargs) -> ModelResponse:
"""Synchronous completion method.
Returns:
ModelResponse with generated text
"""
generated_text = self._generate(**kwargs)
return completion(
model=f"huggingface-local/{kwargs['model']}",
mock_response=generated_text,
)
async def acompletion(self, *args, **kwargs) -> ModelResponse:
"""Asynchronous completion method.
Returns:
ModelResponse with generated text
"""
# Run _generate in thread pool to avoid blocking
generated_text = await asyncio.to_thread(self._generate, **kwargs)
return await acompletion(
model=f"huggingface-local/{kwargs['model']}",
mock_response=generated_text,
)
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
"""Synchronous streaming method.
Returns:
Iterator of GenericStreamingChunk
"""
generated_text = self._generate(**kwargs)
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": generated_text,
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
yield generic_streaming_chunk
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
"""Asynchronous streaming method.
Returns:
AsyncIterator of GenericStreamingChunk
"""
# Run _generate in thread pool to avoid blocking
generated_text = await asyncio.to_thread(self._generate, **kwargs)
generic_streaming_chunk: GenericStreamingChunk = {
"finish_reason": "stop",
"index": 0,
"is_finished": True,
"text": generated_text,
"tool_use": None,
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
}
yield generic_streaming_chunk

View File

@@ -0,0 +1,594 @@
"""
ComputerAgent - Main agent class that selects and runs agent loops
"""
import asyncio
from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set
from litellm.responses.utils import Usage
from .types import Messages, Computer
from .decorators import find_agent_loop
from .computer_handler import OpenAIComputerHandler, acknowledge_safety_check_callback, check_blocklisted_url
import json
import litellm
import litellm.utils
import inspect
from .adapters import HuggingFaceLocalAdapter
from .callbacks import (
ImageRetentionCallback,
LoggingCallback,
TrajectorySaverCallback,
BudgetManagerCallback,
TelemetryCallback,
)
def get_json(obj: Any, max_depth: int = 10) -> Any:
def custom_serializer(o: Any, depth: int = 0, seen: Set[int] = None) -> Any:
if seen is None:
seen = set()
# Use model_dump() if available
if hasattr(o, 'model_dump'):
return o.model_dump()
# Check depth limit
if depth > max_depth:
return f"<max_depth_exceeded:{max_depth}>"
# Check for circular references using object id
obj_id = id(o)
if obj_id in seen:
return f"<circular_reference:{type(o).__name__}>"
# Handle Computer objects
if hasattr(o, '__class__') and 'computer' in getattr(o, '__class__').__name__.lower():
return f"<computer:{o.__class__.__name__}>"
# Handle objects with __dict__
if hasattr(o, '__dict__'):
seen.add(obj_id)
try:
result = {}
for k, v in o.__dict__.items():
if v is not None:
# Recursively serialize with updated depth and seen set
serialized_value = custom_serializer(v, depth + 1, seen.copy())
result[k] = serialized_value
return result
finally:
seen.discard(obj_id)
# Handle common types that might contain nested objects
elif isinstance(o, dict):
seen.add(obj_id)
try:
return {
k: custom_serializer(v, depth + 1, seen.copy())
for k, v in o.items()
if v is not None
}
finally:
seen.discard(obj_id)
elif isinstance(o, (list, tuple, set)):
seen.add(obj_id)
try:
return [
custom_serializer(item, depth + 1, seen.copy())
for item in o
if item is not None
]
finally:
seen.discard(obj_id)
# For basic types that json.dumps can handle
elif isinstance(o, (str, int, float, bool)) or o is None:
return o
# Fallback to string representation
else:
return str(o)
def remove_nones(obj: Any) -> Any:
if isinstance(obj, dict):
return {k: remove_nones(v) for k, v in obj.items() if v is not None}
elif isinstance(obj, list):
return [remove_nones(item) for item in obj if item is not None]
return obj
# Serialize with circular reference and depth protection
serialized = custom_serializer(obj)
# Convert to JSON string and back to ensure JSON compatibility
json_str = json.dumps(serialized)
parsed = json.loads(json_str)
# Final cleanup of any remaining None values
return remove_nones(parsed)
def sanitize_message(msg: Any) -> Any:
"""Return a copy of the message with image_url omitted for computer_call_output messages."""
if msg.get("type") == "computer_call_output":
output = msg.get("output", {})
if isinstance(output, dict):
sanitized = msg.copy()
sanitized["output"] = {**output, "image_url": "[omitted]"}
return sanitized
return msg
class ComputerAgent:
"""
Main agent class that automatically selects the appropriate agent loop
based on the model and executes tool calls.
"""
def __init__(
self,
model: str,
tools: Optional[List[Any]] = None,
custom_loop: Optional[Callable] = None,
only_n_most_recent_images: Optional[int] = None,
callbacks: Optional[List[Any]] = None,
verbosity: Optional[int] = None,
trajectory_dir: Optional[str] = None,
max_retries: Optional[int] = 3,
screenshot_delay: Optional[float | int] = 0.5,
use_prompt_caching: Optional[bool] = False,
max_trajectory_budget: Optional[float | dict] = None,
telemetry_enabled: Optional[bool] = True,
**kwargs
):
"""
Initialize ComputerAgent.
Args:
model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
tools: List of tools (computer objects, decorated functions, etc.)
custom_loop: Custom agent loop function to use instead of auto-selection
only_n_most_recent_images: If set, only keep the N most recent images in message history. Adds ImageRetentionCallback automatically.
callbacks: List of AsyncCallbackHandler instances for preprocessing/postprocessing
verbosity: Logging level (logging.DEBUG, logging.INFO, etc.). If set, adds LoggingCallback automatically
trajectory_dir: If set, saves trajectory data (screenshots, responses) to this directory. Adds TrajectorySaverCallback automatically.
max_retries: Maximum number of retries for failed API calls
screenshot_delay: Delay before screenshots in seconds
use_prompt_caching: If set, use prompt caching to avoid reprocessing the same prompt. Intended for use with anthropic providers.
max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
**kwargs: Additional arguments passed to the agent loop
"""
self.model = model
self.tools = tools or []
self.custom_loop = custom_loop
self.only_n_most_recent_images = only_n_most_recent_images
self.callbacks = callbacks or []
self.verbosity = verbosity
self.trajectory_dir = trajectory_dir
self.max_retries = max_retries
self.screenshot_delay = screenshot_delay
self.use_prompt_caching = use_prompt_caching
self.telemetry_enabled = telemetry_enabled
self.kwargs = kwargs
# == Add built-in callbacks ==
# Add telemetry callback if telemetry_enabled is set
if self.telemetry_enabled:
if isinstance(self.telemetry_enabled, bool):
self.callbacks.append(TelemetryCallback(self))
else:
self.callbacks.append(TelemetryCallback(self, **self.telemetry_enabled))
# Add logging callback if verbosity is set
if self.verbosity is not None:
self.callbacks.append(LoggingCallback(level=self.verbosity))
# Add image retention callback if only_n_most_recent_images is set
if self.only_n_most_recent_images:
self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
# Add trajectory saver callback if trajectory_dir is set
if self.trajectory_dir:
self.callbacks.append(TrajectorySaverCallback(self.trajectory_dir))
# Add budget manager if max_trajectory_budget is set
if max_trajectory_budget:
if isinstance(max_trajectory_budget, dict):
self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
else:
self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
# == Enable local model providers w/ LiteLLM ==
# Register local model providers
hf_adapter = HuggingFaceLocalAdapter(
device="auto"
)
litellm.custom_provider_map = [
{"provider": "huggingface-local", "custom_handler": hf_adapter}
]
# == Initialize computer agent ==
# Find the appropriate agent loop
if custom_loop:
self.agent_loop = custom_loop
self.agent_loop_info = None
else:
loop_info = find_agent_loop(model)
if not loop_info:
raise ValueError(f"No agent loop found for model: {model}")
self.agent_loop = loop_info.func
self.agent_loop_info = loop_info
self.tool_schemas = []
self.computer_handler = None
async def _initialize_computers(self):
"""Initialize computer objects"""
if not self.tool_schemas:
for tool in self.tools:
if hasattr(tool, '_initialized') and not tool._initialized:
await tool.run()
# Process tools and create tool schemas
self.tool_schemas = self._process_tools()
# Find computer tool and create interface adapter
computer_handler = None
for schema in self.tool_schemas:
if schema["type"] == "computer":
computer_handler = OpenAIComputerHandler(schema["computer"].interface)
break
self.computer_handler = computer_handler
def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
"""Process input messages and create schemas for the agent loop"""
if isinstance(input, str):
return [{"role": "user", "content": input}]
return [get_json(msg) for msg in input]
def _process_tools(self) -> List[Dict[str, Any]]:
"""Process tools and create schemas for the agent loop"""
schemas = []
for tool in self.tools:
# Check if it's a computer object (has interface attribute)
if hasattr(tool, 'interface'):
# This is a computer tool - will be handled by agent loop
schemas.append({
"type": "computer",
"computer": tool
})
elif callable(tool):
# Use litellm.utils.function_to_dict to extract schema from docstring
try:
function_schema = litellm.utils.function_to_dict(tool)
schemas.append({
"type": "function",
"function": function_schema
})
except Exception as e:
print(f"Warning: Could not process tool {tool}: {e}")
else:
print(f"Warning: Unknown tool type: {tool}")
return schemas
def _get_tool(self, name: str) -> Optional[Callable]:
"""Get a tool by name"""
for tool in self.tools:
if hasattr(tool, '__name__') and tool.__name__ == name:
return tool
elif hasattr(tool, 'func') and tool.func.__name__ == name:
return tool
return None
# ============================================================================
# AGENT RUN LOOP LIFECYCLE HOOKS
# ============================================================================
async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Initialize run tracking by calling callbacks."""
for callback in self.callbacks:
if hasattr(callback, 'on_run_start'):
await callback.on_run_start(kwargs, old_items)
async def _on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
"""Finalize run tracking by calling callbacks."""
for callback in self.callbacks:
if hasattr(callback, 'on_run_end'):
await callback.on_run_end(kwargs, old_items, new_items)
async def _on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
"""Check if run should continue by calling callbacks."""
for callback in self.callbacks:
if hasattr(callback, 'on_run_continue'):
should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
if not should_continue:
return False
return True
async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Prepare messages for the LLM call by applying callbacks."""
result = messages
for callback in self.callbacks:
if hasattr(callback, 'on_llm_start'):
result = await callback.on_llm_start(result)
return result
async def _on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Postprocess messages after the LLM call by applying callbacks."""
result = messages
for callback in self.callbacks:
if hasattr(callback, 'on_llm_end'):
result = await callback.on_llm_end(result)
return result
async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
"""Called when responses are received."""
for callback in self.callbacks:
if hasattr(callback, 'on_responses'):
await callback.on_responses(get_json(kwargs), get_json(responses))
async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a computer call is about to start."""
for callback in self.callbacks:
if hasattr(callback, 'on_computer_call_start'):
await callback.on_computer_call_start(get_json(item))
async def _on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
"""Called when a computer call has completed."""
for callback in self.callbacks:
if hasattr(callback, 'on_computer_call_end'):
await callback.on_computer_call_end(get_json(item), get_json(result))
async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a function call is about to start."""
for callback in self.callbacks:
if hasattr(callback, 'on_function_call_start'):
await callback.on_function_call_start(get_json(item))
async def _on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
"""Called when a function call has completed."""
for callback in self.callbacks:
if hasattr(callback, 'on_function_call_end'):
await callback.on_function_call_end(get_json(item), get_json(result))
async def _on_text(self, item: Dict[str, Any]) -> None:
"""Called when a text message is encountered."""
for callback in self.callbacks:
if hasattr(callback, 'on_text'):
await callback.on_text(get_json(item))
async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
"""Called when an LLM API call is about to start."""
for callback in self.callbacks:
if hasattr(callback, 'on_api_start'):
await callback.on_api_start(get_json(kwargs))
async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
"""Called when an LLM API call has completed."""
for callback in self.callbacks:
if hasattr(callback, 'on_api_end'):
await callback.on_api_end(get_json(kwargs), get_json(result))
async def _on_usage(self, usage: Dict[str, Any]) -> None:
"""Called when usage information is received."""
for callback in self.callbacks:
if hasattr(callback, 'on_usage'):
await callback.on_usage(get_json(usage))
async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
"""Called when a screenshot is taken."""
for callback in self.callbacks:
if hasattr(callback, 'on_screenshot'):
await callback.on_screenshot(screenshot, name)
# ============================================================================
# AGENT OUTPUT PROCESSING
# ============================================================================
async def _handle_item(self, item: Any, computer: Optional[Computer] = None) -> List[Dict[str, Any]]:
"""Handle each item; may cause a computer action + screenshot."""
item_type = item.get("type", None)
if item_type == "message":
await self._on_text(item)
# # Print messages
# if item.get("content"):
# for content_item in item.get("content"):
# if content_item.get("text"):
# print(content_item.get("text"))
return []
if item_type == "computer_call":
await self._on_computer_call_start(item)
if not computer:
raise ValueError("Computer handler is required for computer calls")
# Perform computer actions
action = item.get("action")
action_type = action.get("type")
# Extract action arguments (all fields except 'type')
action_args = {k: v for k, v in action.items() if k != "type"}
# print(f"{action_type}({action_args})")
# Execute the computer action
computer_method = getattr(computer, action_type, None)
if computer_method:
await computer_method(**action_args)
else:
print(f"Unknown computer action: {action_type}")
return []
# Take screenshot after action
if self.screenshot_delay and self.screenshot_delay > 0:
await asyncio.sleep(self.screenshot_delay)
screenshot_base64 = await computer.screenshot()
await self._on_screenshot(screenshot_base64, "screenshot_after")
# Handle safety checks
pending_checks = item.get("pending_safety_checks", [])
acknowledged_checks = []
for check in pending_checks:
check_message = check.get("message", str(check))
if acknowledge_safety_check_callback(check_message):
acknowledged_checks.append(check)
else:
raise ValueError(f"Safety check failed: {check_message}")
# Create call output
call_output = {
"type": "computer_call_output",
"call_id": item.get("call_id"),
"acknowledged_safety_checks": acknowledged_checks,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_base64}",
},
}
# Additional URL safety checks for browser environments
if await computer.get_environment() == "browser":
current_url = await computer.get_current_url()
call_output["output"]["current_url"] = current_url
check_blocklisted_url(current_url)
result = [call_output]
await self._on_computer_call_end(item, result)
return result
if item_type == "function_call":
await self._on_function_call_start(item)
# Perform function call
function = self._get_tool(item.get("name"))
if not function:
raise ValueError(f"Function {item.get("name")} not found")
args = json.loads(item.get("arguments"))
# Execute function - use asyncio.to_thread for non-async functions
if inspect.iscoroutinefunction(function):
result = await function(**args)
else:
result = await asyncio.to_thread(function, **args)
# Create function call output
call_output = {
"type": "function_call_output",
"call_id": item.get("call_id"),
"output": str(result),
}
result = [call_output]
await self._on_function_call_end(item, result)
return result
return []
# ============================================================================
# MAIN AGENT LOOP
# ============================================================================
async def run(
self,
messages: Messages,
stream: bool = False,
**kwargs
) -> AsyncGenerator[Dict[str, Any], None]:
"""
Run the agent with the given messages using Computer protocol handler pattern.
Args:
messages: List of message dictionaries
stream: Whether to stream the response
**kwargs: Additional arguments
Returns:
AsyncGenerator that yields response chunks
"""
await self._initialize_computers()
# Merge kwargs
merged_kwargs = {**self.kwargs, **kwargs}
old_items = self._process_input(messages)
new_items = []
# Initialize run tracking
run_kwargs = {
"messages": messages,
"stream": stream,
"model": self.model,
"agent_loop": self.agent_loop.__name__,
**merged_kwargs
}
await self._on_run_start(run_kwargs, old_items)
while new_items[-1].get("role") != "assistant" if new_items else True:
# Lifecycle hook: Check if we should continue based on callbacks (e.g., budget manager)
should_continue = await self._on_run_continue(run_kwargs, old_items, new_items)
if not should_continue:
break
# Lifecycle hook: Prepare messages for the LLM call
# Use cases:
# - PII anonymization
# - Image retention policy
combined_messages = old_items + new_items
preprocessed_messages = await self._on_llm_start(combined_messages)
loop_kwargs = {
"messages": preprocessed_messages,
"model": self.model,
"tools": self.tool_schemas,
"stream": False,
"computer_handler": self.computer_handler,
"max_retries": self.max_retries,
"use_prompt_caching": self.use_prompt_caching,
**merged_kwargs
}
# Run agent loop iteration
result = await self.agent_loop(
**loop_kwargs,
_on_api_start=self._on_api_start,
_on_api_end=self._on_api_end,
_on_usage=self._on_usage,
_on_screenshot=self._on_screenshot,
)
result = get_json(result)
# Lifecycle hook: Postprocess messages after the LLM call
# Use cases:
# - PII deanonymization (if you want tool calls to see PII)
result["output"] = await self._on_llm_end(result.get("output", []))
await self._on_responses(loop_kwargs, result)
# Yield agent response
yield result
# Add agent response to new_items
new_items += result.get("output")
# Handle computer actions
for item in result.get("output"):
partial_items = await self._handle_item(item, self.computer_handler)
new_items += partial_items
# Yield partial response
yield {
"output": partial_items,
"usage": Usage(
prompt_tokens=0,
completion_tokens=0,
total_tokens=0,
)
}
await self._on_run_end(loop_kwargs, old_items, new_items)

View File

@@ -0,0 +1,19 @@
"""
Callback system for ComputerAgent preprocessing and postprocessing hooks.
"""
from .base import AsyncCallbackHandler
from .image_retention import ImageRetentionCallback
from .logging import LoggingCallback
from .trajectory_saver import TrajectorySaverCallback
from .budget_manager import BudgetManagerCallback
from .telemetry import TelemetryCallback
__all__ = [
"AsyncCallbackHandler",
"ImageRetentionCallback",
"LoggingCallback",
"TrajectorySaverCallback",
"BudgetManagerCallback",
"TelemetryCallback",
]

View File

@@ -0,0 +1,153 @@
"""
Base callback handler interface for ComputerAgent preprocessing and postprocessing hooks.
"""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional, Union
class AsyncCallbackHandler(ABC):
"""
Base class for async callback handlers that can preprocess messages before
the agent loop and postprocess output after the agent loop.
"""
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Called at the start of an agent run loop."""
pass
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
"""Called at the end of an agent run loop."""
pass
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
"""Called during agent run loop to determine if execution should continue.
Args:
kwargs: Run arguments
old_items: Original messages
new_items: New messages generated during run
Returns:
True to continue execution, False to stop
"""
return True
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Called before messages are sent to the agent loop.
Args:
messages: List of message dictionaries to preprocess
Returns:
List of preprocessed message dictionaries
"""
return messages
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Called after the agent loop returns output.
Args:
output: List of output message dictionaries to postprocess
Returns:
List of postprocessed output dictionaries
"""
return output
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
"""
Called when a computer call is about to start.
Args:
item: The computer call item dictionary
"""
pass
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
"""
Called when a computer call has completed.
Args:
item: The computer call item dictionary
result: The result of the computer call
"""
pass
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
"""
Called when a function call is about to start.
Args:
item: The function call item dictionary
"""
pass
async def on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
"""
Called when a function call has completed.
Args:
item: The function call item dictionary
result: The result of the function call
"""
pass
async def on_text(self, item: Dict[str, Any]) -> None:
"""
Called when a text message is encountered.
Args:
item: The message item dictionary
"""
pass
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
"""
Called when an API call is about to start.
Args:
kwargs: The kwargs being passed to the API call
"""
pass
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
"""
Called when an API call has completed.
Args:
kwargs: The kwargs that were passed to the API call
result: The result of the API call
"""
pass
async def on_usage(self, usage: Dict[str, Any]) -> None:
"""
Called when usage information is received.
Args:
usage: The usage information
"""
pass
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
"""
Called when a screenshot is taken.
Args:
screenshot: The screenshot image
name: The name of the screenshot
"""
pass
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
"""
Called when responses are received.
Args:
kwargs: The kwargs being passed to the agent loop
responses: The responses received
"""
pass

View File

@@ -0,0 +1,44 @@
from typing import Dict, List, Any
from .base import AsyncCallbackHandler
class BudgetExceededError(Exception):
"""Exception raised when budget is exceeded."""
pass
class BudgetManagerCallback(AsyncCallbackHandler):
"""Budget manager callback that tracks usage costs and can stop execution when budget is exceeded."""
def __init__(self, max_budget: float, reset_after_each_run: bool = True, raise_error: bool = False):
"""
Initialize BudgetManagerCallback.
Args:
max_budget: Maximum budget allowed
reset_after_each_run: Whether to reset budget after each run
raise_error: Whether to raise an error when budget is exceeded
"""
self.max_budget = max_budget
self.reset_after_each_run = reset_after_each_run
self.raise_error = raise_error
self.total_cost = 0.0
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Reset budget if configured to do so."""
if self.reset_after_each_run:
self.total_cost = 0.0
async def on_usage(self, usage: Dict[str, Any]) -> None:
"""Track usage costs."""
if "response_cost" in usage:
self.total_cost += usage["response_cost"]
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
"""Check if budget allows continuation."""
if self.total_cost >= self.max_budget:
if self.raise_error:
raise BudgetExceededError(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
else:
print(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
return False
return True

View File

@@ -0,0 +1,139 @@
"""
Image retention callback handler that limits the number of recent images in message history.
"""
from typing import List, Dict, Any, Optional
from .base import AsyncCallbackHandler
class ImageRetentionCallback(AsyncCallbackHandler):
"""
Callback handler that applies image retention policy to limit the number
of recent images in message history to prevent context window overflow.
"""
def __init__(self, only_n_most_recent_images: Optional[int] = None):
"""
Initialize the image retention callback.
Args:
only_n_most_recent_images: If set, only keep the N most recent images in message history
"""
self.only_n_most_recent_images = only_n_most_recent_images
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Apply image retention policy to messages before sending to agent loop.
Args:
messages: List of message dictionaries
Returns:
List of messages with image retention policy applied
"""
if self.only_n_most_recent_images is None:
return messages
return self._apply_image_retention(messages)
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Apply image retention policy to keep only the N most recent images.
Removes computer_call_output items with image_url and their corresponding computer_call items,
keeping only the most recent N image pairs based on only_n_most_recent_images setting.
Args:
messages: List of message dictionaries
Returns:
Filtered list of messages with image retention applied
"""
if self.only_n_most_recent_images is None:
return messages
# First pass: Assign call_id to reasoning items based on the next computer_call
messages_with_call_ids = []
for i, msg in enumerate(messages):
msg_copy = msg.copy() if isinstance(msg, dict) else msg
# If this is a reasoning item without a call_id, find the next computer_call
if (msg_copy.get("type") == "reasoning" and
not msg_copy.get("call_id")):
# Look ahead for the next computer_call
for j in range(i + 1, len(messages)):
next_msg = messages[j]
if (next_msg.get("type") == "computer_call" and
next_msg.get("call_id")):
msg_copy["call_id"] = next_msg.get("call_id")
break
messages_with_call_ids.append(msg_copy)
# Find all computer_call_output items with images and their call_ids
image_call_ids = []
for msg in reversed(messages_with_call_ids): # Process in reverse to get most recent first
if (msg.get("type") == "computer_call_output" and
isinstance(msg.get("output"), dict) and
"image_url" in msg.get("output", {})):
call_id = msg.get("call_id")
if call_id and call_id not in image_call_ids:
image_call_ids.append(call_id)
if len(image_call_ids) >= self.only_n_most_recent_images:
break
# Keep the most recent N image call_ids (reverse to get chronological order)
keep_call_ids = set(image_call_ids[:self.only_n_most_recent_images])
# Filter messages: remove computer_call, computer_call_output, and reasoning for old images
filtered_messages = []
for msg in messages_with_call_ids:
msg_type = msg.get("type")
call_id = msg.get("call_id")
# Remove old computer_call items
if msg_type == "computer_call" and call_id not in keep_call_ids:
# Check if this call_id corresponds to an image call
has_image_output = any(
m.get("type") == "computer_call_output" and
m.get("call_id") == call_id and
isinstance(m.get("output"), dict) and
"image_url" in m.get("output", {})
for m in messages_with_call_ids
)
if has_image_output:
continue # Skip this computer_call
# Remove old computer_call_output items with images
if (msg_type == "computer_call_output" and
call_id not in keep_call_ids and
isinstance(msg.get("output"), dict) and
"image_url" in msg.get("output", {})):
continue # Skip this computer_call_output
# Remove old reasoning items that are paired with removed computer calls
if (msg_type == "reasoning" and
call_id and call_id not in keep_call_ids):
# Check if this call_id corresponds to an image call that's being removed
has_image_output = any(
m.get("type") == "computer_call_output" and
m.get("call_id") == call_id and
isinstance(m.get("output"), dict) and
"image_url" in m.get("output", {})
for m in messages_with_call_ids
)
if has_image_output:
continue # Skip this reasoning item
filtered_messages.append(msg)
# Clean up: Remove call_id from reasoning items before returning
final_messages = []
for msg in filtered_messages:
if msg.get("type") == "reasoning" and "call_id" in msg:
# Create a copy without call_id for reasoning items
cleaned_msg = {k: v for k, v in msg.items() if k != "call_id"}
final_messages.append(cleaned_msg)
else:
final_messages.append(msg)
return final_messages

View File

@@ -0,0 +1,247 @@
"""
Logging callback for ComputerAgent that provides configurable logging of agent lifecycle events.
"""
import json
import logging
from typing import Dict, List, Any, Optional, Union
from .base import AsyncCallbackHandler
def sanitize_image_urls(data: Any) -> Any:
"""
Recursively search for 'image_url' keys and set their values to '[omitted]'.
Args:
data: Any data structure (dict, list, or primitive type)
Returns:
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
"""
if isinstance(data, dict):
# Create a copy of the dictionary
sanitized = {}
for key, value in data.items():
if key == "image_url":
sanitized[key] = "[omitted]"
else:
# Recursively sanitize the value
sanitized[key] = sanitize_image_urls(value)
return sanitized
elif isinstance(data, list):
# Recursively sanitize each item in the list
return [sanitize_image_urls(item) for item in data]
else:
# For primitive types (str, int, bool, None, etc.), return as-is
return data
class LoggingCallback(AsyncCallbackHandler):
"""
Callback handler that logs agent lifecycle events with configurable verbosity.
Logging levels:
- DEBUG: All events including API calls, message preprocessing, and detailed outputs
- INFO: Major lifecycle events (start/end, messages, outputs)
- WARNING: Only warnings and errors
- ERROR: Only errors
"""
def __init__(self, logger: Optional[logging.Logger] = None, level: int = logging.INFO):
"""
Initialize the logging callback.
Args:
logger: Logger instance to use. If None, creates a logger named 'agent.ComputerAgent'
level: Logging level (logging.DEBUG, logging.INFO, etc.)
"""
self.logger = logger or logging.getLogger('agent.ComputerAgent')
self.level = level
# Set up logger if it doesn't have handlers
if not self.logger.handlers:
handler = logging.StreamHandler()
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
handler.setFormatter(formatter)
self.logger.addHandler(handler)
self.logger.setLevel(level)
def _update_usage(self, usage: Dict[str, Any]) -> None:
"""Update total usage statistics."""
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
for key, value in source.items():
if isinstance(value, dict):
if key not in target:
target[key] = {}
add_dicts(target[key], value)
else:
if key not in target:
target[key] = 0
target[key] += value
add_dicts(self.total_usage, usage)
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Called before the run starts."""
self.total_usage = {}
async def on_usage(self, usage: Dict[str, Any]) -> None:
"""Called when usage information is received."""
self._update_usage(usage)
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
"""Called after the run ends."""
def format_dict(d, indent=0):
lines = []
prefix = f" - {' ' * indent}"
for key, value in d.items():
if isinstance(value, dict):
lines.append(f"{prefix}{key}:")
lines.extend(format_dict(value, indent + 1))
elif isinstance(value, float):
lines.append(f"{prefix}{key}: ${value:.4f}")
else:
lines.append(f"{prefix}{key}: {value}")
return lines
formatted_output = "\n".join(format_dict(self.total_usage))
self.logger.info(f"Total usage:\n{formatted_output}")
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Called before LLM processing starts."""
if self.logger.isEnabledFor(logging.INFO):
self.logger.info(f"LLM processing started with {len(messages)} messages")
if self.logger.isEnabledFor(logging.DEBUG):
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
self.logger.debug(f"LLM input messages: {json.dumps(sanitized_messages, indent=2)}")
return messages
async def on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Called after LLM processing ends."""
if self.logger.isEnabledFor(logging.DEBUG):
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
self.logger.debug(f"LLM output: {json.dumps(sanitized_messages, indent=2)}")
return messages
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a computer call starts."""
action = item.get("action", {})
action_type = action.get("type", "unknown")
action_args = {k: v for k, v in action.items() if k != "type"}
# INFO level logging for the action
self.logger.info(f"Computer: {action_type}({action_args})")
# DEBUG level logging for full details
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Computer call started: {json.dumps(action, indent=2)}")
async def on_computer_call_end(self, item: Dict[str, Any], result: Any) -> None:
"""Called when a computer call ends."""
if self.logger.isEnabledFor(logging.DEBUG):
action = item.get("action", "unknown")
self.logger.debug(f"Computer call completed: {json.dumps(action, indent=2)}")
if result:
sanitized_result = sanitize_image_urls(result)
self.logger.debug(f"Computer call result: {json.dumps(sanitized_result, indent=2)}")
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
"""Called when a function call starts."""
name = item.get("name", "unknown")
arguments = item.get("arguments", "{}")
# INFO level logging for the function call
self.logger.info(f"Function: {name}({arguments})")
# DEBUG level logging for full details
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Function call started: {name}")
async def on_function_call_end(self, item: Dict[str, Any], result: Any) -> None:
"""Called when a function call ends."""
# INFO level logging for function output (similar to function_call_output)
if result:
# Handle both list and direct result formats
if isinstance(result, list) and len(result) > 0:
output = result[0].get("output", str(result)) if isinstance(result[0], dict) else str(result[0])
else:
output = str(result)
# Truncate long outputs
if len(output) > 100:
output = output[:100] + "..."
self.logger.info(f"Output: {output}")
# DEBUG level logging for full details
if self.logger.isEnabledFor(logging.DEBUG):
name = item.get("name", "unknown")
self.logger.debug(f"Function call completed: {name}")
if result:
self.logger.debug(f"Function call result: {json.dumps(result, indent=2)}")
async def on_text(self, item: Dict[str, Any]) -> None:
"""Called when a text message is encountered."""
# Get the role to determine if it's Agent or User
role = item.get("role", "unknown")
content_items = item.get("content", [])
# Process content items to build display text
text_parts = []
for content_item in content_items:
content_type = content_item.get("type", "output_text")
if content_type == "output_text":
text_content = content_item.get("text", "")
if not text_content.strip():
text_parts.append("[empty]")
else:
# Truncate long text and add ellipsis
if len(text_content) > 2048:
text_parts.append(text_content[:2048] + "...")
else:
text_parts.append(text_content)
else:
# Non-text content, show as [type]
text_parts.append(f"[{content_type}]")
# Join all text parts
display_text = ''.join(text_parts) if text_parts else "[empty]"
# Log with appropriate level and format
if role == "assistant":
self.logger.info(f"Agent: {display_text}")
elif role == "user":
self.logger.info(f"User: {display_text}")
else:
# Fallback for unknown roles, use debug level
if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(f"Text message ({role}): {display_text}")
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
"""Called when an API call is about to start."""
if self.logger.isEnabledFor(logging.DEBUG):
model = kwargs.get("model", "unknown")
self.logger.debug(f"API call starting for model: {model}")
# Log sanitized messages if present
if "messages" in kwargs:
sanitized_messages = sanitize_image_urls(kwargs["messages"])
self.logger.debug(f"API call messages: {json.dumps(sanitized_messages, indent=2)}")
elif "input" in kwargs:
sanitized_input = sanitize_image_urls(kwargs["input"])
self.logger.debug(f"API call input: {json.dumps(sanitized_input, indent=2)}")
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
"""Called when an API call has completed."""
if self.logger.isEnabledFor(logging.DEBUG):
model = kwargs.get("model", "unknown")
self.logger.debug(f"API call completed for model: {model}")
self.logger.debug(f"API call result: {json.dumps(sanitize_image_urls(result), indent=2)}")
async def on_screenshot(self, item: Union[str, bytes], name: str = "screenshot") -> None:
"""Called when a screenshot is taken."""
if self.logger.isEnabledFor(logging.DEBUG):
image_size = len(item) / 1024
self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB")

View File

@@ -0,0 +1,259 @@
"""
PII anonymization callback handler using Microsoft Presidio for text and image redaction.
"""
from typing import List, Dict, Any, Optional, Tuple
from .base import AsyncCallbackHandler
import base64
import io
import logging
try:
from presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine, DeanonymizeEngine
from presidio_anonymizer.entities import RecognizerResult, OperatorConfig
from presidio_image_redactor import ImageRedactorEngine
from PIL import Image
PRESIDIO_AVAILABLE = True
except ImportError:
PRESIDIO_AVAILABLE = False
logger = logging.getLogger(__name__)
class PIIAnonymizationCallback(AsyncCallbackHandler):
"""
Callback handler that anonymizes PII in text and images using Microsoft Presidio.
This handler:
1. Anonymizes PII in messages before sending to the agent loop
2. Deanonymizes PII in tool calls and message outputs after the agent loop
3. Redacts PII from images in computer_call_output messages
"""
def __init__(
self,
anonymize_text: bool = True,
anonymize_images: bool = True,
entities_to_anonymize: Optional[List[str]] = None,
anonymization_operator: str = "replace",
image_redaction_color: Tuple[int, int, int] = (255, 192, 203) # Pink
):
"""
Initialize the PII anonymization callback.
Args:
anonymize_text: Whether to anonymize text content
anonymize_images: Whether to redact images
entities_to_anonymize: List of entity types to anonymize (None for all)
anonymization_operator: Presidio operator to use ("replace", "mask", "redact", etc.)
image_redaction_color: RGB color for image redaction
"""
if not PRESIDIO_AVAILABLE:
raise ImportError(
"Presidio is not available. Install with: "
"pip install presidio-analyzer presidio-anonymizer presidio-image-redactor"
)
self.anonymize_text = anonymize_text
self.anonymize_images = anonymize_images
self.entities_to_anonymize = entities_to_anonymize
self.anonymization_operator = anonymization_operator
self.image_redaction_color = image_redaction_color
# Initialize Presidio engines
self.analyzer = AnalyzerEngine()
self.anonymizer = AnonymizerEngine()
self.deanonymizer = DeanonymizeEngine()
self.image_redactor = ImageRedactorEngine()
# Store anonymization mappings for deanonymization
self.anonymization_mappings: Dict[str, Any] = {}
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Anonymize PII in messages before sending to agent loop.
Args:
messages: List of message dictionaries
Returns:
List of messages with PII anonymized
"""
if not self.anonymize_text and not self.anonymize_images:
return messages
anonymized_messages = []
for msg in messages:
anonymized_msg = await self._anonymize_message(msg)
anonymized_messages.append(anonymized_msg)
return anonymized_messages
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Deanonymize PII in tool calls and message outputs after agent loop.
Args:
output: List of output dictionaries
Returns:
List of output with PII deanonymized for tool calls
"""
if not self.anonymize_text:
return output
deanonymized_output = []
for item in output:
# Only deanonymize tool calls and computer_call messages
if item.get("type") in ["computer_call", "computer_call_output"]:
deanonymized_item = await self._deanonymize_item(item)
deanonymized_output.append(deanonymized_item)
else:
deanonymized_output.append(item)
return deanonymized_output
async def _anonymize_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Anonymize PII in a single message."""
msg_copy = message.copy()
# Anonymize text content
if self.anonymize_text:
msg_copy = await self._anonymize_text_content(msg_copy)
# Redact images in computer_call_output
if self.anonymize_images and msg_copy.get("type") == "computer_call_output":
msg_copy = await self._redact_image_content(msg_copy)
return msg_copy
async def _anonymize_text_content(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Anonymize text content in a message."""
msg_copy = message.copy()
# Handle content array
content = msg_copy.get("content", [])
if isinstance(content, str):
anonymized_text, _ = await self._anonymize_text(content)
msg_copy["content"] = anonymized_text
elif isinstance(content, list):
anonymized_content = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
text = item.get("text", "")
anonymized_text, _ = await self._anonymize_text(text)
item_copy = item.copy()
item_copy["text"] = anonymized_text
anonymized_content.append(item_copy)
else:
anonymized_content.append(item)
msg_copy["content"] = anonymized_content
return msg_copy
async def _redact_image_content(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Redact PII from images in computer_call_output messages."""
msg_copy = message.copy()
output = msg_copy.get("output", {})
if isinstance(output, dict) and "image_url" in output:
try:
# Extract base64 image data
image_url = output["image_url"]
if image_url.startswith("data:image/"):
# Parse data URL
header, data = image_url.split(",", 1)
image_data = base64.b64decode(data)
# Load image with PIL
image = Image.open(io.BytesIO(image_data))
# Redact PII from image
redacted_image = self.image_redactor.redact(image, self.image_redaction_color)
# Convert back to base64
buffer = io.BytesIO()
redacted_image.save(buffer, format="PNG")
redacted_data = base64.b64encode(buffer.getvalue()).decode()
# Update image URL
output_copy = output.copy()
output_copy["image_url"] = f"data:image/png;base64,{redacted_data}"
msg_copy["output"] = output_copy
except Exception as e:
logger.warning(f"Failed to redact image: {e}")
return msg_copy
async def _deanonymize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
"""Deanonymize PII in tool calls and computer outputs."""
item_copy = item.copy()
# Handle computer_call arguments
if item.get("type") == "computer_call":
args = item_copy.get("args", {})
if isinstance(args, dict):
deanonymized_args = {}
for key, value in args.items():
if isinstance(value, str):
deanonymized_value, _ = await self._deanonymize_text(value)
deanonymized_args[key] = deanonymized_value
else:
deanonymized_args[key] = value
item_copy["args"] = deanonymized_args
return item_copy
async def _anonymize_text(self, text: str) -> Tuple[str, List[RecognizerResult]]:
"""Anonymize PII in text and return the anonymized text and results."""
if not text.strip():
return text, []
try:
# Analyze text for PII
analyzer_results = self.analyzer.analyze(
text=text,
entities=self.entities_to_anonymize,
language="en"
)
if not analyzer_results:
return text, []
# Anonymize the text
anonymized_result = self.anonymizer.anonymize(
text=text,
analyzer_results=analyzer_results,
operators={entity_type: OperatorConfig(self.anonymization_operator)
for entity_type in set(result.entity_type for result in analyzer_results)}
)
# Store mapping for deanonymization
mapping_key = str(hash(text))
self.anonymization_mappings[mapping_key] = {
"original": text,
"anonymized": anonymized_result.text,
"results": analyzer_results
}
return anonymized_result.text, analyzer_results
except Exception as e:
logger.warning(f"Failed to anonymize text: {e}")
return text, []
async def _deanonymize_text(self, text: str) -> Tuple[str, bool]:
"""Attempt to deanonymize text using stored mappings."""
try:
# Look for matching anonymized text in mappings
for mapping_key, mapping in self.anonymization_mappings.items():
if mapping["anonymized"] == text:
return mapping["original"], True
# If no mapping found, return original text
return text, False
except Exception as e:
logger.warning(f"Failed to deanonymize text: {e}")
return text, False

View File

@@ -0,0 +1,210 @@
"""
Telemetry callback handler for Computer-Use Agent (cua-agent)
"""
import time
import uuid
from typing import List, Dict, Any, Optional, Union
from .base import AsyncCallbackHandler
from ..telemetry import (
record_event,
is_telemetry_enabled,
set_dimension,
SYSTEM_INFO,
)
class TelemetryCallback(AsyncCallbackHandler):
"""
Telemetry callback handler for Computer-Use Agent (cua-agent)
Tracks agent usage, performance metrics, and optionally trajectory data.
"""
def __init__(
self,
agent,
log_trajectory: bool = False
):
"""
Initialize telemetry callback.
Args:
agent: The ComputerAgent instance
log_trajectory: Whether to log full trajectory items (opt-in)
"""
self.agent = agent
self.log_trajectory = log_trajectory
# Generate session/run IDs
self.session_id = str(uuid.uuid4())
self.run_id = None
# Track timing and metrics
self.run_start_time = None
self.step_count = 0
self.step_start_time = None
self.total_usage = {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"response_cost": 0.0
}
# Record agent initialization
if is_telemetry_enabled():
self._record_agent_initialization()
def _record_agent_initialization(self) -> None:
"""Record agent type/model and session initialization."""
agent_info = {
"session_id": self.session_id,
"agent_type": self.agent.agent_loop.__name__ if hasattr(self.agent, 'agent_loop') else 'unknown',
"model": getattr(self.agent, 'model', 'unknown'),
**SYSTEM_INFO
}
# Set session-level dimensions
set_dimension("session_id", self.session_id)
set_dimension("agent_type", agent_info["agent_type"])
set_dimension("model", agent_info["model"])
record_event("agent_session_start", agent_info)
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Called at the start of an agent run loop."""
if not is_telemetry_enabled():
return
self.run_id = str(uuid.uuid4())
self.run_start_time = time.time()
self.step_count = 0
# Calculate input context size
input_context_size = self._calculate_context_size(old_items)
run_data = {
"session_id": self.session_id,
"run_id": self.run_id,
"start_time": self.run_start_time,
"input_context_size": input_context_size,
"num_existing_messages": len(old_items)
}
# Log trajectory if opted in
if self.log_trajectory:
trajectory = self._extract_trajectory(old_items)
if trajectory:
run_data["uploaded_trajectory"] = trajectory
set_dimension("run_id", self.run_id)
record_event("agent_run_start", run_data)
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
"""Called at the end of an agent run loop."""
if not is_telemetry_enabled() or not self.run_start_time:
return
run_duration = time.time() - self.run_start_time
run_data = {
"session_id": self.session_id,
"run_id": self.run_id,
"end_time": time.time(),
"duration_seconds": run_duration,
"num_steps": self.step_count,
"total_usage": self.total_usage.copy()
}
# Log trajectory if opted in
if self.log_trajectory:
trajectory = self._extract_trajectory(new_items)
if trajectory:
run_data["uploaded_trajectory"] = trajectory
record_event("agent_run_end", run_data)
async def on_usage(self, usage: Dict[str, Any]) -> None:
"""Called when usage information is received."""
if not is_telemetry_enabled():
return
# Accumulate usage stats
self.total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
self.total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
self.total_usage["total_tokens"] += usage.get("total_tokens", 0)
self.total_usage["response_cost"] += usage.get("response_cost", 0.0)
# Record individual usage event
usage_data = {
"session_id": self.session_id,
"run_id": self.run_id,
"step": self.step_count,
**usage
}
record_event("agent_usage", usage_data)
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
"""Called when responses are received."""
if not is_telemetry_enabled():
return
self.step_count += 1
step_duration = None
if self.step_start_time:
step_duration = time.time() - self.step_start_time
self.step_start_time = time.time()
step_data = {
"session_id": self.session_id,
"run_id": self.run_id,
"step": self.step_count,
"timestamp": self.step_start_time
}
if step_duration is not None:
step_data["duration_seconds"] = step_duration
record_event("agent_step", step_data)
def _calculate_context_size(self, items: List[Dict[str, Any]]) -> int:
"""Calculate approximate context size in tokens/characters."""
total_size = 0
for item in items:
if item.get("type") == "message" and "content" in item:
content = item["content"]
if isinstance(content, str):
total_size += len(content)
elif isinstance(content, list):
for part in content:
if isinstance(part, dict) and "text" in part:
total_size += len(part["text"])
elif "content" in item and isinstance(item["content"], str):
total_size += len(item["content"])
return total_size
def _extract_trajectory(self, items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Extract trajectory items that should be logged."""
trajectory = []
for item in items:
# Include user messages, assistant messages, reasoning, computer calls, and computer outputs
if (
item.get("role") == "user" or # User inputs
(item.get("type") == "message" and item.get("role") == "assistant") or # Model outputs
item.get("type") == "reasoning" or # Reasoning traces
item.get("type") == "computer_call" or # Computer actions
item.get("type") == "computer_call_output" # Computer outputs
):
# Create a copy of the item with timestamp
trajectory_item = item.copy()
trajectory_item["logged_at"] = time.time()
trajectory.append(trajectory_item)
return trajectory

View File

@@ -0,0 +1,305 @@
"""
Trajectory saving callback handler for ComputerAgent.
"""
import os
import json
import uuid
from datetime import datetime
import base64
from pathlib import Path
from typing import List, Dict, Any, Optional, Union, override
from PIL import Image, ImageDraw
import io
from .base import AsyncCallbackHandler
def sanitize_image_urls(data: Any) -> Any:
"""
Recursively search for 'image_url' keys and set their values to '[omitted]'.
Args:
data: Any data structure (dict, list, or primitive type)
Returns:
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
"""
if isinstance(data, dict):
# Create a copy of the dictionary
sanitized = {}
for key, value in data.items():
if key == "image_url":
sanitized[key] = "[omitted]"
else:
# Recursively sanitize the value
sanitized[key] = sanitize_image_urls(value)
return sanitized
elif isinstance(data, list):
# Recursively sanitize each item in the list
return [sanitize_image_urls(item) for item in data]
else:
# For primitive types (str, int, bool, None, etc.), return as-is
return data
class TrajectorySaverCallback(AsyncCallbackHandler):
"""
Callback handler that saves agent trajectories to disk.
Saves each run as a separate trajectory with unique ID, and each turn
within the trajectory gets its own folder with screenshots and responses.
"""
def __init__(self, trajectory_dir: str):
"""
Initialize trajectory saver.
Args:
trajectory_dir: Base directory to save trajectories
"""
self.trajectory_dir = Path(trajectory_dir)
self.trajectory_id: Optional[str] = None
self.current_turn: int = 0
self.current_artifact: int = 0
self.model: Optional[str] = None
self.total_usage: Dict[str, Any] = {}
# Ensure trajectory directory exists
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
def _get_turn_dir(self) -> Path:
"""Get the directory for the current turn."""
if not self.trajectory_id:
raise ValueError("Trajectory not initialized - call _on_run_start first")
# format: trajectory_id/turn_000
turn_dir = self.trajectory_dir / self.trajectory_id / f"turn_{self.current_turn:03d}"
turn_dir.mkdir(parents=True, exist_ok=True)
return turn_dir
def _save_artifact(self, name: str, artifact: Union[str, bytes, Dict[str, Any]]) -> None:
"""Save an artifact to the current turn directory."""
turn_dir = self._get_turn_dir()
if isinstance(artifact, bytes):
# format: turn_000/0000_name.png
artifact_filename = f"{self.current_artifact:04d}_{name}"
artifact_path = turn_dir / f"{artifact_filename}.png"
with open(artifact_path, "wb") as f:
f.write(artifact)
else:
# format: turn_000/0000_name.json
artifact_filename = f"{self.current_artifact:04d}_{name}"
artifact_path = turn_dir / f"{artifact_filename}.json"
with open(artifact_path, "w") as f:
json.dump(sanitize_image_urls(artifact), f, indent=2)
self.current_artifact += 1
def _update_usage(self, usage: Dict[str, Any]) -> None:
"""Update total usage statistics."""
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
for key, value in source.items():
if isinstance(value, dict):
if key not in target:
target[key] = {}
add_dicts(target[key], value)
else:
if key not in target:
target[key] = 0
target[key] += value
add_dicts(self.total_usage, usage)
@override
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
"""Initialize trajectory tracking for a new run."""
model = kwargs.get("model", "unknown")
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
if "+" in model:
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
now = datetime.now()
self.trajectory_id = f"{now.strftime('%Y-%m-%d')}_{model_name_short}_{now.strftime('%H%M%S')}_{str(uuid.uuid4())[:4]}"
self.current_turn = 0
self.current_artifact = 0
self.model = model
self.total_usage = {}
# Create trajectory directory
trajectory_path = self.trajectory_dir / self.trajectory_id
trajectory_path.mkdir(parents=True, exist_ok=True)
# Save trajectory metadata
metadata = {
"trajectory_id": self.trajectory_id,
"created_at": str(uuid.uuid1().time),
"status": "running",
"kwargs": kwargs,
}
with open(trajectory_path / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
@override
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
"""Finalize run tracking by updating metadata with completion status, usage, and new items."""
if not self.trajectory_id:
return
# Update metadata with completion status, total usage, and new items
trajectory_path = self.trajectory_dir / self.trajectory_id
metadata_path = trajectory_path / "metadata.json"
# Read existing metadata
if metadata_path.exists():
with open(metadata_path, "r") as f:
metadata = json.load(f)
else:
metadata = {}
# Update metadata with completion info
metadata.update({
"status": "completed",
"completed_at": str(uuid.uuid1().time),
"total_usage": self.total_usage,
"new_items": sanitize_image_urls(new_items),
"total_turns": self.current_turn
})
# Save updated metadata
with open(metadata_path, "w") as f:
json.dump(metadata, f, indent=2)
@override
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
if not self.trajectory_id:
return
self._save_artifact("api_start", { "kwargs": kwargs })
@override
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
"""Save API call result."""
if not self.trajectory_id:
return
self._save_artifact("api_result", { "kwargs": kwargs, "result": result })
@override
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
"""Save a screenshot."""
if isinstance(screenshot, str):
screenshot = base64.b64decode(screenshot)
self._save_artifact(name, screenshot)
@override
async def on_usage(self, usage: Dict[str, Any]) -> None:
"""Called when usage information is received."""
self._update_usage(usage)
@override
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
"""Save responses to the current turn directory and update usage statistics."""
if not self.trajectory_id:
return
# Save responses
turn_dir = self._get_turn_dir()
response_data = {
"timestamp": str(uuid.uuid1().time),
"model": self.model,
"kwargs": kwargs,
"response": responses
}
self._save_artifact("agent_response", response_data)
# Increment turn counter
self.current_turn += 1
def _draw_crosshair_on_image(self, image_bytes: bytes, x: int, y: int) -> bytes:
"""
Draw a red dot and crosshair at the specified coordinates on the image.
Args:
image_bytes: The original image as bytes
x: X coordinate for the crosshair
y: Y coordinate for the crosshair
Returns:
Modified image as bytes with red dot and crosshair
"""
# Open the image
image = Image.open(io.BytesIO(image_bytes))
draw = ImageDraw.Draw(image)
# Draw crosshair lines (red, 2px thick)
crosshair_size = 20
line_width = 2
color = "red"
# Horizontal line
draw.line([(x - crosshair_size, y), (x + crosshair_size, y)], fill=color, width=line_width)
# Vertical line
draw.line([(x, y - crosshair_size), (x, y + crosshair_size)], fill=color, width=line_width)
# Draw center dot (filled circle)
dot_radius = 3
draw.ellipse([(x - dot_radius, y - dot_radius), (x + dot_radius, y + dot_radius)], fill=color)
# Convert back to bytes
output = io.BytesIO()
image.save(output, format='PNG')
return output.getvalue()
@override
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
"""
Called when a computer call has completed.
Saves screenshots and computer call output.
"""
if not self.trajectory_id:
return
self._save_artifact("computer_call_result", { "item": item, "result": result })
# Check if action has x/y coordinates and there's a screenshot in the result
action = item.get("action", {})
if "x" in action and "y" in action:
# Look for screenshot in the result
for result_item in result:
if (result_item.get("type") == "computer_call_output" and
result_item.get("output", {}).get("type") == "input_image"):
image_url = result_item["output"]["image_url"]
# Extract base64 image data
if image_url.startswith("data:image/"):
# Format: data:image/png;base64,<base64_data>
base64_data = image_url.split(",", 1)[1]
else:
# Assume it's just base64 data
base64_data = image_url
try:
# Decode the image
image_bytes = base64.b64decode(base64_data)
# Draw crosshair at the action coordinates
annotated_image = self._draw_crosshair_on_image(
image_bytes,
int(action["x"]),
int(action["y"])
)
# Save as screenshot_action
self._save_artifact("screenshot_action", annotated_image)
except Exception as e:
# If annotation fails, just log and continue
print(f"Failed to annotate screenshot: {e}")
break # Only process the first screenshot found
# Increment turn counter
self.current_turn += 1

View File

@@ -0,0 +1,359 @@
"""
CLI chat interface for agent - Computer Use Agent
Usage:
python -m agent.cli <model_string>
Examples:
python -m agent.cli openai/computer-use-preview
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
"""
try:
import asyncio
import argparse
import os
import sys
import json
from typing import List, Dict, Any
import dotenv
from yaspin import yaspin
except ImportError:
if __name__ == "__main__":
raise ImportError(
"CLI dependencies not found. "
"Please install with: pip install \"cua-agent[cli]\""
)
# Load environment variables
dotenv.load_dotenv()
# Color codes for terminal output
class Colors:
RESET = '\033[0m'
BOLD = '\033[1m'
DIM = '\033[2m'
# Text colors
RED = '\033[31m'
GREEN = '\033[32m'
YELLOW = '\033[33m'
BLUE = '\033[34m'
MAGENTA = '\033[35m'
CYAN = '\033[36m'
WHITE = '\033[37m'
GRAY = '\033[90m'
# Background colors
BG_RED = '\033[41m'
BG_GREEN = '\033[42m'
BG_YELLOW = '\033[43m'
BG_BLUE = '\033[44m'
def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = False, end: str = "\n", right: str = ""):
"""Print colored text to terminal with optional right-aligned text."""
prefix = ""
if bold:
prefix += Colors.BOLD
if dim:
prefix += Colors.DIM
if color:
prefix += color
if right:
# Get terminal width (default to 80 if unable to determine)
try:
import shutil
terminal_width = shutil.get_terminal_size().columns
except:
terminal_width = 80
# Add right margin
terminal_width -= 1
# Calculate padding needed
# Account for ANSI escape codes not taking visual space
visible_left_len = len(text)
visible_right_len = len(right)
padding = terminal_width - visible_left_len - visible_right_len
if padding > 0:
output = f"{prefix}{text}{' ' * padding}{right}{Colors.RESET}"
else:
# If not enough space, just put a single space between
output = f"{prefix}{text} {right}{Colors.RESET}"
else:
output = f"{prefix}{text}{Colors.RESET}"
print(output, end=end)
def print_action(action_type: str, details: Dict[str, Any], total_cost: float):
"""Print computer action with nice formatting."""
# Format action details
args_str = ""
if action_type == "click" and "x" in details and "y" in details:
args_str = f"({details['x']}, {details['y']})"
elif action_type == "type" and "text" in details:
text = details["text"]
if len(text) > 50:
text = text[:47] + "..."
args_str = f'"{text}"'
elif action_type == "key" and "key" in details:
args_str = f"'{details['key']}'"
elif action_type == "scroll" and "x" in details and "y" in details:
args_str = f"({details['x']}, {details['y']})"
if total_cost > 0:
print_colored(f"🛠️ {action_type}{args_str}", dim=True, right=f"💸 ${total_cost:.2f}")
else:
print_colored(f"🛠️ {action_type}{args_str}", dim=True)
def print_welcome(model: str, agent_loop: str, container_name: str):
"""Print welcome message."""
print_colored(f"Connected to {container_name} ({model}, {agent_loop})")
print_colored("Type 'exit' to quit.", dim=True)
async def ainput(prompt: str = ""):
return await asyncio.to_thread(input, prompt)
async def chat_loop(agent, model: str, container_name: str, initial_prompt: str = "", show_usage: bool = True):
"""Main chat loop with the agent."""
print_welcome(model, agent.agent_loop.__name__, container_name)
history = []
if initial_prompt:
history.append({"role": "user", "content": initial_prompt})
total_cost = 0
while True:
if history[-1].get("role") != "user":
# Get user input with prompt
print_colored("> ", end="")
user_input = await ainput()
if user_input.lower() in ['exit', 'quit', 'q']:
print_colored("\n👋 Goodbye!")
break
if not user_input:
continue
# Add user message to history
history.append({"role": "user", "content": user_input})
# Stream responses from the agent with spinner
with yaspin(text="Thinking...", spinner="line", attrs=["dark"]) as spinner:
spinner.hide()
async for result in agent.run(history):
# Add agent responses to history
history.extend(result.get("output", []))
if show_usage:
total_cost += result.get("usage", {}).get("response_cost", 0)
# Process and display the output
for item in result.get("output", []):
if item.get("type") == "message":
# Display agent text response
content = item.get("content", [])
for content_part in content:
if content_part.get("text"):
text = content_part.get("text", "").strip()
if text:
spinner.hide()
print_colored(text)
elif item.get("type") == "computer_call":
# Display computer action
action = item.get("action", {})
action_type = action.get("type", "")
if action_type:
spinner.hide()
print_action(action_type, action, total_cost)
spinner.text = f"Performing {action_type}..."
spinner.show()
elif item.get("type") == "function_call":
# Display function call
function_name = item.get("name", "")
spinner.hide()
print_colored(f"🔧 Calling function: {function_name}", dim=True)
spinner.text = f"Calling {function_name}..."
spinner.show()
elif item.get("type") == "function_call_output":
# Display function output (dimmed)
output = item.get("output", "")
if output and len(output.strip()) > 0:
spinner.hide()
print_colored(f"📤 {output}", dim=True)
spinner.hide()
if show_usage and total_cost > 0:
print_colored(f"Total cost: ${total_cost:.2f}", dim=True)
async def main():
"""Main CLI function."""
parser = argparse.ArgumentParser(
description="CUA Agent CLI - Interactive computer use assistant",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python -m agent.cli openai/computer-use-preview
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
python -m agent.cli huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B
"""
)
parser.add_argument(
"model",
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')"
)
parser.add_argument(
"--images",
type=int,
default=3,
help="Number of recent images to keep in context (default: 3)"
)
parser.add_argument(
"--trajectory",
action="store_true",
help="Save trajectory for debugging"
)
parser.add_argument(
"--budget",
type=float,
help="Maximum budget for the session (in dollars)"
)
parser.add_argument(
"--verbose",
action="store_true",
help="Enable verbose logging"
)
parser.add_argument(
"-p", "--prompt",
type=str,
help="Initial prompt to send to the agent. Leave blank for interactive mode."
)
parser.add_argument(
"-c", "--cache",
action="store_true",
help="Tell the API to enable caching"
)
parser.add_argument(
"-u", "--usage",
action="store_true",
help="Show total cost of the agent runs"
)
args = parser.parse_args()
# Check for required environment variables
container_name = os.getenv("CUA_CONTAINER_NAME")
cua_api_key = os.getenv("CUA_API_KEY")
# Prompt for missing environment variables
if not container_name:
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
container_name = input("Enter your CUA container name: ").strip()
if not container_name:
print_colored("❌ Container name is required.")
sys.exit(1)
if not cua_api_key:
print_colored("CUA_API_KEY not set.", dim=True)
cua_api_key = input("Enter your CUA API key: ").strip()
if not cua_api_key:
print_colored("❌ API key is required.")
sys.exit(1)
# Check for provider-specific API keys based on model
provider_api_keys = {
"openai/": "OPENAI_API_KEY",
"anthropic/": "ANTHROPIC_API_KEY",
"omniparser+": "OPENAI_API_KEY",
"omniparser+": "ANTHROPIC_API_KEY",
}
# Find matching provider and check for API key
for prefix, env_var in provider_api_keys.items():
if args.model.startswith(prefix):
if not os.getenv(env_var):
print_colored(f"{env_var} not set.", dim=True)
api_key = input(f"Enter your {env_var.replace('_', ' ').title()}: ").strip()
if not api_key:
print_colored(f"{env_var.replace('_', ' ').title()} is required.")
sys.exit(1)
# Set the environment variable for the session
os.environ[env_var] = api_key
break
# Import here to avoid import errors if dependencies are missing
try:
from agent import ComputerAgent
from computer import Computer
except ImportError as e:
print_colored(f"❌ Import error: {e}", Colors.RED, bold=True)
print_colored("Make sure agent and computer libraries are installed.", Colors.YELLOW)
sys.exit(1)
# Create computer instance
async with Computer(
os_type="linux",
provider_type="cloud",
name=container_name,
api_key=cua_api_key
) as computer:
# Create agent
agent_kwargs = {
"model": args.model,
"tools": [computer],
"verbosity": 20 if args.verbose else 30, # DEBUG vs WARNING
}
if args.images > 0:
agent_kwargs["only_n_most_recent_images"] = args.images
if args.trajectory:
agent_kwargs["trajectory_dir"] = "trajectories"
if args.budget:
agent_kwargs["max_trajectory_budget"] = {
"max_budget": args.budget,
"raise_error": True,
"reset_after_each_run": False
}
if args.cache:
agent_kwargs["use_prompt_caching"] = True
agent = ComputerAgent(**agent_kwargs)
# Start chat loop
await chat_loop(agent, args.model, container_name, args.prompt, args.usage)
if __name__ == "__main__":
try:
asyncio.run(main())
except (KeyboardInterrupt, EOFError) as _:
print_colored("\n\n👋 Goodbye!")

View File

@@ -0,0 +1,107 @@
"""
Computer handler implementation for OpenAI computer-use-preview protocol.
"""
import base64
from typing import Dict, List, Any, Literal
from .types import Computer
class OpenAIComputerHandler:
"""Computer handler that implements the Computer protocol using the computer interface."""
def __init__(self, computer_interface):
"""Initialize with a computer interface (from tool schema)."""
self.interface = computer_interface
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
"""Get the current environment type."""
# For now, return a default - this could be enhanced to detect actual environment
return "windows"
async def get_dimensions(self) -> tuple[int, int]:
"""Get screen dimensions as (width, height)."""
screen_size = await self.interface.get_screen_size()
return screen_size["width"], screen_size["height"]
async def screenshot(self) -> str:
"""Take a screenshot and return as base64 string."""
screenshot_bytes = await self.interface.screenshot()
return base64.b64encode(screenshot_bytes).decode('utf-8')
async def click(self, x: int, y: int, button: str = "left") -> None:
"""Click at coordinates with specified button."""
if button == "left":
await self.interface.left_click(x, y)
elif button == "right":
await self.interface.right_click(x, y)
else:
# Default to left click for unknown buttons
await self.interface.left_click(x, y)
async def double_click(self, x: int, y: int) -> None:
"""Double click at coordinates."""
await self.interface.double_click(x, y)
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
"""Scroll at coordinates with specified scroll amounts."""
await self.interface.move_cursor(x, y)
await self.interface.scroll(scroll_x, scroll_y)
async def type(self, text: str) -> None:
"""Type text."""
await self.interface.type_text(text)
async def wait(self, ms: int = 1000) -> None:
"""Wait for specified milliseconds."""
import asyncio
await asyncio.sleep(ms / 1000.0)
async def move(self, x: int, y: int) -> None:
"""Move cursor to coordinates."""
await self.interface.move_cursor(x, y)
async def keypress(self, keys: List[str]) -> None:
"""Press key combination."""
if len(keys) == 1:
await self.interface.press_key(keys[0])
else:
# Handle key combinations
await self.interface.hotkey(*keys)
async def drag(self, path: List[Dict[str, int]]) -> None:
"""Drag along specified path."""
if not path:
return
# Start drag from first point
start = path[0]
await self.interface.mouse_down(start["x"], start["y"])
# Move through path
for point in path[1:]:
await self.interface.move_cursor(point["x"], point["y"])
# End drag at last point
end = path[-1]
await self.interface.mouse_up(end["x"], end["y"])
async def get_current_url(self) -> str:
"""Get current URL (for browser environments)."""
# This would need to be implemented based on the specific browser interface
# For now, return empty string
return ""
def acknowledge_safety_check_callback(message: str) -> bool:
"""Safety check callback for user acknowledgment."""
response = input(
f"Safety Check Warning: {message}\nDo you want to acknowledge and proceed? (y/n): "
).lower()
return response.strip() == "y"
def check_blocklisted_url(url: str) -> None:
"""Check if URL is blocklisted (placeholder implementation)."""
# This would contain actual URL checking logic
pass

View File

@@ -1,27 +0,0 @@
"""Core agent components."""
from .factory import BaseLoop
from .messages import (
StandardMessageManager,
ImageRetentionConfig,
)
from .callbacks import (
CallbackManager,
CallbackHandler,
BaseCallbackManager,
ContentCallback,
ToolCallback,
APICallback,
)
__all__ = [
"BaseLoop",
"CallbackManager",
"CallbackHandler",
"StandardMessageManager",
"ImageRetentionConfig",
"BaseCallbackManager",
"ContentCallback",
"ToolCallback",
"APICallback",
]

View File

@@ -1,210 +0,0 @@
"""Main entry point for computer agents."""
import asyncio
import logging
import os
from typing import AsyncGenerator, Optional
from computer import Computer
from .types import LLM, AgentLoop
from .types import AgentResponse
from .factory import LoopFactory
from .provider_config import DEFAULT_MODELS, ENV_VARS
logger = logging.getLogger(__name__)
class ComputerAgent:
"""A computer agent that can perform automated tasks using natural language instructions."""
def __init__(
self,
computer: Computer,
model: LLM,
loop: AgentLoop,
max_retries: int = 3,
screenshot_dir: Optional[str] = None,
log_dir: Optional[str] = None,
api_key: Optional[str] = None,
save_trajectory: bool = True,
trajectory_dir: str = "trajectories",
only_n_most_recent_images: Optional[int] = None,
verbosity: int = logging.INFO,
):
"""Initialize the ComputerAgent.
Args:
computer: Computer instance. If not provided, one will be created with default settings.
max_retries: Maximum number of retry attempts.
screenshot_dir: Directory to save screenshots.
log_dir: Directory to save logs (set to None to disable logging to files).
model: LLM object containing provider and model name. Takes precedence over provider/model_name.
provider: The AI provider to use (e.g., LLMProvider.ANTHROPIC). Only used if model is None.
api_key: The API key for the provider. If not provided, will look for environment variable.
model_name: The model name to use. Only used if model is None.
save_trajectory: Whether to save the trajectory.
trajectory_dir: Directory to save the trajectory.
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests.
verbosity: Logging level.
"""
# Basic agent configuration
self.max_retries = max_retries
self.computer = computer
self.queue = asyncio.Queue()
self.screenshot_dir = screenshot_dir
self.log_dir = log_dir
self._retry_count = 0
self._initialized = False
self._in_context = False
# Set logging level
logger.setLevel(verbosity)
# Setup logging
if self.log_dir:
os.makedirs(self.log_dir, exist_ok=True)
logger.info(f"Created logs directory: {self.log_dir}")
# Setup screenshots directory
if self.screenshot_dir:
os.makedirs(self.screenshot_dir, exist_ok=True)
logger.info(f"Created screenshots directory: {self.screenshot_dir}")
# Use the provided LLM object
self.provider = model.provider
actual_model_name = model.name or DEFAULT_MODELS.get(self.provider, "")
self.provider_base_url = getattr(model, "provider_base_url", None)
# Ensure we have a valid model name
if not actual_model_name:
actual_model_name = DEFAULT_MODELS.get(self.provider, "")
if not actual_model_name:
raise ValueError(
f"No model specified for provider {self.provider} and no default found"
)
# Get API key from environment if not provided
actual_api_key = api_key or os.environ.get(ENV_VARS[self.provider], "")
# Ollama and OpenAI-compatible APIs typically don't require an API key
if (
not actual_api_key
and str(self.provider) not in ["ollama", "oaicompat"]
and ENV_VARS[self.provider] != "none"
):
raise ValueError(f"No API key provided for {self.provider}")
# Create the appropriate loop using the factory
try:
# Let the factory create the appropriate loop with needed components
self._loop = LoopFactory.create_loop(
loop_type=loop,
provider=self.provider,
computer=self.computer,
model_name=actual_model_name,
api_key=actual_api_key,
save_trajectory=save_trajectory,
trajectory_dir=trajectory_dir,
only_n_most_recent_images=only_n_most_recent_images,
provider_base_url=self.provider_base_url,
)
except ValueError as e:
logger.error(f"Failed to create loop: {str(e)}")
raise
# Initialize the message manager from the loop
self.message_manager = self._loop.message_manager
logger.info(
f"ComputerAgent initialized with provider: {self.provider}, model: {actual_model_name}"
)
async def __aenter__(self):
"""Initialize the agent when used as a context manager."""
logger.info("Entering ComputerAgent context")
self._in_context = True
# In case the computer wasn't initialized
try:
# Initialize the computer only if not already initialized
logger.info("Checking if computer is already initialized...")
if not self.computer._initialized:
logger.info("Initializing computer in __aenter__...")
# Use the computer's __aenter__ directly instead of calling run()
await self.computer.__aenter__()
logger.info("Computer initialized in __aenter__")
else:
logger.info("Computer already initialized, skipping initialization")
except Exception as e:
logger.error(f"Error initializing computer in __aenter__: {str(e)}")
raise
await self.initialize()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Cleanup agent resources if needed."""
logger.info("Cleaning up agent resources")
self._in_context = False
# Do any necessary cleanup
# We're not shutting down the computer here as it might be shared
# Just log that we're exiting
if exc_type:
logger.error(f"Exiting agent context with error: {exc_type.__name__}: {exc_val}")
else:
logger.info("Exiting agent context normally")
# If we have a queue, make sure to signal it's done
if hasattr(self, "queue") and self.queue:
await self.queue.put(None) # Signal that we're done
async def initialize(self) -> None:
"""Initialize the agent and its components."""
if not self._initialized:
# Always initialize the computer if available
if self.computer and not self.computer._initialized:
await self.computer.run()
self._initialized = True
async def run(self, task: str) -> AsyncGenerator[AgentResponse, None]:
"""Run a task using the computer agent.
Args:
task: Task description
Yields:
Agent response format
"""
try:
logger.info(f"Running task: {task}")
logger.info(
f"Message history before task has {len(self.message_manager.messages)} messages"
)
# Initialize the computer if needed
if not self._initialized:
await self.initialize()
# Add task as a user message using the message manager
self.message_manager.add_user_message([{"type": "text", "text": task}])
logger.info(
f"Added task message. Message history now has {len(self.message_manager.messages)} messages"
)
# Pass properly formatted messages to the loop
if self._loop is None:
logger.error("Loop not initialized properly")
yield {"error": "Loop not initialized properly"}
return
# Execute the task and yield results
async for result in self._loop.run(self.message_manager.messages):
yield result
except Exception as e:
logger.error(f"Error in agent run method: {str(e)}")
yield {
"role": "assistant",
"content": f"Error: {str(e)}",
"metadata": {"title": "❌ Error"},
}

View File

@@ -1,217 +0,0 @@
"""Base loop definitions."""
import logging
import asyncio
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Dict, List, Optional
from computer import Computer
from .messages import StandardMessageManager, ImageRetentionConfig
from .types import AgentResponse
from .experiment import ExperimentManager
from .callbacks import CallbackManager, CallbackHandler
logger = logging.getLogger(__name__)
class BaseLoop(ABC):
"""Base class for agent loops that handle message processing and tool execution."""
def __init__(
self,
computer: Computer,
model: str,
api_key: str,
max_tokens: int = 4096,
max_retries: int = 3,
retry_delay: float = 1.0,
base_dir: Optional[str] = "trajectories",
save_trajectory: bool = True,
only_n_most_recent_images: Optional[int] = 2,
callback_handlers: Optional[List[CallbackHandler]] = None,
**kwargs,
):
"""Initialize base agent loop.
Args:
computer: Computer instance to control
model: Model name to use
api_key: API key for provider
max_tokens: Maximum tokens to generate
max_retries: Maximum number of retries
retry_delay: Delay between retries in seconds
base_dir: Base directory for saving experiment data
save_trajectory: Whether to save trajectory data
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
**kwargs: Additional provider-specific arguments
"""
self.computer = computer
self.model = model
self.api_key = api_key
self.max_tokens = max_tokens
self.max_retries = max_retries
self.retry_delay = retry_delay
self.base_dir = base_dir
self.save_trajectory = save_trajectory
self.only_n_most_recent_images = only_n_most_recent_images
self._kwargs = kwargs
# Initialize message manager
self.message_manager = StandardMessageManager(
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
)
# Initialize experiment manager
if self.save_trajectory and self.base_dir:
self.experiment_manager = ExperimentManager(
base_dir=self.base_dir,
only_n_most_recent_images=only_n_most_recent_images,
)
# Track directories for convenience
self.run_dir = self.experiment_manager.run_dir
self.current_turn_dir = self.experiment_manager.current_turn_dir
else:
self.experiment_manager = None
self.run_dir = None
self.current_turn_dir = None
# Initialize basic tracking
self.turn_count = 0
# Initialize callback manager
self.callback_manager = CallbackManager(handlers=callback_handlers or [])
async def initialize(self) -> None:
"""Initialize both the API client and computer interface with retries."""
for attempt in range(self.max_retries):
try:
logger.info(
f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..."
)
# Initialize API client
await self.initialize_client()
logger.info("Initialization complete.")
return
except Exception as e:
if attempt < self.max_retries - 1:
logger.warning(
f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..."
)
await asyncio.sleep(self.retry_delay)
else:
logger.error(
f"Initialization failed after {self.max_retries} attempts: {str(e)}"
)
raise RuntimeError(f"Failed to initialize: {str(e)}")
###########################################
# ABSTRACT METHODS TO BE IMPLEMENTED BY SUBCLASSES
###########################################
@abstractmethod
async def initialize_client(self) -> None:
"""Initialize the API client and any provider-specific components.
This method must be implemented by subclasses to set up
provider-specific clients and tools.
"""
raise NotImplementedError
@abstractmethod
def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
"""Run the agent loop with provided messages.
Args:
messages: List of message objects
Returns:
An async generator that yields agent responses
"""
raise NotImplementedError
@abstractmethod
async def cancel(self) -> None:
"""Cancel the currently running agent loop task.
This method should stop any ongoing processing in the agent loop
and clean up resources appropriately.
"""
raise NotImplementedError
###########################################
# EXPERIMENT AND TRAJECTORY MANAGEMENT
###########################################
def _setup_experiment_dirs(self) -> None:
"""Setup the experiment directory structure."""
if self.experiment_manager:
# Use the experiment manager to set up directories
self.experiment_manager.setup_experiment_dirs()
# Update local tracking variables
self.run_dir = self.experiment_manager.run_dir
self.current_turn_dir = self.experiment_manager.current_turn_dir
def _create_turn_dir(self) -> None:
"""Create a new directory for the current turn."""
if self.experiment_manager:
# Use the experiment manager to create the turn directory
self.experiment_manager.create_turn_dir()
# Update local tracking variables
self.current_turn_dir = self.experiment_manager.current_turn_dir
self.turn_count = self.experiment_manager.turn_count
def _log_api_call(
self, call_type: str, request: Any, response: Any = None, error: Optional[Exception] = None
) -> None:
"""Log API call details to file.
Preserves provider-specific formats for requests and responses to ensure
accurate logging for debugging and analysis purposes.
Args:
call_type: Type of API call (e.g., 'request', 'response', 'error')
request: The API request data in provider-specific format
response: Optional API response data in provider-specific format
error: Optional error information
"""
if self.experiment_manager:
# Use the experiment manager to log the API call
provider = getattr(self, "provider", "unknown")
provider_str = str(provider) if provider else "unknown"
self.experiment_manager.log_api_call(
call_type=call_type,
request=request,
provider=provider_str,
model=self.model,
response=response,
error=error,
)
def _save_screenshot(self, img_base64: str, action_type: str = "") -> None:
"""Save a screenshot to the experiment directory.
Args:
img_base64: Base64 encoded screenshot
action_type: Type of action that triggered the screenshot
"""
if self.experiment_manager:
self.experiment_manager.save_screenshot(img_base64, action_type)
###########################################
# EVENT HOOKS / CALLBACKS
###########################################
async def handle_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[dict] = None) -> None:
"""Process a screenshot through callback managers
Args:
screenshot_base64: Base64 encoded screenshot
action_type: Type of action that triggered the screenshot
"""
if hasattr(self, 'callback_manager'):
await self.callback_manager.on_screenshot(screenshot_base64, action_type, parsed_screen)

View File

@@ -1,200 +0,0 @@
"""Callback handlers for agent."""
import json
import logging
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Dict, List, Optional, Protocol
logger = logging.getLogger(__name__)
class ContentCallback(Protocol):
"""Protocol for content callbacks."""
def __call__(self, content: Dict[str, Any]) -> None: ...
class ToolCallback(Protocol):
"""Protocol for tool callbacks."""
def __call__(self, result: Any, tool_id: str) -> None: ...
class APICallback(Protocol):
"""Protocol for API callbacks."""
def __call__(self, request: Any, response: Any, error: Optional[Exception] = None) -> None: ...
class ScreenshotCallback(Protocol):
"""Protocol for screenshot callbacks."""
def __call__(self, screenshot_base64: str, action_type: str = "") -> Optional[str]: ...
class BaseCallbackManager(ABC):
"""Base class for callback managers."""
def __init__(
self,
content_callback: ContentCallback,
tool_callback: ToolCallback,
api_callback: APICallback,
):
"""Initialize the callback manager.
Args:
content_callback: Callback for content updates
tool_callback: Callback for tool execution results
api_callback: Callback for API interactions
"""
self.content_callback = content_callback
self.tool_callback = tool_callback
self.api_callback = api_callback
@abstractmethod
def on_content(self, content: Any) -> None:
"""Handle content updates."""
raise NotImplementedError
@abstractmethod
def on_tool_result(self, result: Any, tool_id: str) -> None:
"""Handle tool execution results."""
raise NotImplementedError
@abstractmethod
def on_api_interaction(
self,
request: Any,
response: Any,
error: Optional[Exception] = None
) -> None:
"""Handle API interactions."""
raise NotImplementedError
class CallbackManager:
"""Manager for callback handlers."""
def __init__(self, handlers: Optional[List["CallbackHandler"]] = None):
"""Initialize with optional handlers.
Args:
handlers: List of callback handlers
"""
self.handlers = handlers or []
def add_handler(self, handler: "CallbackHandler") -> None:
"""Add a callback handler.
Args:
handler: Callback handler to add
"""
self.handlers.append(handler)
async def on_action_start(self, action: str, **kwargs) -> None:
"""Called when an action starts.
Args:
action: Action name
**kwargs: Additional data
"""
for handler in self.handlers:
await handler.on_action_start(action, **kwargs)
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
"""Called when an action ends.
Args:
action: Action name
success: Whether the action was successful
**kwargs: Additional data
"""
for handler in self.handlers:
await handler.on_action_end(action, success, **kwargs)
async def on_error(self, error: Exception, **kwargs) -> None:
"""Called when an error occurs.
Args:
error: Exception that occurred
**kwargs: Additional data
"""
for handler in self.handlers:
await handler.on_error(error, **kwargs)
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[dict] = None) -> None:
"""Called when a screenshot is taken.
Args:
screenshot_base64: Base64 encoded screenshot
action_type: Type of action that triggered the screenshot
parsed_screen: Optional output from parsing the screenshot
Returns:
Modified screenshot or original if no modifications
"""
for handler in self.handlers:
await handler.on_screenshot(screenshot_base64, action_type, parsed_screen)
class CallbackHandler(ABC):
"""Base class for callback handlers."""
@abstractmethod
async def on_action_start(self, action: str, **kwargs) -> None:
"""Called when an action starts.
Args:
action: Action name
**kwargs: Additional data
"""
pass
@abstractmethod
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
"""Called when an action ends.
Args:
action: Action name
success: Whether the action was successful
**kwargs: Additional data
"""
pass
@abstractmethod
async def on_error(self, error: Exception, **kwargs) -> None:
"""Called when an error occurs.
Args:
error: Exception that occurred
**kwargs: Additional data
"""
pass
@abstractmethod
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[dict] = None) -> None:
"""Called when a screenshot is taken.
Args:
screenshot_base64: Base64 encoded screenshot
action_type: Type of action that triggered the screenshot
Returns:
Optional modified screenshot
"""
pass
class DefaultCallbackHandler(CallbackHandler):
"""Default implementation of CallbackHandler with no-op methods.
This class implements all abstract methods from CallbackHandler,
allowing subclasses to override only the methods they need.
"""
async def on_action_start(self, action: str, **kwargs) -> None:
"""Default no-op implementation."""
pass
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
"""Default no-op implementation."""
pass
async def on_error(self, error: Exception, **kwargs) -> None:
"""Default no-op implementation."""
pass
async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None:
"""Default no-op implementation."""
pass

View File

@@ -1,249 +0,0 @@
"""Core experiment management for agents."""
import os
import logging
import base64
from io import BytesIO
from datetime import datetime
from typing import Any, Dict, List, Optional
from PIL import Image
import json
import re
logger = logging.getLogger(__name__)
class ExperimentManager:
"""Manages experiment directories and logging for the agent."""
def __init__(
self,
base_dir: Optional[str] = None,
only_n_most_recent_images: Optional[int] = None,
):
"""Initialize the experiment manager.
Args:
base_dir: Base directory for saving experiment data
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
"""
self.base_dir = base_dir
self.only_n_most_recent_images = only_n_most_recent_images
self.run_dir = None
self.current_turn_dir = None
self.turn_count = 0
self.screenshot_count = 0
# Track all screenshots for potential API request inclusion
self.screenshot_paths = []
# Set up experiment directories if base_dir is provided
if self.base_dir:
self.setup_experiment_dirs()
def setup_experiment_dirs(self) -> None:
"""Setup the experiment directory structure."""
if not self.base_dir:
return
# Create base experiments directory if it doesn't exist
os.makedirs(self.base_dir, exist_ok=True)
# Create timestamped run directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
self.run_dir = os.path.join(self.base_dir, timestamp)
os.makedirs(self.run_dir, exist_ok=True)
logger.info(f"Created run directory: {self.run_dir}")
# Create first turn directory
self.create_turn_dir()
def create_turn_dir(self) -> None:
"""Create a new directory for the current turn."""
if not self.run_dir:
logger.warning("Cannot create turn directory: run_dir not set")
return
# Increment turn counter
self.turn_count += 1
# Create turn directory with padded number
turn_name = f"turn_{self.turn_count:03d}"
self.current_turn_dir = os.path.join(self.run_dir, turn_name)
os.makedirs(self.current_turn_dir, exist_ok=True)
logger.info(f"Created turn directory: {self.current_turn_dir}")
def sanitize_log_data(self, data: Any) -> Any:
"""Sanitize log data by replacing large binary data with placeholders.
Args:
data: Data to sanitize
Returns:
Sanitized copy of the data
"""
if isinstance(data, dict):
result = {}
for k, v in data.items():
# Special handling for 'data' field in Anthropic message source
if k == "data" and isinstance(v, str) and len(v) > 1000:
result[k] = f"[BASE64_DATA_LENGTH_{len(v)}]"
# Special handling for the 'media_type' key which indicates we're in an image block
elif k == "media_type" and "image" in str(v):
result[k] = v
# If we're in an image block, look for a sibling 'data' field with base64 content
if (
"data" in result
and isinstance(result["data"], str)
and len(result["data"]) > 1000
):
result["data"] = f"[BASE64_DATA_LENGTH_{len(result['data'])}]"
else:
result[k] = self.sanitize_log_data(v)
return result
elif isinstance(data, list):
return [self.sanitize_log_data(item) for item in data]
elif isinstance(data, str) and len(data) > 1000 and "base64" in data.lower():
return f"[BASE64_DATA_LENGTH_{len(data)}]"
else:
return data
def save_screenshot(self, img_base64: str, action_type: str = "") -> Optional[str]:
"""Save a screenshot to the experiment directory.
Args:
img_base64: Base64 encoded screenshot
action_type: Type of action that triggered the screenshot
Returns:
Path to the saved screenshot or None if there was an error
"""
if not self.current_turn_dir:
return None
try:
# Increment screenshot counter
self.screenshot_count += 1
# Sanitize action_type to ensure valid filename
# Replace characters that are not safe for filenames
sanitized_action = ""
if action_type:
# Replace invalid filename characters with underscores
sanitized_action = re.sub(r'[\\/*?:"<>|]', "_", action_type)
# Limit the length to avoid excessively long filenames
sanitized_action = sanitized_action[:50]
# Create a descriptive filename
timestamp = int(datetime.now().timestamp() * 1000)
action_suffix = f"_{sanitized_action}" if sanitized_action else ""
filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
# Save directly to the turn directory
filepath = os.path.join(self.current_turn_dir, filename)
# Save the screenshot
img_data = base64.b64decode(img_base64)
with open(filepath, "wb") as f:
f.write(img_data)
# Keep track of the file path
self.screenshot_paths.append(filepath)
return filepath
except Exception as e:
logger.error(f"Error saving screenshot: {str(e)}")
return None
def save_action_visualization(
self, img: Image.Image, action_name: str, details: str = ""
) -> str:
"""Save a visualization of an action.
Args:
img: Image to save
action_name: Name of the action
details: Additional details about the action
Returns:
Path to the saved image
"""
if not self.current_turn_dir:
return ""
try:
# Create a descriptive filename
timestamp = int(datetime.now().timestamp() * 1000)
details_suffix = f"_{details}" if details else ""
filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
# Save directly to the turn directory
filepath = os.path.join(self.current_turn_dir, filename)
# Save the image
img.save(filepath)
# Keep track of the file path
self.screenshot_paths.append(filepath)
return filepath
except Exception as e:
logger.error(f"Error saving action visualization: {str(e)}")
return ""
def log_api_call(
self,
call_type: str,
request: Any,
provider: str = "unknown",
model: str = "unknown",
response: Any = None,
error: Optional[Exception] = None,
) -> None:
"""Log API call details to file.
Args:
call_type: Type of API call (request, response, error)
request: Request data
provider: API provider name
model: Model name
response: Response data (for response logs)
error: Error information (for error logs)
"""
if not self.current_turn_dir:
logger.warning("Cannot log API call: current_turn_dir not set")
return
try:
# Create a timestamp for the log file
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
# Create filename based on log type
filename = f"api_call_{timestamp}_{call_type}.json"
filepath = os.path.join(self.current_turn_dir, filename)
# Sanitize data before logging
sanitized_request = self.sanitize_log_data(request)
sanitized_response = self.sanitize_log_data(response) if response is not None else None
# Prepare log data
log_data = {
"timestamp": timestamp,
"provider": provider,
"model": model,
"type": call_type,
"request": sanitized_request,
}
if sanitized_response is not None:
log_data["response"] = sanitized_response
if error is not None:
log_data["error"] = str(error)
# Write to file
with open(filepath, "w") as f:
json.dump(log_data, f, indent=2, default=str)
logger.info(f"Logged API {call_type} to {filepath}")
except Exception as e:
logger.error(f"Error logging API call: {str(e)}")

View File

@@ -1,122 +0,0 @@
"""Base agent loop implementation."""
import logging
import importlib.util
from typing import Dict, Optional, Type, TYPE_CHECKING, Any, cast, Callable, Awaitable
from computer import Computer
from .types import AgentLoop
from .base import BaseLoop
logger = logging.getLogger(__name__)
class LoopFactory:
"""Factory class for creating agent loops."""
# Registry to store loop implementations
_loop_registry: Dict[AgentLoop, Type[BaseLoop]] = {}
@classmethod
def create_loop(
cls,
loop_type: AgentLoop,
api_key: str,
model_name: str,
computer: Computer,
provider: Any = None,
save_trajectory: bool = True,
trajectory_dir: str = "trajectories",
only_n_most_recent_images: Optional[int] = None,
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
provider_base_url: Optional[str] = None,
) -> BaseLoop:
"""Create and return an appropriate loop instance based on type."""
if loop_type == AgentLoop.ANTHROPIC:
# Lazy import AnthropicLoop only when needed
try:
from ..providers.anthropic.loop import AnthropicLoop
except ImportError:
raise ImportError(
"The 'anthropic' provider is not installed. "
"Install it with 'pip install cua-agent[anthropic]'"
)
return AnthropicLoop(
api_key=api_key,
model=model_name,
computer=computer,
save_trajectory=save_trajectory,
base_dir=trajectory_dir,
only_n_most_recent_images=only_n_most_recent_images,
)
elif loop_type == AgentLoop.OPENAI:
# Lazy import OpenAILoop only when needed
try:
from ..providers.openai.loop import OpenAILoop
except ImportError:
raise ImportError(
"The 'openai' provider is not installed. "
"Install it with 'pip install cua-agent[openai]'"
)
return OpenAILoop(
api_key=api_key,
model=model_name,
computer=computer,
save_trajectory=save_trajectory,
base_dir=trajectory_dir,
only_n_most_recent_images=only_n_most_recent_images,
acknowledge_safety_check_callback=acknowledge_safety_check_callback,
)
elif loop_type == AgentLoop.OMNI:
# Lazy import OmniLoop and related classes only when needed
try:
from ..providers.omni.loop import OmniLoop
from ..providers.omni.parser import OmniParser
from .types import LLMProvider
except ImportError:
raise ImportError(
"The 'omni' provider is not installed. "
"Install it with 'pip install cua-agent[all]'"
)
if provider is None:
raise ValueError("Provider is required for OMNI loop type")
# We know provider is the correct type at this point, so cast it
provider_instance = cast(LLMProvider, provider)
return OmniLoop(
provider=provider_instance,
api_key=api_key,
model=model_name,
computer=computer,
save_trajectory=save_trajectory,
base_dir=trajectory_dir,
only_n_most_recent_images=only_n_most_recent_images,
parser=OmniParser(),
provider_base_url=provider_base_url,
)
elif loop_type == AgentLoop.UITARS:
# Lazy import UITARSLoop only when needed
try:
from ..providers.uitars.loop import UITARSLoop
except ImportError:
raise ImportError(
"The 'uitars' provider is not installed. "
"Install it with 'pip install cua-agent[all]'"
)
return UITARSLoop(
api_key=api_key,
model=model_name,
computer=computer,
save_trajectory=save_trajectory,
base_dir=trajectory_dir,
only_n_most_recent_images=only_n_most_recent_images,
provider_base_url=provider_base_url,
provider=provider,
)
else:
raise ValueError(f"Unsupported loop type: {loop_type}")

View File

@@ -1,332 +0,0 @@
"""Message handling utilities for agent."""
import logging
import json
from typing import Any, Dict, List, Optional, Union, Tuple
from dataclasses import dataclass
import re
logger = logging.getLogger(__name__)
@dataclass
class ImageRetentionConfig:
"""Configuration for image retention in messages."""
num_images_to_keep: Optional[int] = None
min_removal_threshold: int = 1
enable_caching: bool = True
def should_retain_images(self) -> bool:
"""Check if image retention is enabled."""
return self.num_images_to_keep is not None and self.num_images_to_keep > 0
class StandardMessageManager:
"""Manages messages in a standardized OpenAI format across different providers."""
def __init__(self, config: Optional[ImageRetentionConfig] = None):
"""Initialize message manager.
Args:
config: Configuration for image retention
"""
self.messages: List[Dict[str, Any]] = []
self.config = config or ImageRetentionConfig()
def add_user_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
"""Add a user message.
Args:
content: Message content (text or multimodal content)
"""
self.messages.append({"role": "user", "content": content})
def add_assistant_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
"""Add an assistant message.
Args:
content: Message content (text or multimodal content)
"""
self.messages.append({"role": "assistant", "content": content})
def add_system_message(self, content: str) -> None:
"""Add a system message.
Args:
content: System message content
"""
self.messages.append({"role": "system", "content": content})
def get_messages(self) -> List[Dict[str, Any]]:
"""Get all messages in standard format.
This method applies image retention policy if configured.
Returns:
List of messages
"""
# If image retention is configured, apply it
if self.config.num_images_to_keep is not None:
return self._apply_image_retention(self.messages)
return self.messages
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Apply image retention policy to messages.
Args:
messages: List of messages
Returns:
List of messages with image retention applied
"""
if not self.config.num_images_to_keep:
return messages
# Find messages with images (both user messages and tool call outputs)
image_messages = []
for msg in messages:
has_image = False
# Check user messages with images
if msg["role"] == "user" and isinstance(msg["content"], list):
has_image = any(
item.get("type") == "image_url" or item.get("type") == "image"
for item in msg["content"]
)
# Check assistant messages with tool calls that have images
elif msg["role"] == "assistant" and isinstance(msg["content"], list):
for item in msg["content"]:
if item.get("type") == "tool_result" and "base64_image" in item:
has_image = True
break
if has_image:
image_messages.append(msg)
# If we don't have more images than the limit, return all messages
if len(image_messages) <= self.config.num_images_to_keep:
return messages
# Get the most recent N images to keep
images_to_keep = image_messages[-self.config.num_images_to_keep :]
images_to_remove = image_messages[: -self.config.num_images_to_keep]
# Create a new message list, removing images from older messages
result = []
for msg in messages:
if msg in images_to_remove:
# Remove images from this message but keep the text content
if msg["role"] == "user" and isinstance(msg["content"], list):
# Keep only text content, remove images
new_content = [
item for item in msg["content"]
if item.get("type") not in ["image_url", "image"]
]
if new_content: # Only add if there's still content
result.append({"role": msg["role"], "content": new_content})
elif msg["role"] == "assistant" and isinstance(msg["content"], list):
# Remove base64_image from tool_result items
new_content = []
for item in msg["content"]:
if item.get("type") == "tool_result" and "base64_image" in item:
# Create a copy without the base64_image
new_item = {k: v for k, v in item.items() if k != "base64_image"}
new_content.append(new_item)
else:
new_content.append(item)
result.append({"role": msg["role"], "content": new_content})
else:
# For other message types, keep as is
result.append(msg)
else:
result.append(msg)
return result
def to_anthropic_format(
self, messages: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], str]:
"""Convert standard OpenAI format messages to Anthropic format.
Args:
messages: List of messages in OpenAI format
Returns:
Tuple containing (anthropic_messages, system_content)
"""
result = []
system_content = ""
# Process messages in order to maintain conversation flow
previous_assistant_tool_use_ids = (
set()
) # Track tool_use_ids in the previous assistant message
for i, msg in enumerate(messages):
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
# Collect system messages for later use
system_content += content + "\n"
continue
if role == "assistant":
# Track tool_use_ids in this assistant message for the next user message
previous_assistant_tool_use_ids = set()
if isinstance(content, list):
for item in content:
if (
isinstance(item, dict)
and item.get("type") == "tool_use"
and "id" in item
):
previous_assistant_tool_use_ids.add(item["id"])
logger.info(
f"Tool use IDs in assistant message #{i}: {previous_assistant_tool_use_ids}"
)
if role in ["user", "assistant"]:
anthropic_msg = {"role": role}
# Convert content based on type
if isinstance(content, str):
# Simple text content
anthropic_msg["content"] = [{"type": "text", "text": content}]
elif isinstance(content, list):
# Convert complex content
anthropic_content = []
for item in content:
item_type = item.get("type", "")
if item_type == "text":
anthropic_content.append({"type": "text", "text": item.get("text", "")})
elif item_type == "image_url":
# Convert OpenAI image format to Anthropic
image_url = item.get("image_url", {}).get("url", "")
if image_url.startswith("data:"):
# Extract base64 data and media type
match = re.match(r"data:(.+);base64,(.+)", image_url)
if match:
media_type, data = match.groups()
anthropic_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data,
},
}
)
else:
# Regular URL
anthropic_content.append(
{
"type": "image",
"source": {
"type": "url",
"url": image_url,
},
}
)
elif item_type == "tool_use":
# Always include tool_use blocks
anthropic_content.append(item)
elif item_type == "tool_result":
# Check if this is a user message AND if the tool_use_id exists in the previous assistant message
tool_use_id = item.get("tool_use_id")
# Only include tool_result if it references a tool_use from the immediately preceding assistant message
if (
role == "user"
and tool_use_id
and tool_use_id in previous_assistant_tool_use_ids
):
anthropic_content.append(item)
logger.info(
f"Including tool_result with tool_use_id: {tool_use_id}"
)
else:
# Convert to text to preserve information
logger.warning(
f"Converting tool_result to text. Tool use ID {tool_use_id} not found in previous assistant message"
)
content_text = "Tool Result: "
if "content" in item:
if isinstance(item["content"], list):
for content_item in item["content"]:
if (
isinstance(content_item, dict)
and content_item.get("type") == "text"
):
content_text += content_item.get("text", "")
elif isinstance(item["content"], str):
content_text += item["content"]
anthropic_content.append({"type": "text", "text": content_text})
anthropic_msg["content"] = anthropic_content
result.append(anthropic_msg)
return result, system_content
def from_anthropic_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert Anthropic format messages to standard OpenAI format.
Args:
messages: List of messages in Anthropic format
Returns:
List of messages in OpenAI format
"""
result = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", [])
if role in ["user", "assistant"]:
openai_msg = {"role": role}
# Simple case: single text block
if len(content) == 1 and content[0].get("type") == "text":
openai_msg["content"] = content[0].get("text", "")
else:
# Complex case: multiple blocks or non-text
openai_content = []
for item in content:
item_type = item.get("type", "")
if item_type == "text":
openai_content.append({"type": "text", "text": item.get("text", "")})
elif item_type == "image":
# Convert Anthropic image to OpenAI format
source = item.get("source", {})
if source.get("type") == "base64":
media_type = source.get("media_type", "image/png")
data = source.get("data", "")
openai_content.append(
{
"type": "image_url",
"image_url": {"url": f"data:{media_type};base64,{data}"},
}
)
else:
# URL
openai_content.append(
{
"type": "image_url",
"image_url": {"url": source.get("url", "")},
}
)
elif item_type in ["tool_use", "tool_result"]:
# Pass through tool-related content
openai_content.append(item)
openai_msg["content"] = openai_content
result.append(openai_msg)
return result

View File

@@ -1,21 +0,0 @@
"""Provider-specific configurations and constants."""
from .types import LLMProvider
# Default models for different providers
DEFAULT_MODELS = {
LLMProvider.OPENAI: "gpt-4o",
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M",
LLMProvider.OAICOMPAT: "Qwen2.5-VL-7B-Instruct",
LLMProvider.MLXVLM: "mlx-community/UI-TARS-1.5-7B-4bit",
}
# Map providers to their environment variable names
ENV_VARS = {
LLMProvider.OPENAI: "OPENAI_API_KEY",
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
LLMProvider.OLLAMA: "none",
LLMProvider.OAICOMPAT: "none", # OpenAI-compatible API typically doesn't require an API key
LLMProvider.MLXVLM: "none", # MLX VLM typically doesn't require an API key
}

View File

@@ -1,142 +0,0 @@
"""Agent telemetry for tracking anonymous usage and feature usage."""
import logging
import os
import platform
import sys
from typing import Dict, Any, Callable
# Import the core telemetry module
TELEMETRY_AVAILABLE = False
# Local fallbacks in case core telemetry isn't available
def _noop(*args: Any, **kwargs: Any) -> None:
"""No-op function for when telemetry is not available."""
pass
# Define default functions with unique names to avoid shadowing
_default_record_event = _noop
_default_increment_counter = _noop
_default_set_dimension = _noop
_default_get_telemetry_client = lambda: None
_default_flush = _noop
_default_is_telemetry_enabled = lambda: False
_default_is_telemetry_globally_disabled = lambda: True
# Set the actual functions to the defaults initially
record_event = _default_record_event
increment_counter = _default_increment_counter
set_dimension = _default_set_dimension
get_telemetry_client = _default_get_telemetry_client
flush = _default_flush
is_telemetry_enabled = _default_is_telemetry_enabled
is_telemetry_globally_disabled = _default_is_telemetry_globally_disabled
logger = logging.getLogger("agent.telemetry")
try:
# Import from core telemetry
from core.telemetry import (
record_event as core_record_event,
increment as core_increment,
get_telemetry_client as core_get_telemetry_client,
flush as core_flush,
is_telemetry_enabled as core_is_telemetry_enabled,
is_telemetry_globally_disabled as core_is_telemetry_globally_disabled,
)
# Override the default functions with actual implementations
record_event = core_record_event
get_telemetry_client = core_get_telemetry_client
flush = core_flush
is_telemetry_enabled = core_is_telemetry_enabled
is_telemetry_globally_disabled = core_is_telemetry_globally_disabled
def increment_counter(counter_name: str, value: int = 1) -> None:
"""Wrapper for increment to maintain backward compatibility."""
if is_telemetry_enabled():
core_increment(counter_name, value)
def set_dimension(name: str, value: Any) -> None:
"""Set a dimension that will be attached to all events."""
logger.debug(f"Setting dimension {name}={value}")
TELEMETRY_AVAILABLE = True
logger.info("Successfully imported telemetry")
except ImportError as e:
logger.warning(f"Could not import telemetry: {e}")
logger.debug("Telemetry not available, using no-op functions")
# Get system info once to use in telemetry
SYSTEM_INFO = {
"os": platform.system().lower(),
"os_version": platform.release(),
"python_version": platform.python_version(),
}
def enable_telemetry() -> bool:
"""Enable telemetry if available.
Returns:
bool: True if telemetry was successfully enabled, False otherwise
"""
global TELEMETRY_AVAILABLE, record_event, increment_counter, get_telemetry_client, flush, is_telemetry_enabled, is_telemetry_globally_disabled
# Check if globally disabled using core function
if TELEMETRY_AVAILABLE and is_telemetry_globally_disabled():
logger.info("Telemetry is globally disabled via environment variable - cannot enable")
return False
# Already enabled
if TELEMETRY_AVAILABLE:
return True
# Try to import and enable
try:
from core.telemetry import (
record_event,
increment,
get_telemetry_client,
flush,
is_telemetry_globally_disabled,
)
# Check again after import
if is_telemetry_globally_disabled():
logger.info("Telemetry is globally disabled via environment variable - cannot enable")
return False
TELEMETRY_AVAILABLE = True
logger.info("Telemetry successfully enabled")
return True
except ImportError as e:
logger.warning(f"Could not enable telemetry: {e}")
return False
def is_telemetry_enabled() -> bool:
"""Check if telemetry is enabled.
Returns:
bool: True if telemetry is enabled, False otherwise
"""
# Use the core function if available, otherwise use our local flag
if TELEMETRY_AVAILABLE:
from core.telemetry import is_telemetry_enabled as core_is_enabled
return core_is_enabled()
return False
def record_agent_initialization() -> None:
"""Record when an agent instance is initialized."""
if TELEMETRY_AVAILABLE and is_telemetry_enabled():
record_event("agent_initialized", SYSTEM_INFO)
# Set dimensions that will be attached to all events
set_dimension("os", SYSTEM_INFO["os"])
set_dimension("os_version", SYSTEM_INFO["os_version"])
set_dimension("python_version", SYSTEM_INFO["python_version"])

View File

@@ -1,32 +0,0 @@
"""Tool-related type definitions."""
from enum import StrEnum
from typing import Dict, Any, Optional
from pydantic import BaseModel, ConfigDict
class ToolInvocationState(StrEnum):
"""States for tool invocation."""
CALL = 'call'
PARTIAL_CALL = 'partial-call'
RESULT = 'result'
class ToolInvocation(BaseModel):
"""Tool invocation type."""
model_config = ConfigDict(extra='forbid')
state: Optional[str] = None
toolCallId: str
toolName: Optional[str] = None
args: Optional[Dict[str, Any]] = None
class ClientAttachment(BaseModel):
"""Client attachment type."""
name: str
contentType: str
url: str
class ToolResult(BaseModel):
"""Result of a tool execution."""
model_config = ConfigDict(extra='forbid')
output: Optional[str] = None
error: Optional[str] = None
metadata: Optional[Dict[str, Any]] = None

View File

@@ -1,21 +0,0 @@
"""Core tools package."""
from .base import BaseTool, ToolResult, ToolError, ToolFailure, CLIResult
from .bash import BaseBashTool
from .collection import ToolCollection
from .computer import BaseComputerTool
from .edit import BaseEditTool
from .manager import BaseToolManager
__all__ = [
"BaseTool",
"ToolResult",
"ToolError",
"ToolFailure",
"CLIResult",
"BaseBashTool",
"BaseComputerTool",
"BaseEditTool",
"ToolCollection",
"BaseToolManager",
]

View File

@@ -1,74 +0,0 @@
"""Abstract base classes for tools that can be used with any provider."""
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, fields, replace
from typing import Any, Dict
class BaseTool(metaclass=ABCMeta):
"""Abstract base class for provider-agnostic tools."""
name: str
@abstractmethod
async def __call__(self, **kwargs) -> Any:
"""Executes the tool with the given arguments."""
...
@abstractmethod
def to_params(self) -> Dict[str, Any]:
"""Convert tool to provider-specific API parameters.
Returns:
Dictionary with tool parameters specific to the LLM provider
"""
raise NotImplementedError
@dataclass(kw_only=True, frozen=True)
class ToolResult:
"""Represents the result of a tool execution."""
output: str | None = None
error: str | None = None
base64_image: str | None = None
system: str | None = None
content: list[dict] | None = None
def __bool__(self):
return any(getattr(self, field.name) for field in fields(self))
def __add__(self, other: "ToolResult"):
def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True):
if field and other_field:
if concatenate:
return field + other_field
raise ValueError("Cannot combine tool results")
return field or other_field
return ToolResult(
output=combine_fields(self.output, other.output),
error=combine_fields(self.error, other.error),
base64_image=combine_fields(self.base64_image, other.base64_image, False),
system=combine_fields(self.system, other.system),
content=self.content or other.content, # Use first non-None content
)
def replace(self, **kwargs):
"""Returns a new ToolResult with the given fields replaced."""
return replace(self, **kwargs)
class CLIResult(ToolResult):
"""A ToolResult that can be rendered as a CLI output."""
class ToolFailure(ToolResult):
"""A ToolResult that represents a failure."""
class ToolError(Exception):
"""Raised when a tool encounters an error."""
def __init__(self, message):
self.message = message

View File

@@ -1,52 +0,0 @@
"""Abstract base bash/shell tool implementation."""
import asyncio
import logging
from abc import abstractmethod
from typing import Any, Dict, Tuple
from computer.computer import Computer
from .base import BaseTool, ToolResult
class BaseBashTool(BaseTool):
"""Base class for bash/shell command execution tools across different providers."""
name = "bash"
logger = logging.getLogger(__name__)
computer: Computer
def __init__(self, computer: Computer):
"""Initialize the BashTool.
Args:
computer: Computer instance, may be used for related operations
"""
self.computer = computer
async def run_command(self, command: str) -> Tuple[int, str, str]:
"""Run a shell command and return exit code, stdout, and stderr.
Args:
command: Shell command to execute
Returns:
Tuple containing (exit_code, stdout, stderr)
"""
try:
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
return process.returncode or 0, stdout.decode(), stderr.decode()
except Exception as e:
self.logger.error(f"Error running command: {str(e)}")
return 1, "", str(e)
@abstractmethod
async def __call__(self, **kwargs) -> ToolResult:
"""Execute the tool with the provided arguments."""
raise NotImplementedError

View File

@@ -1,46 +0,0 @@
"""Collection classes for managing multiple tools."""
from typing import Any, Dict, List, Type
from .base import (
BaseTool,
ToolError,
ToolFailure,
ToolResult,
)
class ToolCollection:
"""A collection of tools that can be used with any provider."""
def __init__(self, *tools: BaseTool):
self.tools = tools
self.tool_map = {tool.name: tool for tool in tools}
def to_params(self) -> List[Dict[str, Any]]:
"""Convert all tools to provider-specific parameters.
Returns:
List of dictionaries with tool parameters
"""
return [tool.to_params() for tool in self.tools]
async def run(self, *, name: str, tool_input: Dict[str, Any]) -> ToolResult:
"""Run a tool with the given input.
Args:
name: Name of the tool to run
tool_input: Input parameters for the tool
Returns:
Result of the tool execution
"""
tool = self.tool_map.get(name)
if not tool:
return ToolFailure(error=f"Tool {name} is invalid")
try:
return await tool(**tool_input)
except ToolError as e:
return ToolFailure(error=e.message)
except Exception as e:
return ToolFailure(error=f"Unexpected error in tool {name}: {str(e)}")

View File

@@ -1,113 +0,0 @@
"""Abstract base computer tool implementation."""
import asyncio
import base64
import io
import logging
from abc import abstractmethod
from typing import Any, Dict, Optional, Tuple
from PIL import Image
from computer.computer import Computer
from .base import BaseTool, ToolError, ToolResult
class BaseComputerTool(BaseTool):
"""Base class for computer interaction tools across different providers."""
name = "computer"
logger = logging.getLogger(__name__)
width: Optional[int] = None
height: Optional[int] = None
display_num: Optional[int] = None
computer: Computer
_screenshot_delay = 1.0 # Default delay for most platforms
_scaling_enabled = True
def __init__(self, computer: Computer):
"""Initialize the ComputerTool.
Args:
computer: Computer instance for screen interactions
"""
self.computer = computer
async def initialize_dimensions(self):
"""Initialize screen dimensions from the computer interface."""
display_size = await self.computer.interface.get_screen_size()
self.width = display_size["width"]
self.height = display_size["height"]
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
@property
def options(self) -> Dict[str, Any]:
"""Get the options for the tool.
Returns:
Dictionary with tool options
"""
if self.width is None or self.height is None:
raise RuntimeError(
"Screen dimensions not initialized. Call initialize_dimensions() first."
)
return {
"display_width_px": self.width,
"display_height_px": self.height,
"display_number": self.display_num,
}
async def resize_screenshot_if_needed(self, screenshot: bytes) -> bytes:
"""Resize a screenshot to match the expected dimensions.
Args:
screenshot: Raw screenshot data
Returns:
Resized screenshot data
"""
if self.width is None or self.height is None:
raise ToolError("Screen dimensions not initialized")
try:
img = Image.open(io.BytesIO(screenshot))
if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info):
img = img.convert("RGB")
# Resize if dimensions don't match
if img.size != (self.width, self.height):
self.logger.info(
f"Scaling image from {img.size} to {self.width}x{self.height} to match screen dimensions"
)
img = img.resize((self.width, self.height), Image.Resampling.LANCZOS)
# Save back to bytes
buffer = io.BytesIO()
img.save(buffer, format="PNG")
return buffer.getvalue()
return screenshot
except Exception as e:
self.logger.error(f"Error during screenshot resizing: {str(e)}")
raise ToolError(f"Failed to resize screenshot: {str(e)}")
async def screenshot(self) -> ToolResult:
"""Take a screenshot and return it as a ToolResult with base64-encoded image.
Returns:
ToolResult with the screenshot
"""
try:
screenshot = await self.computer.interface.screenshot()
screenshot = await self.resize_screenshot_if_needed(screenshot)
return ToolResult(base64_image=base64.b64encode(screenshot).decode())
except Exception as e:
self.logger.error(f"Error taking screenshot: {str(e)}")
return ToolResult(error=f"Failed to take screenshot: {str(e)}")
@abstractmethod
async def __call__(self, **kwargs) -> ToolResult:
"""Execute the tool with the provided arguments."""
raise NotImplementedError

View File

@@ -1,67 +0,0 @@
"""Abstract base edit tool implementation."""
import asyncio
import logging
import os
from abc import abstractmethod
from pathlib import Path
from typing import Any, Dict, Optional
from computer.computer import Computer
from .base import BaseTool, ToolError, ToolResult
class BaseEditTool(BaseTool):
"""Base class for text editor tools across different providers."""
name = "edit"
logger = logging.getLogger(__name__)
computer: Computer
def __init__(self, computer: Computer):
"""Initialize the EditTool.
Args:
computer: Computer instance, may be used for related operations
"""
self.computer = computer
async def read_file(self, path: str) -> str:
"""Read a file and return its contents.
Args:
path: Path to the file to read
Returns:
File contents as a string
"""
try:
path_obj = Path(path)
if not path_obj.exists():
raise ToolError(f"File does not exist: {path}")
return path_obj.read_text()
except Exception as e:
self.logger.error(f"Error reading file: {str(e)}")
raise ToolError(f"Failed to read file: {str(e)}")
async def write_file(self, path: str, content: str) -> None:
"""Write content to a file.
Args:
path: Path to the file to write
content: Content to write to the file
"""
try:
path_obj = Path(path)
# Create parent directories if they don't exist
path_obj.parent.mkdir(parents=True, exist_ok=True)
path_obj.write_text(content)
except Exception as e:
self.logger.error(f"Error writing file: {str(e)}")
raise ToolError(f"Failed to write file: {str(e)}")
@abstractmethod
async def __call__(self, **kwargs) -> ToolResult:
"""Execute the tool with the provided arguments."""
raise NotImplementedError

View File

@@ -1,56 +0,0 @@
"""Tool manager for initializing and running tools."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from computer.computer import Computer
from .base import BaseTool, ToolResult
from .collection import ToolCollection
class BaseToolManager(ABC):
"""Base class for tool managers across different providers."""
def __init__(self, computer: Computer):
"""Initialize the tool manager.
Args:
computer: Computer instance for computer-related tools
"""
self.computer = computer
self.tools: ToolCollection | None = None
@abstractmethod
def _initialize_tools(self) -> ToolCollection:
"""Initialize all available tools."""
...
async def initialize(self) -> None:
"""Initialize tool-specific requirements and create tool collection."""
await self._initialize_tools_specific()
self.tools = self._initialize_tools()
@abstractmethod
async def _initialize_tools_specific(self) -> None:
"""Initialize provider-specific tool requirements."""
...
@abstractmethod
def get_tool_params(self) -> List[Dict[str, Any]]:
"""Get tool parameters for API calls."""
...
async def execute_tool(self, name: str, tool_input: Dict[str, Any]) -> ToolResult:
"""Execute a tool with the given input.
Args:
name: Name of the tool to execute
tool_input: Input parameters for the tool
Returns:
Result of the tool execution
"""
if self.tools is None:
raise RuntimeError("Tools not initialized. Call initialize() first.")
return await self.tools.run(name=name, tool_input=tool_input)

View File

@@ -1,88 +0,0 @@
"""Core type definitions."""
from typing import Any, Dict, List, Optional, TypedDict, Union
from enum import StrEnum
from dataclasses import dataclass
class AgentLoop(StrEnum):
"""Enumeration of available loop types."""
ANTHROPIC = "anthropic" # Anthropic implementation
OMNI = "omni" # OmniLoop implementation
OPENAI = "openai" # OpenAI implementation
OLLAMA = "ollama" # OLLAMA implementation
UITARS = "uitars" # UI-TARS implementation
# Add more loop types as needed
class LLMProvider(StrEnum):
"""Supported LLM providers."""
ANTHROPIC = "anthropic"
OPENAI = "openai"
OLLAMA = "ollama"
OAICOMPAT = "oaicompat"
MLXVLM= "mlxvlm"
@dataclass
class LLM:
"""Configuration for LLM model and provider."""
provider: LLMProvider
name: Optional[str] = None
provider_base_url: Optional[str] = None
def __post_init__(self):
"""Set default model name if not provided."""
if self.name is None:
from .provider_config import DEFAULT_MODELS
self.name = DEFAULT_MODELS.get(self.provider)
# Set default provider URL if none provided
if self.provider_base_url is None and self.provider == LLMProvider.OAICOMPAT:
# Default for vLLM
self.provider_base_url = "http://localhost:8000/v1"
# Common alternatives:
# - LM Studio: "http://localhost:1234/v1"
# - LocalAI: "http://localhost:8080/v1"
# - Ollama with OpenAI compatible API: "http://localhost:11434/v1"
# For backward compatibility
LLMModel = LLM
Model = LLM
class AgentResponse(TypedDict, total=False):
"""Agent response format."""
id: str
object: str
created_at: int
status: str
error: Optional[str]
incomplete_details: Optional[Any]
instructions: Optional[Any]
max_output_tokens: Optional[int]
model: str
output: List[Dict[str, Any]]
parallel_tool_calls: bool
previous_response_id: Optional[str]
reasoning: Dict[str, str]
store: bool
temperature: float
text: Dict[str, Dict[str, str]]
tool_choice: str
tools: List[Dict[str, Union[str, int]]]
top_p: float
truncation: str
usage: Dict[str, Any]
user: Optional[str]
metadata: Dict[str, Any]
response: Dict[str, List[Dict[str, Any]]]
# Additional fields for error responses
role: str
content: Union[str, List[Dict[str, Any]]]

View File

@@ -1,197 +0,0 @@
"""Core visualization utilities for agents."""
import logging
import base64
from typing import Dict, Tuple
from PIL import Image, ImageDraw
from io import BytesIO
logger = logging.getLogger(__name__)
def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
"""Visualize a click action by drawing a circle on the screenshot.
Args:
x: X coordinate of the click
y: Y coordinate of the click
img_base64: Base64-encoded screenshot
Returns:
PIL Image with visualization
"""
try:
# Decode the base64 image
image_data = base64.b64decode(img_base64)
img = Image.open(BytesIO(image_data))
# Create a copy to draw on
draw_img = img.copy()
draw = ImageDraw.Draw(draw_img)
# Draw a circle at the click location
radius = 15
draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], outline="red", width=3)
# Draw crosshairs
line_length = 20
draw.line([(x - line_length, y), (x + line_length, y)], fill="red", width=3)
draw.line([(x, y - line_length), (x, y + line_length)], fill="red", width=3)
return draw_img
except Exception as e:
logger.error(f"Error visualizing click: {str(e)}")
# Return a blank image as fallback
return Image.new("RGB", (800, 600), "white")
def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
"""Visualize a scroll action by drawing arrows on the screenshot.
Args:
direction: Direction of scroll ('up' or 'down')
clicks: Number of scroll clicks
img_base64: Base64-encoded screenshot
Returns:
PIL Image with visualization
"""
try:
# Decode the base64 image
image_data = base64.b64decode(img_base64)
img = Image.open(BytesIO(image_data))
# Create a copy to draw on
draw_img = img.copy()
draw = ImageDraw.Draw(draw_img)
# Calculate parameters for visualization
width, height = img.size
center_x = width // 2
# Draw arrows to indicate scrolling
arrow_length = min(100, height // 4)
arrow_width = 30
num_arrows = min(clicks, 3) # Don't draw too many arrows
# Calculate starting position
if direction == "down":
start_y = height // 3
arrow_dir = 1 # Down
else:
start_y = height * 2 // 3
arrow_dir = -1 # Up
# Draw the arrows
for i in range(num_arrows):
y_pos = start_y + (i * arrow_length * arrow_dir * 0.7)
arrow_top = (center_x, y_pos)
arrow_bottom = (center_x, y_pos + arrow_length * arrow_dir)
# Draw the main line
draw.line([arrow_top, arrow_bottom], fill="red", width=5)
# Draw the arrowhead
arrowhead_size = 20
if direction == "down":
draw.line(
[
(center_x - arrow_width // 2, arrow_bottom[1] - arrowhead_size),
arrow_bottom,
(center_x + arrow_width // 2, arrow_bottom[1] - arrowhead_size),
],
fill="red",
width=5,
)
else:
draw.line(
[
(center_x - arrow_width // 2, arrow_bottom[1] + arrowhead_size),
arrow_bottom,
(center_x + arrow_width // 2, arrow_bottom[1] + arrowhead_size),
],
fill="red",
width=5,
)
return draw_img
except Exception as e:
logger.error(f"Error visualizing scroll: {str(e)}")
# Return a blank image as fallback
return Image.new("RGB", (800, 600), "white")
def calculate_element_center(bbox: Dict[str, float], width: int, height: int) -> Tuple[int, int]:
"""Calculate the center point of a UI element.
Args:
bbox: Bounding box dictionary with x1, y1, x2, y2 coordinates (0-1 normalized)
width: Screen width in pixels
height: Screen height in pixels
Returns:
(x, y) tuple with pixel coordinates
"""
center_x = int((bbox["x1"] + bbox["x2"]) / 2 * width)
center_y = int((bbox["y1"] + bbox["y2"]) / 2 * height)
return center_x, center_y
class VisualizationHelper:
"""Helper class for visualizing agent actions."""
def __init__(self, agent):
"""Initialize visualization helper.
Args:
agent: Reference to the agent that will use this helper
"""
self.agent = agent
def visualize_action(self, x: int, y: int, img_base64: str) -> None:
"""Visualize a click action by drawing on the screenshot."""
if (
not self.agent.save_trajectory
or not hasattr(self.agent, "experiment_manager")
or not self.agent.experiment_manager
):
return
try:
# Use the visualization utility
img = visualize_click(x, y, img_base64)
# Save the visualization
self.agent.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
except Exception as e:
logger.error(f"Error visualizing action: {str(e)}")
def visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
"""Visualize a scroll action by drawing arrows on the screenshot."""
if (
not self.agent.save_trajectory
or not hasattr(self.agent, "experiment_manager")
or not self.agent.experiment_manager
):
return
try:
# Use the visualization utility
img = visualize_scroll(direction, clicks, img_base64)
# Save the visualization
self.agent.experiment_manager.save_action_visualization(
img, "scroll", f"{direction}_{clicks}"
)
except Exception as e:
logger.error(f"Error visualizing scroll: {str(e)}")
def save_action_visualization(
self, img: Image.Image, action_name: str, details: str = ""
) -> str:
"""Save a visualization of an action."""
if hasattr(self.agent, "experiment_manager") and self.agent.experiment_manager:
return self.agent.experiment_manager.save_action_visualization(
img, action_name, details
)
return ""

View File

@@ -0,0 +1,90 @@
"""
Decorators for agent - agent_loop decorator
"""
import asyncio
import inspect
from typing import Dict, List, Any, Callable, Optional
from functools import wraps
from .types import AgentLoopInfo
# Global registry
_agent_loops: List[AgentLoopInfo] = []
def agent_loop(models: str, priority: int = 0):
"""
Decorator to register an agent loop function.
Args:
models: Regex pattern to match supported models
priority: Priority for loop selection (higher = more priority)
"""
def decorator(func: Callable):
# Validate function signature
sig = inspect.signature(func)
required_params = {'messages', 'model'}
func_params = set(sig.parameters.keys())
if not required_params.issubset(func_params):
missing = required_params - func_params
raise ValueError(f"Agent loop function must have parameters: {missing}")
# Register the loop
loop_info = AgentLoopInfo(
func=func,
models_regex=models,
priority=priority
)
_agent_loops.append(loop_info)
# Sort by priority (highest first)
_agent_loops.sort(key=lambda x: x.priority, reverse=True)
@wraps(func)
async def wrapper(*args, **kwargs):
# Wrap the function in an asyncio.Queue for cancellation support
queue = asyncio.Queue()
task = None
try:
# Create a task that can be cancelled
async def run_loop():
try:
result = await func(*args, **kwargs)
await queue.put(('result', result))
except Exception as e:
await queue.put(('error', e))
task = asyncio.create_task(run_loop())
# Wait for result or cancellation
event_type, data = await queue.get()
if event_type == 'error':
raise data
return data
except asyncio.CancelledError:
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
raise
return wrapper
return decorator
def get_agent_loops() -> List[AgentLoopInfo]:
"""Get all registered agent loops"""
return _agent_loops.copy()
def find_agent_loop(model: str) -> Optional[AgentLoopInfo]:
"""Find the best matching agent loop for a model"""
for loop_info in _agent_loops:
if loop_info.matches_model(model):
return loop_info
return None

View File

@@ -0,0 +1,11 @@
"""
Agent loops for agent
"""
# Import the loops to register them
from . import anthropic
from . import openai
from . import uitars
from . import omniparser
__all__ = ["anthropic", "openai", "uitars", "omniparser"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,339 @@
"""
OpenAI computer-use-preview agent loop implementation using liteLLM
"""
import asyncio
import json
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
import litellm
import inspect
import base64
from ..decorators import agent_loop
from ..types import Messages, AgentResponse, Tools
SOM_TOOL_SCHEMA = {
"type": "function",
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment"
],
"description": "The action to perform"
},
"element_id": {
"type": "integer",
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
},
"start_element_id": {
"type": "integer",
"description": "The ID of the element to start dragging from (required for drag action)"
},
"end_element_id": {
"type": "integer",
"description": "The ID of the element to drag to (required for drag action)"
},
"text": {
"type": "string",
"description": "The text to type (required for type action)"
},
"keys": {
"type": "string",
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
},
"button": {
"type": "string",
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
},
},
"required": [
"action"
]
}
}
OMNIPARSER_AVAILABLE = False
try:
from som import OmniParser
OMNIPARSER_AVAILABLE = True
except ImportError:
pass
OMNIPARSER_SINGLETON = None
def get_parser():
global OMNIPARSER_SINGLETON
if OMNIPARSER_SINGLETON is None:
OMNIPARSER_SINGLETON = OmniParser()
return OMNIPARSER_SINGLETON
def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Get the last computer_call_output message from a messages list.
Args:
messages: List of messages to search through
Returns:
The last computer_call_output message dict, or None if not found
"""
for message in reversed(messages):
if isinstance(message, dict) and message.get("type") == "computer_call_output":
return message
return None
def _prepare_tools_for_omniparser(tool_schemas: List[Dict[str, Any]]) -> Tuple[Tools, dict]:
"""Prepare tools for OpenAI API format"""
omniparser_tools = []
id2xy = dict()
for schema in tool_schemas:
if schema["type"] == "computer":
omniparser_tools.append(SOM_TOOL_SCHEMA)
if "id2xy" in schema:
id2xy = schema["id2xy"]
else:
schema["id2xy"] = id2xy
elif schema["type"] == "function":
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
# Schema should be: {type, name, description, parameters}
omniparser_tools.append({ "type": "function", **schema["function"] })
return omniparser_tools, id2xy
async def replace_function_with_computer_call(item: Dict[str, Any], id2xy: Dict[int, Tuple[float, float]]):
item_type = item.get("type")
def _get_xy(element_id: Optional[int]) -> Union[Tuple[float, float], Tuple[None, None]]:
if element_id is None:
return (None, None)
return id2xy.get(element_id, (None, None))
if item_type == "function_call":
fn_name = item.get("name")
fn_args = json.loads(item.get("arguments", "{}"))
item_id = item.get("id")
call_id = item.get("call_id")
if fn_name == "computer":
action = fn_args.get("action")
element_id = fn_args.get("element_id")
start_element_id = fn_args.get("start_element_id")
end_element_id = fn_args.get("end_element_id")
text = fn_args.get("text")
keys = fn_args.get("keys")
button = fn_args.get("button")
scroll_x = fn_args.get("scroll_x")
scroll_y = fn_args.get("scroll_y")
x, y = _get_xy(element_id)
start_x, start_y = _get_xy(start_element_id)
end_x, end_y = _get_xy(end_element_id)
action_args = {
"type": action,
"x": x,
"y": y,
"start_x": start_x,
"start_y": start_y,
"end_x": end_x,
"end_y": end_y,
"text": text,
"keys": keys,
"button": button,
"scroll_x": scroll_x,
"scroll_y": scroll_y
}
# Remove None values to keep the JSON clean
action_args = {k: v for k, v in action_args.items() if v is not None}
return [{
"type": "computer_call",
"action": action_args,
"id": item_id,
"call_id": call_id,
"status": "completed"
}]
return [item]
async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[Tuple[float, float], int]):
"""
Convert computer_call back to function_call format.
Also handles computer_call_output -> function_call_output conversion.
Args:
item: The item to convert
xy2id: Mapping from (x, y) coordinates to element IDs
"""
item_type = item.get("type")
def _get_element_id(x: Optional[float], y: Optional[float]) -> Optional[int]:
"""Get element ID from coordinates, return None if coordinates are None"""
if x is None or y is None:
return None
return xy2id.get((x, y))
if item_type == "computer_call":
action_data = item.get("action", {})
# Extract coordinates and convert back to element IDs
element_id = _get_element_id(action_data.get("x"), action_data.get("y"))
start_element_id = _get_element_id(action_data.get("start_x"), action_data.get("start_y"))
end_element_id = _get_element_id(action_data.get("end_x"), action_data.get("end_y"))
# Build function arguments
fn_args = {
"action": action_data.get("type"),
"element_id": element_id,
"start_element_id": start_element_id,
"end_element_id": end_element_id,
"text": action_data.get("text"),
"keys": action_data.get("keys"),
"button": action_data.get("button"),
"scroll_x": action_data.get("scroll_x"),
"scroll_y": action_data.get("scroll_y")
}
# Remove None values to keep the JSON clean
fn_args = {k: v for k, v in fn_args.items() if v is not None}
return [{
"type": "function_call",
"name": "computer",
"arguments": json.dumps(fn_args),
"id": item.get("id"),
"call_id": item.get("call_id"),
"status": "completed",
# Fall back to string representation
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})"
}]
elif item_type == "computer_call_output":
# Simple conversion: computer_call_output -> function_call_output
return [{
"type": "function_call_output",
"call_id": item.get("call_id"),
"content": [item.get("output")],
"id": item.get("id"),
"status": "completed"
}]
return [item]
@agent_loop(models=r"omniparser\+.*|omni\+.*", priority=10)
async def omniparser_loop(
messages: Messages,
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
"""
OpenAI computer-use-preview agent loop using liteLLM responses.
Supports OpenAI's computer use preview models.
"""
if not OMNIPARSER_AVAILABLE:
raise ValueError("omniparser loop requires som to be installed. Install it with `pip install cua-som`.")
tools = tools or []
llm_model = model.split('+')[-1]
# Prepare tools for OpenAI API
openai_tools, id2xy = _prepare_tools_for_omniparser(tools)
# Find last computer_call_output
last_computer_call_output = get_last_computer_call_output(messages)
if last_computer_call_output:
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
image_data = image_url.split(",")[-1]
if image_data:
parser = get_parser()
result = parser.parse(image_data)
if _on_screenshot:
await _on_screenshot(result.annotated_image_base64, "annotated_image")
for element in result.elements:
id2xy[element.id] = ((element.bbox.x1 + element.bbox.x2) / 2, (element.bbox.y1 + element.bbox.y2) / 2)
# handle computer calls -> function calls
new_messages = []
for message in messages:
if not isinstance(message, dict):
message = message.__dict__
new_messages += await replace_computer_call_with_function(message, id2xy)
messages = new_messages
# Prepare API call kwargs
api_kwargs = {
"model": llm_model,
"input": messages,
"tools": openai_tools if openai_tools else None,
"stream": stream,
"reasoning": {"summary": "concise"},
"truncation": "auto",
"num_retries": max_retries,
**kwargs
}
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
print(str(api_kwargs)[:1000])
# Use liteLLM responses
response = await litellm.aresponses(**api_kwargs)
# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)
# Extract usage information
response.usage = {
**response.usage.model_dump(),
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(response.usage)
# handle som function calls -> xy computer calls
new_output = []
for i in range(len(response.output)):
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy)
response.output = new_output
return response

View File

@@ -0,0 +1,95 @@
"""
OpenAI computer-use-preview agent loop implementation using liteLLM
"""
import asyncio
import json
from typing import Dict, List, Any, AsyncGenerator, Union, Optional
import litellm
from ..decorators import agent_loop
from ..types import Messages, AgentResponse, Tools
def _map_computer_tool_to_openai(computer_tool: Any) -> Dict[str, Any]:
"""Map a computer tool to OpenAI's computer-use-preview tool schema"""
return {
"type": "computer_use_preview",
"display_width": getattr(computer_tool, 'display_width', 1024),
"display_height": getattr(computer_tool, 'display_height', 768),
"environment": getattr(computer_tool, 'environment', "linux") # mac, windows, linux, browser
}
def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools:
"""Prepare tools for OpenAI API format"""
openai_tools = []
for schema in tool_schemas:
if schema["type"] == "computer":
# Map computer tool to OpenAI format
openai_tools.append(_map_computer_tool_to_openai(schema["computer"]))
elif schema["type"] == "function":
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
# Schema should be: {type, name, description, parameters}
openai_tools.append({ "type": "function", **schema["function"] })
return openai_tools
@agent_loop(models=r".*computer-use-preview.*", priority=10)
async def openai_computer_use_loop(
messages: Messages,
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
"""
OpenAI computer-use-preview agent loop using liteLLM responses.
Supports OpenAI's computer use preview models.
"""
tools = tools or []
# Prepare tools for OpenAI API
openai_tools = _prepare_tools_for_openai(tools)
# Prepare API call kwargs
api_kwargs = {
"model": model,
"input": messages,
"tools": openai_tools if openai_tools else None,
"stream": stream,
"reasoning": {"summary": "concise"},
"truncation": "auto",
"num_retries": max_retries,
**kwargs
}
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
# Use liteLLM responses
response = await litellm.aresponses(**api_kwargs)
# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)
# Extract usage information
response.usage = {
**response.usage.model_dump(),
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(response.usage)
return response

View File

@@ -0,0 +1,688 @@
"""
UITARS agent loop implementation using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B
"""
import asyncio
from ctypes import cast
import json
import base64
import math
import re
import ast
from typing import Dict, List, Any, AsyncGenerator, Union, Optional
from io import BytesIO
from PIL import Image
import litellm
from litellm.types.utils import ModelResponse
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
from litellm.responses.utils import Usage
from openai.types.responses.response_computer_tool_call_param import ActionType, ResponseComputerToolCallParam
from openai.types.responses.response_input_param import ComputerCallOutput
from openai.types.responses.response_output_message_param import ResponseOutputMessageParam
from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam, Summary
from ..decorators import agent_loop
from ..types import Messages, AgentResponse, Tools
from ..responses import (
make_reasoning_item,
make_output_text_item,
make_click_item,
make_double_click_item,
make_drag_item,
make_keypress_item,
make_scroll_item,
make_type_item,
make_wait_item,
make_input_image_item
)
# Constants from reference code
IMAGE_FACTOR = 28
MIN_PIXELS = 100 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
FINISH_WORD = "finished"
WAIT_WORD = "wait"
ENV_FAIL_WORD = "error_env"
CALL_USER = "call_user"
# Action space prompt for UITARS
UITARS_ACTION_SPACE = """
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
"""
UITARS_PROMPT_TEMPLATE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
{action_space}
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
def round_by_factor(number: float, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: float, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: float, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
def escape_single_quotes(text):
"""Escape single quotes in text for safe string formatting."""
pattern = r"(?<!\\)'"
return re.sub(pattern, r"\\'", text)
def parse_action(action_str):
"""Parse action string into structured format."""
try:
node = ast.parse(action_str, mode='eval')
if not isinstance(node, ast.Expression):
raise ValueError("Not an expression")
call = node.body
if not isinstance(call, ast.Call):
raise ValueError("Not a function call")
# Get function name
if isinstance(call.func, ast.Name):
func_name = call.func.id
elif isinstance(call.func, ast.Attribute):
func_name = call.func.attr
else:
func_name = None
# Get keyword arguments
kwargs = {}
for kw in call.keywords:
key = kw.arg
if isinstance(kw.value, ast.Constant):
value = kw.value.value
elif isinstance(kw.value, ast.Str): # Compatibility with older Python
value = kw.value.s
else:
value = None
kwargs[key] = value
return {
'function': func_name,
'args': kwargs
}
except Exception as e:
print(f"Failed to parse action '{action_str}': {e}")
return None
def parse_uitars_response(text: str, image_width: int, image_height: int) -> List[Dict[str, Any]]:
"""Parse UITARS model response into structured actions."""
text = text.strip()
# Extract thought
thought = None
if text.startswith("Thought:"):
thought_match = re.search(r"Thought: (.+?)(?=\s*Action:|$)", text, re.DOTALL)
if thought_match:
thought = thought_match.group(1).strip()
# Extract action
if "Action:" not in text:
raise ValueError("No Action found in response")
action_str = text.split("Action:")[-1].strip()
# Handle special case for type actions
if "type(content" in action_str:
def escape_quotes(match):
return match.group(1)
pattern = r"type\(content='(.*?)'\)"
content = re.sub(pattern, escape_quotes, action_str)
action_str = escape_single_quotes(content)
action_str = "type(content='" + action_str + "')"
# Parse the action
parsed_action = parse_action(action_str.replace("\n", "\\n").lstrip())
if parsed_action is None:
raise ValueError(f"Action can't parse: {action_str}")
action_type = parsed_action["function"]
params = parsed_action["args"]
# Process parameters
action_inputs = {}
for param_name, param in params.items():
if param == "":
continue
param = str(param).lstrip()
action_inputs[param_name.strip()] = param
# Handle coordinate parameters
if "start_box" in param_name or "end_box" in param_name:
# Parse coordinates like '(x,y)' or '(x1,y1,x2,y2)'
numbers = param.replace("(", "").replace(")", "").split(",")
float_numbers = [float(num.strip()) / 1000 for num in numbers] # Normalize to 0-1 range
if len(float_numbers) == 2:
# Single point, duplicate for box format
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
action_inputs[param_name.strip()] = str(float_numbers)
return [{
"thought": thought,
"action_type": action_type,
"action_inputs": action_inputs,
"text": text
}]
def convert_to_computer_actions(parsed_responses: List[Dict[str, Any]], image_width: int, image_height: int) -> List[ResponseComputerToolCallParam | ResponseOutputMessageParam]:
"""Convert parsed UITARS responses to computer actions."""
computer_actions = []
for response in parsed_responses:
action_type = response.get("action_type")
action_inputs = response.get("action_inputs", {})
if action_type == "finished":
finished_text = action_inputs.get("content", "Task completed successfully.")
computer_actions.append(make_output_text_item(finished_text))
break
elif action_type == "wait":
computer_actions.append(make_wait_item())
elif action_type == "call_user":
computer_actions.append(make_output_text_item("I need assistance from the user to proceed with this task."))
elif action_type in ["click", "left_single"]:
start_box = action_inputs.get("start_box")
if start_box:
coords = eval(start_box)
x = int((coords[0] + coords[2]) / 2 * image_width)
y = int((coords[1] + coords[3]) / 2 * image_height)
computer_actions.append(make_click_item(x, y, "left"))
elif action_type == "double_click":
start_box = action_inputs.get("start_box")
if start_box:
coords = eval(start_box)
x = int((coords[0] + coords[2]) / 2 * image_width)
y = int((coords[1] + coords[3]) / 2 * image_height)
computer_actions.append(make_double_click_item(x, y))
elif action_type == "right_click":
start_box = action_inputs.get("start_box")
if start_box:
coords = eval(start_box)
x = int((coords[0] + coords[2]) / 2 * image_width)
y = int((coords[1] + coords[3]) / 2 * image_height)
computer_actions.append(make_click_item(x, y, "right"))
elif action_type == "type":
content = action_inputs.get("content", "")
computer_actions.append(make_type_item(content))
elif action_type == "hotkey":
key = action_inputs.get("key", "")
keys = key.split()
computer_actions.append(make_keypress_item(keys))
elif action_type == "press":
key = action_inputs.get("key", "")
computer_actions.append(make_keypress_item([key]))
elif action_type == "scroll":
start_box = action_inputs.get("start_box")
direction = action_inputs.get("direction", "down")
if start_box:
coords = eval(start_box)
x = int((coords[0] + coords[2]) / 2 * image_width)
y = int((coords[1] + coords[3]) / 2 * image_height)
else:
x, y = image_width // 2, image_height // 2
scroll_y = 5 if "up" in direction.lower() else -5
computer_actions.append(make_scroll_item(x, y, 0, scroll_y))
elif action_type == "drag":
start_box = action_inputs.get("start_box")
end_box = action_inputs.get("end_box")
if start_box and end_box:
start_coords = eval(start_box)
end_coords = eval(end_box)
start_x = int((start_coords[0] + start_coords[2]) / 2 * image_width)
start_y = int((start_coords[1] + start_coords[3]) / 2 * image_height)
end_x = int((end_coords[0] + end_coords[2]) / 2 * image_width)
end_y = int((end_coords[1] + end_coords[3]) / 2 * image_height)
path = [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}]
computer_actions.append(make_drag_item(path))
return computer_actions
def pil_to_base64(image: Image.Image) -> str:
"""Convert PIL image to base64 string."""
buffer = BytesIO()
image.save(buffer, format="PNG")
return base64.b64encode(buffer.getvalue()).decode("utf-8")
def process_image_for_uitars(image_data: str, max_pixels: int = MAX_PIXELS, min_pixels: int = MIN_PIXELS) -> tuple[Image.Image, int, int]:
"""Process image for UITARS model input."""
# Decode base64 image
if image_data.startswith('data:image'):
image_data = image_data.split(',')[1]
image_bytes = base64.b64decode(image_data)
image = Image.open(BytesIO(image_bytes))
original_width, original_height = image.size
# Resize image according to UITARS requirements
if image.width * image.height > max_pixels:
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
width = int(image.width * resize_factor)
height = int(image.height * resize_factor)
image = image.resize((width, height))
if image.width * image.height < min_pixels:
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
width = math.ceil(image.width * resize_factor)
height = math.ceil(image.height * resize_factor)
image = image.resize((width, height))
if image.mode != "RGB":
image = image.convert("RGB")
return image, original_width, original_height
def sanitize_message(msg: Any) -> Any:
"""Return a copy of the message with image_url ommited within content parts"""
if isinstance(msg, dict):
result = {}
for key, value in msg.items():
if key == "content" and isinstance(value, list):
result[key] = [
{k: v for k, v in item.items() if k != "image_url"} if isinstance(item, dict) else item
for item in value
]
else:
result[key] = value
return result
elif isinstance(msg, list):
return [sanitize_message(item) for item in msg]
else:
return msg
def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any]]:
"""
Convert UITARS internal message format back to LiteLLM format.
This function processes reasoning, computer_call, and computer_call_output messages
and converts them to the appropriate LiteLLM assistant message format.
Args:
messages: List of UITARS internal messages
Returns:
List of LiteLLM formatted messages
"""
litellm_messages = []
current_assistant_content = []
for message in messages:
if isinstance(message, dict):
message_type = message.get("type")
if message_type == "reasoning":
# Extract reasoning text from summary
summary = message.get("summary", [])
if summary and isinstance(summary, list):
for summary_item in summary:
if isinstance(summary_item, dict) and summary_item.get("type") == "summary_text":
reasoning_text = summary_item.get("text", "")
if reasoning_text:
current_assistant_content.append(f"Thought: {reasoning_text}")
elif message_type == "computer_call":
# Convert computer action to UITARS action format
action = message.get("action", {})
action_type = action.get("type")
if action_type == "click":
x, y = action.get("x", 0), action.get("y", 0)
button = action.get("button", "left")
if button == "left":
action_text = f"Action: click(start_box='({x},{y})')"
elif button == "right":
action_text = f"Action: right_single(start_box='({x},{y})')"
else:
action_text = f"Action: click(start_box='({x},{y})')"
elif action_type == "double_click":
x, y = action.get("x", 0), action.get("y", 0)
action_text = f"Action: left_double(start_box='({x},{y})')"
elif action_type == "drag":
start_x, start_y = action.get("start_x", 0), action.get("start_y", 0)
end_x, end_y = action.get("end_x", 0), action.get("end_y", 0)
action_text = f"Action: drag(start_box='({start_x},{start_y})', end_box='({end_x},{end_y})')"
elif action_type == "key":
key = action.get("key", "")
action_text = f"Action: hotkey(key='{key}')"
elif action_type == "type":
text = action.get("text", "")
# Escape single quotes in the text
escaped_text = escape_single_quotes(text)
action_text = f"Action: type(content='{escaped_text}')"
elif action_type == "scroll":
x, y = action.get("x", 0), action.get("y", 0)
direction = action.get("direction", "down")
action_text = f"Action: scroll(start_box='({x},{y})', direction='{direction}')"
elif action_type == "wait":
action_text = "Action: wait()"
else:
# Fallback for unknown action types
action_text = f"Action: {action_type}({action})"
current_assistant_content.append(action_text)
# When we hit a computer_call_output, finalize the current assistant message
if current_assistant_content:
litellm_messages.append({
"role": "assistant",
"content": [{"type": "text", "text": "\n".join(current_assistant_content)}]
})
current_assistant_content = []
elif message_type == "computer_call_output":
# Add screenshot from computer call output
output = message.get("output", {})
if isinstance(output, dict) and output.get("type") == "input_image":
image_url = output.get("image_url", "")
if image_url:
litellm_messages.append({
"role": "user",
"content": [{"type": "image_url", "image_url": {"url": image_url}}]
})
elif message.get("role") == "user":
# # Handle user messages
# content = message.get("content", "")
# if isinstance(content, str):
# litellm_messages.append({
# "role": "user",
# "content": content
# })
# elif isinstance(content, list):
# litellm_messages.append({
# "role": "user",
# "content": content
# })
pass
# Add any remaining assistant content
if current_assistant_content:
litellm_messages.append({
"role": "assistant",
"content": current_assistant_content
})
return litellm_messages
@agent_loop(models=r"(?i).*ui-?tars.*", priority=10)
async def uitars_loop(
messages: Messages,
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
"""
UITARS agent loop using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B model.
Supports UITARS vision-language models for computer control.
"""
tools = tools or []
# Create response items
response_items = []
# Find computer tool for screen dimensions
computer_tool = None
for tool_schema in tools:
if tool_schema["type"] == "computer":
computer_tool = tool_schema["computer"]
break
# Get screen dimensions
screen_width, screen_height = 1024, 768
if computer_tool:
try:
screen_width, screen_height = await computer_tool.get_dimensions()
except:
pass
# Process messages to extract instruction and image
instruction = ""
image_data = None
# Convert messages to list if string
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
# Extract instruction and latest screenshot
for message in reversed(messages):
if isinstance(message, dict):
content = message.get("content", "")
# Handle different content formats
if isinstance(content, str):
if not instruction and message.get("role") == "user":
instruction = content
elif isinstance(content, list):
for item in content:
if isinstance(item, dict):
if item.get("type") == "text" and not instruction:
instruction = item.get("text", "")
elif item.get("type") == "image_url" and not image_data:
image_url = item.get("image_url", {})
if isinstance(image_url, dict):
image_data = image_url.get("url", "")
else:
image_data = image_url
# Also check for computer_call_output with screenshots
if message.get("type") == "computer_call_output" and not image_data:
output = message.get("output", {})
if isinstance(output, dict) and output.get("type") == "input_image":
image_data = output.get("image_url", "")
if instruction and image_data:
break
if not instruction:
instruction = "Help me complete this task by analyzing the screen and taking appropriate actions."
# Create prompt
user_prompt = UITARS_PROMPT_TEMPLATE.format(
instruction=instruction,
action_space=UITARS_ACTION_SPACE,
language="English"
)
# Convert conversation history to LiteLLM format
history_messages = convert_uitars_messages_to_litellm(messages)
# Prepare messages for liteLLM
litellm_messages = [
{
"role": "system",
"content": "You are a helpful assistant."
}
]
# Add current user instruction with screenshot
current_user_message = {
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
]
}
litellm_messages.append(current_user_message)
# Process image for UITARS
if not image_data:
# Take screenshot if none found in messages
if computer_handler:
image_data = await computer_handler.screenshot()
await _on_screenshot(image_data, "screenshot_before")
# Add screenshot to output items so it can be retained in history
response_items.append(make_input_image_item(image_data))
else:
raise ValueError("No screenshot found in messages and no computer_handler provided")
processed_image, original_width, original_height = process_image_for_uitars(image_data)
encoded_image = pil_to_base64(processed_image)
# Add conversation history
if history_messages:
litellm_messages.extend(history_messages)
else:
litellm_messages.append({
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
]
})
# Prepare API call kwargs
api_kwargs = {
"model": model,
"messages": litellm_messages,
"max_tokens": kwargs.get("max_tokens", 500),
"temperature": kwargs.get("temperature", 0.0),
"do_sample": kwargs.get("temperature", 0.0) > 0.0,
"num_retries": max_retries,
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
}
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
# Call liteLLM with UITARS model
response = await litellm.acompletion(**api_kwargs)
# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)
# Extract response content
response_content = response.choices[0].message.content.strip() # type: ignore
# Parse UITARS response
parsed_responses = parse_uitars_response(response_content, original_width, original_height)
# Convert to computer actions
computer_actions = convert_to_computer_actions(parsed_responses, original_width, original_height)
# Add computer actions to response items
thought = parsed_responses[0].get("thought", "")
if thought:
response_items.append(make_reasoning_item(thought))
response_items.extend(computer_actions)
# Extract usage information
response_usage = {
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(response_usage)
# Create agent response
agent_response = {
"output": response_items,
"usage": response_usage
}
return agent_response

View File

@@ -1,4 +0,0 @@
"""Provider implementations for different AI services."""
# Import specific providers only when needed to avoid circular imports
__all__ = [] # Let each provider module handle its own exports

View File

@@ -1,6 +0,0 @@
"""Anthropic provider implementation."""
from .loop import AnthropicLoop
from .types import LLMProvider
__all__ = ["AnthropicLoop", "LLMProvider"]

View File

@@ -1,360 +0,0 @@
from typing import Any, List, Dict, cast
import httpx
import asyncio
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaToolUnionParam
from ..types import LLMProvider
from .logging import log_api_interaction
import random
import logging
logger = logging.getLogger(__name__)
class APIConnectionError(Exception):
"""Error raised when there are connection issues with the API."""
pass
class BaseAnthropicClient:
"""Base class for Anthropic API clients."""
MAX_RETRIES = 10
INITIAL_RETRY_DELAY = 1.0
MAX_RETRY_DELAY = 60.0
JITTER_FACTOR = 0.1
async def create_message(
self,
*,
messages: list[BetaMessageParam],
system: list[Any],
tools: list[BetaToolUnionParam],
max_tokens: int,
betas: list[str],
) -> BetaMessage:
"""Create a message using the Anthropic API."""
raise NotImplementedError
async def _make_api_call_with_retries(self, api_call):
"""Make an API call with exponential backoff retry logic.
Args:
api_call: Async function that makes the actual API call
Returns:
API response
Raises:
APIConnectionError: If all retries fail
"""
retry_count = 0
last_error = None
while retry_count < self.MAX_RETRIES:
try:
return await api_call()
except Exception as e:
last_error = e
retry_count += 1
if retry_count == self.MAX_RETRIES:
break
# Calculate delay with exponential backoff and jitter
delay = min(
self.INITIAL_RETRY_DELAY * (2 ** (retry_count - 1)), self.MAX_RETRY_DELAY
)
# Add jitter to avoid thundering herd
jitter = delay * self.JITTER_FACTOR * (2 * random.random() - 1)
final_delay = delay + jitter
logger.info(
f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) "
f"in {final_delay:.2f} seconds after error: {str(e)}"
)
await asyncio.sleep(final_delay)
raise APIConnectionError(
f"Failed after {self.MAX_RETRIES} retries. " f"Last error: {str(last_error)}"
)
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: int = 4096
) -> Any:
"""Run the Anthropic API with the Claude model, supports interleaved tool calling.
Args:
messages: List of message objects
system: System prompt
max_tokens: Maximum tokens to generate
Returns:
API response
"""
# Add the tool_result check/fix logic here
fixed_messages = self._fix_missing_tool_results(messages)
# Get model name from concrete implementation if available
model_name = getattr(self, "model", "unknown model")
logger.info(f"Running Anthropic API call with model {model_name}")
retry_count = 0
while retry_count < self.MAX_RETRIES:
try:
# Call the Anthropic API through create_message which is implemented by subclasses
# Convert system str to the list format expected by create_message
system_list = [system]
# Convert message format if needed - concrete implementations may do further conversion
response = await self.create_message(
messages=cast(list[BetaMessageParam], fixed_messages),
system=system_list,
tools=[], # Tools are included in the messages
max_tokens=max_tokens,
betas=["tools-2023-12-13"],
)
logger.info(f"Anthropic API call successful")
return response
except Exception as e:
retry_count += 1
wait_time = self.INITIAL_RETRY_DELAY * (
2 ** (retry_count - 1)
) # Exponential backoff
logger.info(
f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) in {wait_time:.2f} seconds after error: {str(e)}"
)
await asyncio.sleep(wait_time)
# If we get here, all retries failed
raise RuntimeError(f"Failed to call Anthropic API after {self.MAX_RETRIES} attempts")
def _fix_missing_tool_results(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Check for and fix any missing tool_result blocks after tool_use blocks.
Args:
messages: List of message objects
Returns:
Fixed messages with proper tool_result blocks
"""
fixed_messages = []
pending_tool_uses = {} # Map of tool_use IDs to their details
for i, message in enumerate(messages):
# Track any tool_use blocks in this message
if message.get("role") == "assistant" and "content" in message:
content = message.get("content", [])
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_use":
tool_id = block.get("id")
if tool_id:
pending_tool_uses[tool_id] = {
"name": block.get("name", ""),
"input": block.get("input", {}),
}
# Check if this message handles any pending tool_use blocks
if message.get("role") == "user" and "content" in message:
# Check for tool_result blocks in this message
content = message.get("content", [])
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_result":
tool_id = block.get("tool_use_id")
if tool_id in pending_tool_uses:
# This tool_result handles a pending tool_use
pending_tool_uses.pop(tool_id)
# Add the message to our fixed list
fixed_messages.append(message)
# If this is an assistant message with tool_use blocks and there are
# pending tool uses that need to be resolved before the next assistant message
if (
i + 1 < len(messages)
and message.get("role") == "assistant"
and messages[i + 1].get("role") == "assistant"
and pending_tool_uses
):
# We need to insert a user message with tool_results for all pending tool_uses
tool_results = []
for tool_id, tool_info in pending_tool_uses.items():
tool_results.append(
{
"type": "tool_result",
"tool_use_id": tool_id,
"content": {
"type": "error",
"message": "Tool execution was skipped or failed",
},
}
)
# Insert a synthetic user message with the tool results
if tool_results:
fixed_messages.append({"role": "user", "content": tool_results})
# Clear pending tools since we've added results for them
pending_tool_uses = {}
# Check if there are any remaining pending tool_uses at the end of the conversation
if pending_tool_uses and fixed_messages and fixed_messages[-1].get("role") == "assistant":
# Add a final user message with tool results for any pending tool_uses
tool_results = []
for tool_id, tool_info in pending_tool_uses.items():
tool_results.append(
{
"type": "tool_result",
"tool_use_id": tool_id,
"content": {
"type": "error",
"message": "Tool execution was skipped or failed",
},
}
)
if tool_results:
fixed_messages.append({"role": "user", "content": tool_results})
return fixed_messages
class AnthropicDirectClient(BaseAnthropicClient):
"""Direct Anthropic API client implementation."""
def __init__(self, api_key: str, model: str):
self.model = model
self.client = Anthropic(api_key=api_key, http_client=self._create_http_client())
def _create_http_client(self) -> httpx.Client:
"""Create an HTTP client with appropriate settings."""
return httpx.Client(
verify=True,
timeout=httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0),
transport=httpx.HTTPTransport(
retries=3,
verify=True,
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
),
)
async def create_message(
self,
*,
messages: list[BetaMessageParam],
system: list[Any],
tools: list[BetaToolUnionParam],
max_tokens: int,
betas: list[str],
) -> BetaMessage:
"""Create a message using the direct Anthropic API with retry logic."""
async def api_call():
response = self.client.beta.messages.with_raw_response.create(
max_tokens=max_tokens,
messages=messages,
model=self.model,
system=system,
tools=tools,
betas=betas,
)
log_api_interaction(response.http_response.request, response.http_response, None)
return response.parse()
try:
return await self._make_api_call_with_retries(api_call)
except Exception as e:
log_api_interaction(None, None, e)
raise
class AnthropicVertexClient(BaseAnthropicClient):
"""Google Cloud Vertex AI implementation of Anthropic client."""
def __init__(self, model: str):
self.model = model
self.client = AnthropicVertex()
async def create_message(
self,
*,
messages: list[BetaMessageParam],
system: list[Any],
tools: list[BetaToolUnionParam],
max_tokens: int,
betas: list[str],
) -> BetaMessage:
"""Create a message using Vertex AI with retry logic."""
async def api_call():
response = self.client.beta.messages.with_raw_response.create(
max_tokens=max_tokens,
messages=messages,
model=self.model,
system=system,
tools=tools,
betas=betas,
)
log_api_interaction(response.http_response.request, response.http_response, None)
return response.parse()
try:
return await self._make_api_call_with_retries(api_call)
except Exception as e:
log_api_interaction(None, None, e)
raise
class AnthropicBedrockClient(BaseAnthropicClient):
"""AWS Bedrock implementation of Anthropic client."""
def __init__(self, model: str):
self.model = model
self.client = AnthropicBedrock()
async def create_message(
self,
*,
messages: list[BetaMessageParam],
system: list[Any],
tools: list[BetaToolUnionParam],
max_tokens: int,
betas: list[str],
) -> BetaMessage:
"""Create a message using AWS Bedrock with retry logic."""
async def api_call():
response = self.client.beta.messages.with_raw_response.create(
max_tokens=max_tokens,
messages=messages,
model=self.model,
system=system,
tools=tools,
betas=betas,
)
log_api_interaction(response.http_response.request, response.http_response, None)
return response.parse()
try:
return await self._make_api_call_with_retries(api_call)
except Exception as e:
log_api_interaction(None, None, e)
raise
class AnthropicClientFactory:
"""Factory for creating appropriate Anthropic client implementations."""
@staticmethod
def create_client(provider: LLMProvider, api_key: str, model: str) -> BaseAnthropicClient:
"""Create an appropriate client based on the provider."""
if provider == LLMProvider.ANTHROPIC:
return AnthropicDirectClient(api_key, model)
elif provider == LLMProvider.VERTEX:
return AnthropicVertexClient(model)
elif provider == LLMProvider.BEDROCK:
return AnthropicBedrockClient(model)
raise ValueError(f"Unsupported provider: {provider}")

View File

@@ -1,150 +0,0 @@
"""API logging functionality."""
import json
import logging
from datetime import datetime
from pathlib import Path
import httpx
from typing import Any
logger = logging.getLogger(__name__)
def _filter_base64_images(content: Any) -> Any:
"""Filter out base64 image data from content.
Args:
content: Content to filter
Returns:
Filtered content with base64 data replaced by placeholder
"""
if isinstance(content, dict):
filtered = {}
for key, value in content.items():
if (
isinstance(value, dict)
and value.get("type") == "image"
and value.get("source", {}).get("type") == "base64"
):
# Replace base64 data with placeholder
filtered[key] = {
**value,
"source": {
**value["source"],
"data": "<base64_image_data>"
}
}
else:
filtered[key] = _filter_base64_images(value)
return filtered
elif isinstance(content, list):
return [_filter_base64_images(item) for item in content]
return content
def log_api_interaction(
request: httpx.Request | None,
response: httpx.Response | object | None,
error: Exception | None,
log_dir: Path = Path("/tmp/claude_logs")
) -> None:
"""Log API request, response, and any errors in a structured way.
Args:
request: The HTTP request if available
response: The HTTP response or response object
error: Any error that occurred
log_dir: Directory to store log files
"""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
# Helper function to safely decode JSON content
def safe_json_decode(content):
if not content:
return None
try:
if isinstance(content, bytes):
return json.loads(content.decode())
elif isinstance(content, str):
return json.loads(content)
elif isinstance(content, dict):
return content
return None
except json.JSONDecodeError:
return {"error": "Could not decode JSON", "raw": str(content)}
# Process request content
request_content = None
if request and request.content:
request_content = safe_json_decode(request.content)
request_content = _filter_base64_images(request_content)
# Process response content
response_content = None
if response:
if isinstance(response, httpx.Response):
try:
response_content = response.json()
except json.JSONDecodeError:
response_content = {"error": "Could not decode JSON", "raw": response.text}
else:
response_content = safe_json_decode(response)
response_content = _filter_base64_images(response_content)
log_entry = {
"timestamp": timestamp,
"request": {
"method": request.method if request else None,
"url": str(request.url) if request else None,
"headers": dict(request.headers) if request else None,
"content": request_content,
} if request else None,
"response": {
"status_code": response.status_code if isinstance(response, httpx.Response) else None,
"headers": dict(response.headers) if isinstance(response, httpx.Response) else None,
"content": response_content,
} if response else None,
"error": {
"type": type(error).__name__ if error else None,
"message": str(error) if error else None,
} if error else None
}
# Log to file with timestamp in filename
log_dir.mkdir(exist_ok=True)
log_file = log_dir / f"claude_api_{timestamp.replace(' ', '_').replace(':', '-')}.json"
with open(log_file, 'w') as f:
json.dump(log_entry, f, indent=2)
# Also log a summary to the console
if error:
logger.error(f"API Error at {timestamp}: {error}")
else:
logger.info(
f"API Call at {timestamp}: "
f"{request.method if request else 'No request'} -> "
f"{response.status_code if isinstance(response, httpx.Response) else 'No response'}"
)
# Log if there are any images in the content
if response_content:
image_count = count_images(response_content)
if image_count > 0:
logger.info(f"Response contains {image_count} images")
def count_images(content: dict | list | Any) -> int:
"""Count the number of images in the content.
Args:
content: Content to search for images
Returns:
Number of images found
"""
if isinstance(content, dict):
if content.get("type") == "image":
return 1
return sum(count_images(v) for v in content.values())
elif isinstance(content, list):
return sum(count_images(item) for item in content)
return 0

View File

@@ -1,140 +0,0 @@
"""API call handling for Anthropic provider."""
import logging
import asyncio
from typing import List
from anthropic.types.beta import (
BetaMessage,
BetaMessageParam,
BetaTextBlockParam,
)
from .types import LLMProvider
from .prompts import SYSTEM_PROMPT
# Constants
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
logger = logging.getLogger(__name__)
class AnthropicAPIHandler:
"""Handles API calls to Anthropic's API with structured error handling and retries."""
def __init__(self, loop):
"""Initialize the API handler.
Args:
loop: Reference to the parent loop instance that provides context
"""
self.loop = loop
async def make_api_call(
self, messages: List[BetaMessageParam], system_prompt: str = SYSTEM_PROMPT
) -> BetaMessage:
"""Make API call to Anthropic with retry logic.
Args:
messages: List of messages to send to the API
system_prompt: System prompt to use (default: SYSTEM_PROMPT)
Returns:
API response
Raises:
RuntimeError: If API call fails after all retries
"""
if self.loop.client is None:
raise RuntimeError("Client not initialized. Call initialize_client() first.")
if self.loop.tool_manager is None:
raise RuntimeError("Tool manager not initialized. Call initialize_client() first.")
last_error = None
# Add detailed debug logging to examine messages
logger.info(f"Sending {len(messages)} messages to Anthropic API")
# Log tool use IDs and tool result IDs for debugging
tool_use_ids = set()
tool_result_ids = set()
for i, msg in enumerate(messages):
logger.info(f"Message {i}: role={msg.get('role')}")
if isinstance(msg.get("content"), list):
for content_block in msg.get("content", []):
if isinstance(content_block, dict):
block_type = content_block.get("type")
if block_type == "tool_use" and "id" in content_block:
tool_id = content_block.get("id")
tool_use_ids.add(tool_id)
logger.info(f" - Found tool_use with ID: {tool_id}")
elif block_type == "tool_result" and "tool_use_id" in content_block:
result_id = content_block.get("tool_use_id")
tool_result_ids.add(result_id)
logger.info(f" - Found tool_result referencing ID: {result_id}")
# Check for mismatches
missing_tool_uses = tool_result_ids - tool_use_ids
if missing_tool_uses:
logger.warning(
f"Found tool_result IDs without matching tool_use IDs: {missing_tool_uses}"
)
for attempt in range(self.loop.max_retries):
try:
# Log request
request_data = {
"messages": messages,
"max_tokens": self.loop.max_tokens,
"system": system_prompt,
}
# Let ExperimentManager handle sanitization
self.loop._log_api_call("request", request_data)
# Setup betas and system
system = BetaTextBlockParam(
type="text",
text=system_prompt,
)
betas = [COMPUTER_USE_BETA_FLAG]
# Add prompt caching if enabled in the message manager's config
if self.loop.message_manager.config.enable_caching:
betas.append(PROMPT_CACHING_BETA_FLAG)
system["cache_control"] = {"type": "ephemeral"}
# Make API call
response = await self.loop.client.create_message(
messages=messages,
system=[system],
tools=self.loop.tool_manager.get_tool_params(),
max_tokens=self.loop.max_tokens,
betas=betas,
)
# Let ExperimentManager handle sanitization
self.loop._log_api_call("response", request_data, response)
return response
except Exception as e:
last_error = e
logger.error(
f"Error in API call (attempt {attempt + 1}/{self.loop.max_retries}): {str(e)}"
)
self.loop._log_api_call("error", {"messages": messages}, error=e)
if attempt < self.loop.max_retries - 1:
await asyncio.sleep(
self.loop.retry_delay * (attempt + 1)
) # Exponential backoff
continue
# If we get here, all retries failed
error_message = f"API call failed after {self.loop.max_retries} attempts"
if last_error:
error_message += f": {str(last_error)}"
logger.error(error_message)
raise RuntimeError(error_message)

View File

@@ -1,5 +0,0 @@
"""Anthropic callbacks package."""
from .manager import CallbackManager
__all__ = ["CallbackManager"]

View File

@@ -1,65 +0,0 @@
from typing import Callable, Protocol
import httpx
from anthropic.types.beta import BetaContentBlockParam
from ..tools import ToolResult
class APICallback(Protocol):
"""Protocol for API callbacks."""
def __call__(
self,
request: httpx.Request | None,
response: httpx.Response | object | None,
error: Exception | None,
) -> None: ...
class ContentCallback(Protocol):
"""Protocol for content callbacks."""
def __call__(self, content: BetaContentBlockParam) -> None: ...
class ToolCallback(Protocol):
"""Protocol for tool callbacks."""
def __call__(self, result: ToolResult, tool_id: str) -> None: ...
class CallbackManager:
"""Manages various callbacks for the agent system."""
def __init__(
self,
content_callback: ContentCallback,
tool_callback: ToolCallback,
api_callback: APICallback,
):
"""Initialize the callback manager.
Args:
content_callback: Callback for content updates
tool_callback: Callback for tool execution results
api_callback: Callback for API interactions
"""
self.content_callback = content_callback
self.tool_callback = tool_callback
self.api_callback = api_callback
def on_content(self, content: BetaContentBlockParam) -> None:
"""Handle content updates."""
self.content_callback(content)
def on_tool_result(self, result: ToolResult, tool_id: str) -> None:
"""Handle tool execution results."""
self.tool_callback(result, tool_id)
def on_api_interaction(
self,
request: httpx.Request | None,
response: httpx.Response | object | None,
error: Exception | None,
) -> None:
"""Handle API interactions."""
self.api_callback(request, response, error)

View File

@@ -1,568 +0,0 @@
"""Anthropic-specific agent loop implementation."""
import logging
import asyncio
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, cast
from anthropic.types.beta import (
BetaMessage,
BetaMessageParam,
BetaTextBlock,
BetaContentBlockParam,
)
import base64
from datetime import datetime
# Computer
from computer import Computer
# Base imports
from ...core.base import BaseLoop
from ...core.messages import StandardMessageManager, ImageRetentionConfig
from ...core.types import AgentResponse
# Anthropic provider-specific imports
from .api.client import AnthropicClientFactory, BaseAnthropicClient
from .tools.manager import ToolManager
from .prompts import SYSTEM_PROMPT
from .types import LLMProvider
from .tools import ToolResult
from .utils import to_anthropic_format, to_agent_response_format
# Import the new modules we created
from .api_handler import AnthropicAPIHandler
from .response_handler import AnthropicResponseHandler
from .callbacks.manager import CallbackManager
# Constants
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
logger = logging.getLogger(__name__)
class AnthropicLoop(BaseLoop):
"""Anthropic-specific implementation of the agent loop.
This class extends BaseLoop to provide specialized support for Anthropic's Claude models
with their unique tool-use capabilities, custom message formatting, and
callback-driven approach to handling responses.
"""
###########################################
# INITIALIZATION AND CONFIGURATION
###########################################
def __init__(
self,
api_key: str,
computer: Computer,
model: str = "claude-3-7-sonnet-20250219",
only_n_most_recent_images: Optional[int] = 2,
base_dir: Optional[str] = "trajectories",
max_retries: int = 3,
retry_delay: float = 1.0,
save_trajectory: bool = True,
**kwargs,
):
"""Initialize the Anthropic loop.
Args:
api_key: Anthropic API key
model: Model name (fixed to claude-3-7-sonnet-20250219)
computer: Computer instance
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
base_dir: Base directory for saving experiment data
max_retries: Maximum number of retries for API calls
retry_delay: Delay between retries in seconds
save_trajectory: Whether to save trajectory data
"""
# Initialize base class with core config
super().__init__(
computer=computer,
model=model,
api_key=api_key,
max_retries=max_retries,
retry_delay=retry_delay,
base_dir=base_dir,
save_trajectory=save_trajectory,
only_n_most_recent_images=only_n_most_recent_images,
**kwargs,
)
# Initialize message manager
self.message_manager = StandardMessageManager(
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
)
# Anthropic-specific attributes
self.provider = LLMProvider.ANTHROPIC
self.client = None
self.retry_count = 0
self.tool_manager = None
self.callback_manager = None
self.queue = asyncio.Queue() # Initialize queue
self.loop_task = None # Store the loop task for cancellation
# Initialize handlers
self.api_handler = AnthropicAPIHandler(self)
self.response_handler = AnthropicResponseHandler(self)
###########################################
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
###########################################
async def initialize_client(self) -> None:
"""Initialize the Anthropic API client and tools.
Implements abstract method from BaseLoop to set up the Anthropic-specific
client, tool manager, message manager, and callback handlers.
"""
try:
logger.info(f"Initializing Anthropic client with model {self.model}...")
# Initialize client
self.client = AnthropicClientFactory.create_client(
provider=self.provider, api_key=self.api_key, model=self.model
)
# Initialize callback manager with our callback handlers
self.callback_manager = CallbackManager(
content_callback=self._handle_content,
tool_callback=self._handle_tool_result,
api_callback=self._handle_api_interaction,
)
# Initialize tool manager
self.tool_manager = ToolManager(self.computer)
await self.tool_manager.initialize()
logger.info(f"Initialized Anthropic client with model {self.model}")
except Exception as e:
logger.error(f"Error initializing Anthropic client: {str(e)}")
self.client = None
raise RuntimeError(f"Failed to initialize Anthropic client: {str(e)}")
###########################################
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
###########################################
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
"""Run the agent loop with provided messages.
Args:
messages: List of message objects in standard OpenAI format
Yields:
Agent response format
"""
try:
logger.info("Starting Anthropic loop run")
# Create queue for response streaming
queue = asyncio.Queue()
# Ensure client is initialized
if self.client is None or self.tool_manager is None:
logger.info("Initializing client...")
await self.initialize_client()
if self.client is None:
raise RuntimeError("Failed to initialize client")
logger.info("Client initialized successfully")
# Start loop in background task
self.loop_task = asyncio.create_task(self._run_loop(queue, messages))
# Process and yield messages as they arrive
while True:
try:
item = await queue.get()
if item is None: # Stop signal
break
yield item
queue.task_done()
except Exception as e:
logger.error(f"Error processing queue item: {str(e)}")
continue
# Wait for loop to complete
await self.loop_task
# Send completion message
yield {
"role": "assistant",
"content": "Task completed successfully.",
"metadata": {"title": "✅ Complete"},
}
except Exception as e:
logger.error(f"Error executing task: {str(e)}")
yield {
"role": "assistant",
"content": f"Error: {str(e)}",
"metadata": {"title": "❌ Error"},
}
async def cancel(self) -> None:
"""Cancel the currently running agent loop task.
This method stops the ongoing processing in the agent loop
by cancelling the loop_task if it exists and is running.
"""
if self.loop_task and not self.loop_task.done():
logger.info("Cancelling Anthropic loop task")
self.loop_task.cancel()
try:
# Wait for the task to be cancelled with a timeout
await asyncio.wait_for(self.loop_task, timeout=2.0)
except asyncio.TimeoutError:
logger.warning("Timeout while waiting for loop task to cancel")
except asyncio.CancelledError:
logger.info("Loop task cancelled successfully")
except Exception as e:
logger.error(f"Error while cancelling loop task: {str(e)}")
finally:
# Put None in the queue to signal any waiting consumers to stop
await self.queue.put(None)
logger.info("Anthropic loop task cancelled")
else:
logger.info("No active Anthropic loop task to cancel")
###########################################
# AGENT LOOP IMPLEMENTATION
###########################################
async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
"""Run the agent loop with provided messages.
Args:
queue: Queue for response streaming
messages: List of messages in standard OpenAI format
"""
try:
while True:
# Capture screenshot
try:
# Take screenshot - always returns raw PNG bytes
screenshot = await self.computer.interface.screenshot()
logger.info("Screenshot captured successfully")
# Convert PNG bytes to base64
base64_image = base64.b64encode(screenshot).decode("utf-8")
logger.info(f"Screenshot converted to base64 (size: {len(base64_image)} bytes)")
# Save screenshot if requested
if self.save_trajectory and self.experiment_manager:
try:
self._save_screenshot(base64_image, action_type="state")
logger.info("Screenshot saved to trajectory")
except Exception as e:
logger.error(f"Error saving screenshot: {str(e)}")
# Create screenshot message
screen_info_msg = {
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": base64_image,
},
}
],
}
# Add screenshot to messages
messages.append(screen_info_msg)
logger.info("Screenshot message added to conversation")
except Exception as e:
logger.error(f"Error capturing or processing screenshot: {str(e)}")
raise
# Create new turn directory for this API call
self._create_turn_dir()
# Apply image retention policy
self.message_manager.messages = messages.copy()
prepared_messages = self.message_manager.get_messages()
# Convert standard messages to Anthropic format using utility function
anthropic_messages, system_content = to_anthropic_format(prepared_messages)
# Use API handler to make API call with Anthropic format
response = await self.api_handler.make_api_call(
messages=cast(List[BetaMessageParam], anthropic_messages),
system_prompt=system_content or SYSTEM_PROMPT,
)
# Use response handler to handle the response and get new messages
new_messages, should_continue = await self.response_handler.handle_response(
response, messages
)
# Add new messages to the parent's message history
messages.extend(new_messages)
openai_compatible_response = await to_agent_response_format(
response,
messages,
model=self.model,
)
# Log standardized response for ease of parsing
self._log_api_call("agent_response", request=None, response=openai_compatible_response)
await queue.put(openai_compatible_response)
if not should_continue:
break
# Signal completion
await queue.put(None)
except Exception as e:
logger.error(f"Error in _run_loop: {str(e)}")
await queue.put(
{
"role": "assistant",
"content": f"Error in agent loop: {str(e)}",
"metadata": {"title": "❌ Error"},
}
)
await queue.put(None)
###########################################
# RESPONSE AND CALLBACK HANDLING
###########################################
async def _handle_response(self, response: BetaMessage, messages: List[Dict[str, Any]]) -> bool:
"""Handle a response from the Anthropic API.
Args:
response: The response from the Anthropic API
messages: The message history
Returns:
bool: Whether to continue the conversation
"""
try:
# Convert response to standard format
openai_compatible_response = await to_agent_response_format(
response,
messages,
model=self.model,
)
# Put the response on the queue
await self.queue.put(openai_compatible_response)
if self.callback_manager is None:
raise RuntimeError(
"Callback manager not initialized. Call initialize_client() first."
)
# Handle tool use blocks and collect ALL results before adding to messages
tool_result_content = []
has_tool_use = False
for content_block in response.content:
# Notify callback of content
self.callback_manager.on_content(cast(BetaContentBlockParam, content_block))
# Handle tool use - carefully check and access attributes
if hasattr(content_block, "type") and content_block.type == "tool_use":
has_tool_use = True
if self.tool_manager is None:
raise RuntimeError(
"Tool manager not initialized. Call initialize_client() first."
)
# Safely get attributes
tool_name = getattr(content_block, "name", "")
tool_input = getattr(content_block, "input", {})
tool_id = getattr(content_block, "id", "")
result = await self.tool_manager.execute_tool(
name=tool_name,
tool_input=cast(Dict[str, Any], tool_input),
)
# Create tool result
tool_result = self._make_tool_result(cast(ToolResult, result), tool_id)
tool_result_content.append(tool_result)
# Notify callback of tool result
self.callback_manager.on_tool_result(cast(ToolResult, result), tool_id)
# If we had any tool_use blocks, we MUST add the tool_result message
# even if there were errors or no actual results
if has_tool_use:
# If somehow we have no tool results but had tool uses, add synthetic error results
if not tool_result_content:
logger.warning(
"Had tool uses but no tool results, adding synthetic error results"
)
for content_block in response.content:
if hasattr(content_block, "type") and content_block.type == "tool_use":
tool_id = getattr(content_block, "id", "")
if tool_id:
tool_result_content.append(
{
"type": "tool_result",
"tool_use_id": tool_id,
"content": {
"type": "error",
"text": "Tool execution was skipped or failed",
},
"is_error": True,
}
)
# Add ALL tool results as a SINGLE user message
messages.append({"role": "user", "content": tool_result_content})
return True
else:
# No tool uses, we're done
self.callback_manager.on_content({"type": "text", "text": "<DONE>"})
return False
except Exception as e:
logger.error(f"Error handling response: {str(e)}")
messages.append(
{
"role": "assistant",
"content": f"Error: {str(e)}",
}
)
return False
def _response_to_blocks(self, response: BetaMessage) -> List[Dict[str, Any]]:
"""Convert Anthropic API response to standard blocks format.
Args:
response: API response message
Returns:
List of content blocks in standard format
"""
result = []
for block in response.content:
if isinstance(block, BetaTextBlock):
result.append({"type": "text", "text": block.text})
elif hasattr(block, "type") and block.type == "tool_use":
# Safely access attributes after confirming it's a tool_use
result.append(
{
"type": "tool_use",
"id": getattr(block, "id", ""),
"name": getattr(block, "name", ""),
"input": getattr(block, "input", {}),
}
)
else:
# For other block types, convert to dict
block_dict = {}
for key, value in vars(block).items():
if not key.startswith("_"):
block_dict[key] = value
result.append(block_dict)
return result
def _make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]:
"""Convert a tool result to standard format.
Args:
result: Tool execution result
tool_use_id: ID of the tool use
Returns:
Formatted tool result
"""
if result.content:
return {
"type": "tool_result",
"content": result.content,
"tool_use_id": tool_use_id,
"is_error": bool(result.error),
}
tool_result_content = []
is_error = False
if result.error:
is_error = True
tool_result_content = [
{
"type": "text",
"text": self._maybe_prepend_system_tool_result(result, result.error),
}
]
else:
if result.output:
tool_result_content.append(
{
"type": "text",
"text": self._maybe_prepend_system_tool_result(result, result.output),
}
)
if result.base64_image:
tool_result_content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{result.base64_image}"},
}
)
return {
"type": "tool_result",
"content": tool_result_content,
"tool_use_id": tool_use_id,
"is_error": is_error,
}
def _maybe_prepend_system_tool_result(self, result: ToolResult, result_text: str) -> str:
"""Prepend system information to tool result if available.
Args:
result: Tool execution result
result_text: Text to prepend to
Returns:
Text with system information prepended if available
"""
if result.system:
result_text = f"<s>{result.system}</s>\n{result_text}"
return result_text
###########################################
# CALLBACK HANDLERS
###########################################
def _handle_content(self, content):
"""Handle content updates from the assistant."""
if content.get("type") == "text":
text = content.get("text", "")
if text == "<DONE>":
return
logger.info(f"Assistant: {text}")
def _handle_tool_result(self, result, tool_id):
"""Handle tool execution results."""
if result.error:
logger.error(f"Tool {tool_id} error: {result.error}")
else:
logger.info(f"Tool {tool_id} output: {result.output}")
def _handle_api_interaction(
self, request: Any, response: Any, error: Optional[Exception]
) -> None:
"""Handle API interactions."""
if error:
logger.error(f"API error: {error}")
self._log_api_call("error", request, error=error)
else:
logger.debug(f"API request: {request}")
if response:
self._log_api_call("response", request, response)
else:
self._log_api_call("request", request)

View File

@@ -1,23 +0,0 @@
"""System prompts for Anthropic provider."""
from datetime import datetime
import platform
today = datetime.today()
today = f"{today.strftime('%A, %B')} {today.day}, {today.year}"
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
* You are utilising a macOS virtual machine using ARM architecture with internet access and Safari as default browser.
* You can feel free to install macOS applications with your bash tool. Use curl instead of wget.
* Using bash tool you can start GUI applications. GUI apps run with bash tool will appear within your desktop environment, but they may take some time to appear. Take a screenshot to confirm it did.
* When using your bash tool with commands that are expected to output very large quantities of text, redirect into a tmp file and use str_replace_editor or `grep -n -B <lines before> -A <lines after> <query> <filename>` to confirm output.
* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
* The current date is {today}.
</SYSTEM_CAPABILITY>
<IMPORTANT>
* Plan at maximum 1 step each time, and evaluate the result of each step before proceeding. Hold back if you're not sure about the result of the step.
* If you're not sure about the location of an application, use start the app using the bash tool.
* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool.
</IMPORTANT>"""

View File

@@ -1,226 +0,0 @@
"""Response and tool handling for Anthropic provider."""
import logging
from typing import Any, Dict, List, Tuple, cast
from anthropic.types.beta import (
BetaMessage,
BetaTextBlock,
BetaContentBlockParam,
)
from .tools import ToolResult
logger = logging.getLogger(__name__)
class AnthropicResponseHandler:
"""Handles Anthropic API responses and tool execution results."""
def __init__(self, loop):
"""Initialize the response handler.
Args:
loop: Reference to the parent loop instance that provides context
"""
self.loop = loop
async def handle_response(
self, response: BetaMessage, messages: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], bool]:
"""Handle the Anthropic API response.
Args:
response: API response
messages: List of messages for context
Returns:
Tuple containing:
- List of new messages to be added
- Boolean indicating if the loop should continue
"""
try:
new_messages = []
# Convert response to parameter format
response_params = self.response_to_params(response)
# Collect all existing tool_use IDs from previous messages for validation
existing_tool_use_ids = set()
for msg in messages:
if msg.get("role") == "assistant" and isinstance(msg.get("content"), list):
for block in msg.get("content", []):
if (
isinstance(block, dict)
and block.get("type") == "tool_use"
and "id" in block
):
existing_tool_use_ids.add(block["id"])
# Also add new tool_use IDs from the current response
current_tool_use_ids = set()
for block in response_params:
if isinstance(block, dict) and block.get("type") == "tool_use" and "id" in block:
current_tool_use_ids.add(block["id"])
existing_tool_use_ids.add(block["id"])
logger.info(f"Existing tool_use IDs in conversation: {existing_tool_use_ids}")
logger.info(f"New tool_use IDs in current response: {current_tool_use_ids}")
# Create assistant message
new_messages.append(
{
"role": "assistant",
"content": response_params,
}
)
if self.loop.callback_manager is None:
raise RuntimeError(
"Callback manager not initialized. Call initialize_client() first."
)
# Handle tool use blocks and collect results
tool_result_content = []
for content_block in response_params:
# Notify callback of content
self.loop.callback_manager.on_content(cast(BetaContentBlockParam, content_block))
# Handle tool use
if content_block.get("type") == "tool_use":
if self.loop.tool_manager is None:
raise RuntimeError(
"Tool manager not initialized. Call initialize_client() first."
)
# Execute the tool
result = await self.loop.tool_manager.execute_tool(
name=content_block["name"],
tool_input=cast(Dict[str, Any], content_block["input"]),
)
# Verify the tool_use ID exists in the conversation (which it should now)
tool_use_id = content_block["id"]
if tool_use_id in existing_tool_use_ids:
# Create tool result and add to content
tool_result = self.make_tool_result(cast(ToolResult, result), tool_use_id)
tool_result_content.append(tool_result)
# Notify callback of tool result
self.loop.callback_manager.on_tool_result(
cast(ToolResult, result), content_block["id"]
)
else:
logger.warning(
f"Tool use ID {tool_use_id} not found in previous messages. Skipping tool result."
)
# If no tool results, we're done
if not tool_result_content:
# Signal completion
self.loop.callback_manager.on_content({"type": "text", "text": "<DONE>"})
return new_messages, False
# Add tool results as user message
new_messages.append({"content": tool_result_content, "role": "user"})
return new_messages, True
except Exception as e:
logger.error(f"Error handling response: {str(e)}")
new_messages.append(
{
"role": "assistant",
"content": f"Error: {str(e)}",
}
)
return new_messages, False
def response_to_params(
self,
response: BetaMessage,
) -> List[Dict[str, Any]]:
"""Convert API response to message parameters.
Args:
response: API response message
Returns:
List of content blocks
"""
result = []
for block in response.content:
if isinstance(block, BetaTextBlock):
result.append({"type": "text", "text": block.text})
else:
result.append(cast(Dict[str, Any], block.model_dump()))
return result
def make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]:
"""Convert a tool result to API format.
Args:
result: Tool execution result
tool_use_id: ID of the tool use
Returns:
Formatted tool result
"""
if result.content:
return {
"type": "tool_result",
"content": result.content,
"tool_use_id": tool_use_id,
"is_error": bool(result.error),
}
tool_result_content = []
is_error = False
if result.error:
is_error = True
tool_result_content = [
{
"type": "text",
"text": self.maybe_prepend_system_tool_result(result, result.error),
}
]
else:
if result.output:
tool_result_content.append(
{
"type": "text",
"text": self.maybe_prepend_system_tool_result(result, result.output),
}
)
if result.base64_image:
tool_result_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": result.base64_image,
},
}
)
return {
"type": "tool_result",
"content": tool_result_content,
"tool_use_id": tool_use_id,
"is_error": is_error,
}
def maybe_prepend_system_tool_result(self, result: ToolResult, result_text: str) -> str:
"""Prepend system information to tool result if available.
Args:
result: Tool execution result
result_text: Text to prepend to
Returns:
Text with system information prepended if available
"""
if result.system:
result_text = f"<s>{result.system}</s>\n{result_text}"
return result_text

View File

@@ -1,33 +0,0 @@
"""Anthropic-specific tools for agent."""
from .base import (
BaseAnthropicTool,
ToolResult,
ToolError,
ToolFailure,
CLIResult,
AnthropicToolResult,
AnthropicToolError,
AnthropicToolFailure,
AnthropicCLIResult,
)
from .bash import BashTool
from .computer import ComputerTool
from .edit import EditTool
from .manager import ToolManager
__all__ = [
"BaseAnthropicTool",
"ToolResult",
"ToolError",
"ToolFailure",
"CLIResult",
"AnthropicToolResult",
"AnthropicToolError",
"AnthropicToolFailure",
"AnthropicCLIResult",
"BashTool",
"ComputerTool",
"EditTool",
"ToolManager",
]

View File

@@ -1,88 +0,0 @@
"""Anthropic-specific tool base classes."""
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, fields, replace
from typing import Any, Dict
from anthropic.types.beta import BetaToolUnionParam
from ....core.tools.base import BaseTool
class BaseAnthropicTool(BaseTool, metaclass=ABCMeta):
"""Abstract base class for Anthropic-defined tools."""
def __init__(self):
"""Initialize the base Anthropic tool."""
# No specific initialization needed yet, but included for future extensibility
pass
@abstractmethod
async def __call__(self, **kwargs) -> Any:
"""Executes the tool with the given arguments."""
...
@abstractmethod
def to_params(self) -> Dict[str, Any]:
"""Convert tool to Anthropic-specific API parameters.
Returns:
Dictionary with tool parameters for Anthropic API
"""
raise NotImplementedError
@dataclass(kw_only=True, frozen=True)
class ToolResult:
"""Represents the result of a tool execution."""
output: str | None = None
error: str | None = None
base64_image: str | None = None
system: str | None = None
content: list[dict] | None = None
def __bool__(self):
return any(getattr(self, field.name) for field in fields(self))
def __add__(self, other: "ToolResult"):
def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True):
if field and other_field:
if concatenate:
return field + other_field
raise ValueError("Cannot combine tool results")
return field or other_field
return ToolResult(
output=combine_fields(self.output, other.output),
error=combine_fields(self.error, other.error),
base64_image=combine_fields(self.base64_image, other.base64_image, False),
system=combine_fields(self.system, other.system),
content=self.content or other.content, # Use first non-None content
)
def replace(self, **kwargs):
"""Returns a new ToolResult with the given fields replaced."""
return replace(self, **kwargs)
class CLIResult(ToolResult):
"""A ToolResult that can be rendered as a CLI output."""
class ToolFailure(ToolResult):
"""A ToolResult that represents a failure."""
class ToolError(Exception):
"""Raised when a tool encounters an error."""
def __init__(self, message):
self.message = message
# Re-export the core tool classes with Anthropic-specific names for backward compatibility
AnthropicToolResult = ToolResult
AnthropicToolError = ToolError
AnthropicToolFailure = ToolFailure
AnthropicCLIResult = CLIResult

View File

@@ -1,66 +0,0 @@
import asyncio
import os
from typing import ClassVar, Literal, Dict, Any
from computer.computer import Computer
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
from ....core.tools.bash import BaseBashTool
class BashTool(BaseBashTool, BaseAnthropicTool):
"""
A tool that allows the agent to run bash commands.
The tool parameters are defined by Anthropic and are not editable.
"""
name: ClassVar[Literal["bash"]] = "bash"
api_type: ClassVar[Literal["bash_20250124"]] = "bash_20250124"
_timeout: float = 120.0 # seconds
def __init__(self, computer: Computer):
"""Initialize the bash tool.
Args:
computer: Computer instance for executing commands
"""
# Initialize the base bash tool first
BaseBashTool.__init__(self, computer)
# Then initialize the Anthropic tool
BaseAnthropicTool.__init__(self)
# Initialize bash session
async def __call__(self, command: str | None = None, restart: bool = False, **kwargs):
"""Execute a bash command.
Args:
command: The command to execute
restart: Whether to restart the shell (not used with computer interface)
Returns:
Tool execution result
Raises:
ToolError: If command execution fails
"""
if restart:
return ToolResult(system="Restart not needed with computer interface.")
if command is None:
raise ToolError("no command provided.")
try:
async with asyncio.timeout(self._timeout):
result = await self.computer.interface.run_command(command)
return CLIResult(output=result.stdout or "", error=result.stderr or "")
except asyncio.TimeoutError as e:
raise ToolError(f"Command timed out after {self._timeout} seconds") from e
except Exception as e:
raise ToolError(f"Failed to execute command: {str(e)}")
def to_params(self) -> Dict[str, Any]:
"""Convert tool to API parameters.
Returns:
Dictionary with tool parameters
"""
return {"name": self.name, "type": self.api_type}

View File

@@ -1,34 +0,0 @@
"""Collection classes for managing multiple tools."""
from typing import Any, cast
from anthropic.types.beta import BetaToolUnionParam
from .base import (
BaseAnthropicTool,
ToolError,
ToolFailure,
ToolResult,
)
class ToolCollection:
"""A collection of anthropic-defined tools."""
def __init__(self, *tools: BaseAnthropicTool):
self.tools = tools
self.tool_map = {tool.to_params()["name"]: tool for tool in tools}
def to_params(
self,
) -> list[BetaToolUnionParam]:
return cast(list[BetaToolUnionParam], [tool.to_params() for tool in self.tools])
async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
tool = self.tool_map.get(name)
if not tool:
return ToolFailure(error=f"Tool {name} is invalid")
try:
return await tool(**tool_input)
except ToolError as e:
return ToolFailure(error=e.message)

View File

@@ -1,396 +0,0 @@
import asyncio
import base64
import io
import logging
from enum import StrEnum
from pathlib import Path
from typing import Literal, TypedDict, Any, Dict
import subprocess
from PIL import Image
from datetime import datetime
from computer.computer import Computer
from .base import BaseAnthropicTool, ToolError, ToolResult
from .run import run
from ....core.tools.computer import BaseComputerTool
TYPING_DELAY_MS = 12
TYPING_GROUP_SIZE = 50
Action = Literal[
"key",
"type",
"mouse_move",
"left_click",
"left_click_drag",
"right_click",
"middle_click",
"double_click",
"screenshot",
"cursor_position",
"scroll",
]
class Resolution(TypedDict):
width: int
height: int
class ScalingSource(StrEnum):
COMPUTER = "computer"
API = "api"
class ComputerToolOptions(TypedDict):
display_height_px: int
display_width_px: int
display_number: int | None
def chunks(s: str, chunk_size: int) -> list[str]:
return [s[i : i + chunk_size] for i in range(0, len(s), chunk_size)]
class ComputerTool(BaseComputerTool, BaseAnthropicTool):
"""
A tool that allows the agent to interact with the screen, keyboard, and mouse of the current macOS computer.
The tool parameters are defined by Anthropic and are not editable.
"""
name: Literal["computer"] = "computer"
api_type: Literal["computer_20250124"] = "computer_20250124"
width: int | None = None
height: int | None = None
display_num: int | None = None
computer: Computer # The CUA Computer instance
logger = logging.getLogger(__name__)
_screenshot_delay = 1.0 # macOS is generally faster than X11
_scaling_enabled = True
@property
def options(self) -> ComputerToolOptions:
if self.width is None or self.height is None:
raise RuntimeError(
"Screen dimensions not initialized. Call initialize_dimensions() first."
)
return {
"display_width_px": self.width,
"display_height_px": self.height,
"display_number": self.display_num,
}
def to_params(self) -> Dict[str, Any]:
"""Convert tool to API parameters.
Returns:
Dictionary with tool parameters
"""
return {"name": self.name, "type": self.api_type, **self.options}
def __init__(self, computer):
# Initialize the base computer tool first
BaseComputerTool.__init__(self, computer)
# Then initialize the Anthropic tool
BaseAnthropicTool.__init__(self)
# Additional initialization
self.width = None # Will be initialized from computer interface
self.height = None # Will be initialized from computer interface
self.display_num = None
async def initialize_dimensions(self):
"""Initialize screen dimensions from the computer interface."""
display_size = await self.computer.interface.get_screen_size()
self.width = display_size["width"]
self.height = display_size["height"]
assert isinstance(self.width, int) and isinstance(self.height, int)
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
async def __call__(
self,
*,
action: Action,
text: str | None = None,
coordinate: tuple[int, int] | None = None,
**kwargs,
):
try:
# Ensure dimensions are initialized
if self.width is None or self.height is None:
await self.initialize_dimensions()
if self.width is None or self.height is None:
raise ToolError("Failed to initialize screen dimensions")
except Exception as e:
raise ToolError(f"Failed to initialize dimensions: {e}")
if action in ("mouse_move", "left_click_drag"):
if coordinate is None:
raise ToolError(f"coordinate is required for {action}")
if text is not None:
raise ToolError(f"text is not accepted for {action}")
if not isinstance(coordinate, (list, tuple)) or len(coordinate) != 2:
raise ToolError(f"{coordinate} must be a tuple of length 2")
if not all(isinstance(i, int) and i >= 0 for i in coordinate):
raise ToolError(f"{coordinate} must be a tuple of non-negative ints")
try:
x, y = coordinate
self.logger.info(f"Handling {action} action:")
self.logger.info(f" Coordinates: ({x}, {y})")
# Take pre-action screenshot to get current dimensions
pre_screenshot = await self.computer.interface.screenshot()
pre_img = Image.open(io.BytesIO(pre_screenshot))
# Scale image to match screen dimensions if needed
if pre_img.size != (self.width, self.height):
self.logger.info(
f"Scaling image from {pre_img.size} to {self.width}x{self.height} to match screen dimensions"
)
if not isinstance(self.width, int) or not isinstance(self.height, int):
raise ToolError("Screen dimensions must be integers")
size = (int(self.width), int(self.height))
pre_img = pre_img.resize(size, Image.Resampling.LANCZOS)
self.logger.info(f" Current dimensions: {pre_img.width}x{pre_img.height}")
if action == "mouse_move":
self.logger.info(f"Moving cursor to ({x}, {y})")
await self.computer.interface.move_cursor(x, y)
elif action == "left_click_drag":
# Get the start coordinate from kwargs
start_coordinate = kwargs.get("start_coordinate")
if not start_coordinate:
raise ToolError("start_coordinate is required for left_click_drag action")
start_x, start_y = start_coordinate
end_x, end_y = x, y
self.logger.info(f"Dragging from ({start_x}, {start_y}) to ({end_x}, {end_y})")
await self.computer.interface.move_cursor(start_x, start_y)
await self.computer.interface.drag_to(end_x, end_y)
# Wait briefly for any UI changes
await asyncio.sleep(0.5)
# Take post-action screenshot
post_screenshot = await self.computer.interface.screenshot()
post_img = Image.open(io.BytesIO(post_screenshot))
# Scale post-action image if needed
if post_img.size != (self.width, self.height):
self.logger.info(
f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}"
)
post_img = post_img.resize((self.width, self.height), Image.Resampling.LANCZOS)
buffer = io.BytesIO()
post_img.save(buffer, format="PNG")
post_screenshot = buffer.getvalue()
return ToolResult(
output=f"{'Moved cursor to' if action == 'mouse_move' else 'Dragged to'} {x},{y}",
base64_image=base64.b64encode(post_screenshot).decode(),
)
except Exception as e:
self.logger.error(f"Error during {action} action: {str(e)}")
raise ToolError(f"Failed to perform {action}: {str(e)}")
elif action in ("left_click", "right_click", "double_click"):
if coordinate:
x, y = coordinate
self.logger.info(f"Handling {action} action:")
self.logger.info(f" Coordinates: ({x}, {y})")
try:
# Perform the click action
if action == "left_click":
self.logger.info(f"Clicking at ({x}, {y})")
await self.computer.interface.move_cursor(x, y)
await self.computer.interface.left_click()
elif action == "right_click":
self.logger.info(f"Right clicking at ({x}, {y})")
await self.computer.interface.move_cursor(x, y)
await self.computer.interface.right_click()
elif action == "double_click":
self.logger.info(f"Double clicking at ({x}, {y})")
await self.computer.interface.move_cursor(x, y)
await self.computer.interface.double_click()
# Wait briefly for any UI changes
await asyncio.sleep(0.5)
return ToolResult(
output=f"Performed {action} at ({x}, {y})",
)
except Exception as e:
self.logger.error(f"Error during {action} action: {str(e)}")
raise ToolError(f"Failed to perform {action}: {str(e)}")
else:
try:
# Perform the click action
if action == "left_click":
self.logger.info("Performing left click at current position")
await self.computer.interface.left_click()
elif action == "right_click":
self.logger.info("Performing right click at current position")
await self.computer.interface.right_click()
elif action == "double_click":
self.logger.info("Performing double click at current position")
await self.computer.interface.double_click()
# Wait briefly for any UI changes
await asyncio.sleep(0.5)
return ToolResult(
output=f"Performed {action} at current position",
)
except Exception as e:
self.logger.error(f"Error during {action} action: {str(e)}")
raise ToolError(f"Failed to perform {action}: {str(e)}")
elif action in ("key", "type"):
if text is None:
raise ToolError(f"text is required for {action}")
if coordinate is not None:
raise ToolError(f"coordinate is not accepted for {action}")
if not isinstance(text, str):
raise ToolError(f"{text} must be a string")
try:
if action == "key":
# Special handling for page up/down on macOS
if text.lower() in ["pagedown", "page_down", "page down"]:
self.logger.info("Converting page down to fn+down for macOS")
await self.computer.interface.hotkey("fn", "down")
output_text = "fn+down"
elif text.lower() in ["pageup", "page_up", "page up"]:
self.logger.info("Converting page up to fn+up for macOS")
await self.computer.interface.hotkey("fn", "up")
output_text = "fn+up"
elif text == "fn+down":
self.logger.info("Using fn+down combination")
await self.computer.interface.hotkey("fn", "down")
output_text = text
elif text == "fn+up":
self.logger.info("Using fn+up combination")
await self.computer.interface.hotkey("fn", "up")
output_text = text
elif "+" in text:
# Handle hotkey combinations
keys = text.split("+")
self.logger.info(f"Pressing hotkey combination: {text}")
await self.computer.interface.hotkey(*keys)
output_text = text
else:
# Handle single key press
self.logger.info(f"Pressing key: {text}")
try:
await self.computer.interface.press_key(text)
output_text = text
except ValueError as e:
raise ToolError(f"Invalid key: {text}. {str(e)}")
# Wait briefly for UI changes
await asyncio.sleep(0.5)
return ToolResult(
output=f"Pressed key: {output_text}",
)
elif action == "type":
self.logger.info(f"Typing text: {text}")
await self.computer.interface.type_text(text)
# Wait briefly for UI changes
await asyncio.sleep(0.5)
return ToolResult(
output=f"Typed text: {text}",
)
except Exception as e:
self.logger.error(f"Error during {action} action: {str(e)}")
raise ToolError(f"Failed to perform {action}: {str(e)}")
elif action == "scroll":
# Implement scroll action
direction = kwargs.get("direction", "down")
amount = kwargs.get("amount", 10)
if direction not in ["up", "down"]:
raise ToolError(f"Invalid scroll direction: {direction}. Must be 'up' or 'down'.")
try:
if direction == "down":
# Scroll down (Page Down on macOS)
self.logger.info(f"Scrolling down, amount: {amount}")
await self.computer.interface.scroll_down(amount)
else:
# Scroll up (Page Up on macOS)
self.logger.info(f"Scrolling up, amount: {amount}")
await self.computer.interface.scroll_up(amount)
# Wait briefly for UI changes
await asyncio.sleep(0.5)
return ToolResult(
output=f"Scrolled {direction} by {amount} steps",
)
except Exception as e:
self.logger.error(f"Error during scroll action: {str(e)}")
raise ToolError(f"Failed to perform scroll: {str(e)}")
elif action == "screenshot":
# Take screenshot
return await self.screenshot()
elif action == "cursor_position":
pos = await self.computer.interface.get_cursor_position()
x, y = pos # Unpack the tuple
return ToolResult(output=f"X={int(x)},Y={int(y)}")
raise ToolError(f"Invalid action: {action}")
async def screenshot(self):
"""Take a screenshot and return it as a base64-encoded string."""
try:
screenshot = await self.computer.interface.screenshot()
img = Image.open(io.BytesIO(screenshot))
# Scale image if needed
if img.size != (self.width, self.height):
self.logger.info(f"Scaling image from {img.size} to {self.width}x{self.height}")
if not isinstance(self.width, int) or not isinstance(self.height, int):
raise ToolError("Screen dimensions must be integers")
size = (int(self.width), int(self.height))
img = img.resize(size, Image.Resampling.LANCZOS)
buffer = io.BytesIO()
img.save(buffer, format="PNG")
screenshot = buffer.getvalue()
return ToolResult(base64_image=base64.b64encode(screenshot).decode())
except Exception as e:
self.logger.error(f"Error taking screenshot: {str(e)}")
return ToolResult(error=f"Failed to take screenshot: {str(e)}")
async def shell(self, command: str, take_screenshot=False) -> ToolResult:
"""Run a shell command and return the output, error, and optionally a screenshot."""
try:
_, stdout, stderr = await run(command)
base64_image = None
if take_screenshot:
# delay to let things settle before taking a screenshot
await asyncio.sleep(self._screenshot_delay)
screenshot_result = await self.screenshot()
if screenshot_result.error:
return ToolResult(
output=stdout,
error=f"{stderr}\nScreenshot error: {screenshot_result.error}",
)
base64_image = screenshot_result.base64_image
return ToolResult(output=stdout, error=stderr, base64_image=base64_image)
except Exception as e:
return ToolResult(error=f"Shell command failed: {str(e)}")

View File

@@ -1,326 +0,0 @@
from collections import defaultdict
from pathlib import Path
from typing import Literal, get_args, Dict, Any
from computer.computer import Computer
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
from ....core.tools.edit import BaseEditTool
from .run import maybe_truncate
Command = Literal[
"view",
"create",
"str_replace",
"insert",
"undo_edit",
]
SNIPPET_LINES: int = 4
class EditTool(BaseEditTool, BaseAnthropicTool):
"""
An filesystem editor tool that allows the agent to view, create, and edit files.
The tool parameters are defined by Anthropic and are not editable.
"""
api_type: Literal["text_editor_20250124"] = "text_editor_20250124"
name: Literal["str_replace_editor"] = "str_replace_editor"
_timeout: float = 30.0 # seconds
def __init__(self, computer: Computer):
"""Initialize the edit tool.
Args:
computer: Computer instance for file operations
"""
# Initialize the base edit tool first
BaseEditTool.__init__(self, computer)
# Then initialize the Anthropic tool
BaseAnthropicTool.__init__(self)
# Edit history for the current session
self.edit_history = defaultdict(list)
async def __call__(
self,
*,
command: Command,
path: str,
file_text: str | None = None,
view_range: list[int] | None = None,
old_str: str | None = None,
new_str: str | None = None,
insert_line: int | None = None,
**kwargs,
):
_path = Path(path)
await self.validate_path(command, _path)
if command == "view":
return await self.view(_path, view_range)
elif command == "create":
if file_text is None:
raise ToolError("Parameter `file_text` is required for command: create")
await self.write_file(_path, file_text)
self.edit_history[_path].append(file_text)
return ToolResult(output=f"File created successfully at: {_path}")
elif command == "str_replace":
if old_str is None:
raise ToolError("Parameter `old_str` is required for command: str_replace")
return await self.str_replace(_path, old_str, new_str)
elif command == "insert":
if insert_line is None:
raise ToolError("Parameter `insert_line` is required for command: insert")
if new_str is None:
raise ToolError("Parameter `new_str` is required for command: insert")
return await self.insert(_path, insert_line, new_str)
elif command == "undo_edit":
return await self.undo_edit(_path)
raise ToolError(
f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
)
async def validate_path(self, command: str, path: Path):
"""Check that the path/command combination is valid."""
# Check if its an absolute path
if not path.is_absolute():
suggested_path = Path("") / path
raise ToolError(
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
)
# Check if path exists using bash commands
try:
result = await self.computer.interface.run_command(
f'[ -e "{str(path)}" ] && echo "exists" || echo "not exists"'
)
exists = result.stdout.strip() == "exists"
if exists:
result = await self.computer.interface.run_command(
f'[ -d "{str(path)}" ] && echo "dir" || echo "file"'
)
is_dir = result.stdout.strip() == "dir"
else:
is_dir = False
# Check path validity
if not exists and command != "create":
raise ToolError(f"The path {path} does not exist. Please provide a valid path.")
if exists and command == "create":
raise ToolError(
f"File already exists at: {path}. Cannot overwrite files using command `create`."
)
if is_dir and command != "view":
raise ToolError(
f"The path {path} is a directory and only the `view` command can be used on directories"
)
except Exception as e:
raise ToolError(f"Failed to validate path: {str(e)}")
async def view(self, path: Path, view_range: list[int] | None = None):
"""Implement the view command"""
try:
# Check if path is a directory
result = await self.computer.interface.run_command(
f'[ -d "{str(path)}" ] && echo "dir" || echo "file"'
)
is_dir = result.stdout.strip() == "dir"
if is_dir:
if view_range:
raise ToolError(
"The `view_range` parameter is not allowed when `path` points to a directory."
)
# List directory contents using ls
result = await self.computer.interface.run_command(f'ls -la "{str(path)}"')
contents = result.stdout
if contents:
stdout = f"Here's the files and directories in {path}:\n{contents}\n"
else:
stdout = f"Directory {path} is empty\n"
return CLIResult(output=stdout)
# Read file content using cat
file_content = await self.read_file(path)
init_line = 1
if view_range:
if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
raise ToolError("Invalid `view_range`. It should be a list of two integers.")
file_lines = file_content.split("\n")
n_lines_file = len(file_lines)
init_line, final_line = view_range
if init_line < 1 or init_line > n_lines_file:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
)
if final_line > n_lines_file:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
)
if final_line != -1 and final_line < init_line:
raise ToolError(
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
)
if final_line == -1:
file_content = "\n".join(file_lines[init_line - 1 :])
else:
file_content = "\n".join(file_lines[init_line - 1 : final_line])
return CLIResult(output=self._make_output(file_content, str(path), init_line=init_line))
except Exception as e:
raise ToolError(f"Failed to view path: {str(e)}")
async def str_replace(self, path: Path, old_str: str, new_str: str | None):
"""Implement the str_replace command"""
# Read the file content
file_content = await self.read_file(path)
file_content = file_content.expandtabs()
old_str = old_str.expandtabs()
new_str = new_str.expandtabs() if new_str is not None else ""
# Check if old_str is unique in the file
occurrences = file_content.count(old_str)
if occurrences == 0:
raise ToolError(
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
)
elif occurrences > 1:
file_content_lines = file_content.split("\n")
lines = [idx + 1 for idx, line in enumerate(file_content_lines) if old_str in line]
raise ToolError(
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
)
# Replace old_str with new_str
new_file_content = file_content.replace(old_str, new_str)
# Write the new content to the file
await self.write_file(path, new_file_content)
# Save the content to history
self.edit_history[path].append(file_content)
# Create a snippet of the edited section
replacement_line = file_content.split(old_str)[0].count("\n")
start_line = max(0, replacement_line - SNIPPET_LINES)
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1])
# Prepare the success message
success_msg = f"The file {path} has been edited. "
success_msg += self._make_output(snippet, f"a snippet of {path}", start_line + 1)
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
return CLIResult(output=success_msg)
async def insert(self, path: Path, insert_line: int, new_str: str):
"""Implement the insert command"""
file_text = await self.read_file(path)
file_text = file_text.expandtabs()
new_str = new_str.expandtabs()
file_text_lines = file_text.split("\n")
n_lines_file = len(file_text_lines)
if insert_line < 0 or insert_line > n_lines_file:
raise ToolError(
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
)
new_str_lines = new_str.split("\n")
new_file_text_lines = (
file_text_lines[:insert_line] + new_str_lines + file_text_lines[insert_line:]
)
snippet_lines = (
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
+ new_str_lines
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
)
new_file_text = "\n".join(new_file_text_lines)
snippet = "\n".join(snippet_lines)
await self.write_file(path, new_file_text)
self.edit_history[path].append(file_text)
success_msg = f"The file {path} has been edited. "
success_msg += self._make_output(
snippet, "a snippet of the edited file", max(1, insert_line - SNIPPET_LINES + 1)
)
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
return CLIResult(output=success_msg)
async def undo_edit(self, path: Path):
"""Implement the undo_edit command"""
if not self.edit_history[path]:
raise ToolError(f"No edit history found for {path}.")
old_text = self.edit_history[path].pop()
await self.write_file(path, old_text)
return CLIResult(
output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}"
)
async def read_file(self, path: Path) -> str:
"""Read the content of a file using cat command."""
try:
result = await self.computer.interface.run_command(f'cat "{str(path)}"')
if result.stderr: # If there's stderr output
raise ToolError(f"Error reading file: {result.stderr}")
return result.stdout
except Exception as e:
raise ToolError(f"Failed to read {path}: {str(e)}")
async def write_file(self, path: Path, content: str):
"""Write content to a file using echo and redirection."""
try:
# Create parent directories if they don't exist
parent = path.parent
if parent != Path("/"):
await self.computer.interface.run_command(f'mkdir -p "{str(parent)}"')
# Write content to file using echo and heredoc to preserve formatting
cmd = f"""cat > "{str(path)}" << 'EOFCUA'
{content}
EOFCUA"""
result = await self.computer.interface.run_command(cmd)
if result.stderr: # If there's stderr output
raise ToolError(f"Error writing file: {result.stderr}")
except Exception as e:
raise ToolError(f"Failed to write to {path}: {str(e)}")
def _make_output(
self,
file_content: str,
file_descriptor: str,
init_line: int = 1,
expand_tabs: bool = True,
) -> str:
"""Generate output for the CLI based on the content of a file."""
file_content = maybe_truncate(file_content)
if expand_tabs:
file_content = file_content.expandtabs()
file_content = "\n".join(
[f"{i + init_line:6}\t{line}" for i, line in enumerate(file_content.split("\n"))]
)
return (
f"Here's the result of running `cat -n` on {file_descriptor}:\n" + file_content + "\n"
)
def to_params(self) -> Dict[str, Any]:
"""Convert tool to API parameters.
Returns:
Dictionary with tool parameters
"""
return {
"name": self.name,
"type": self.api_type,
}

View File

@@ -1,54 +0,0 @@
from typing import Any, Dict, List, cast
from anthropic.types.beta import BetaToolUnionParam
from computer.computer import Computer
from ....core.tools import BaseToolManager, ToolResult
from ....core.tools.collection import ToolCollection
from .bash import BashTool
from .computer import ComputerTool
from .edit import EditTool
class ToolManager(BaseToolManager):
"""Manages Anthropic-specific tool initialization and execution."""
def __init__(self, computer: Computer):
"""Initialize the tool manager.
Args:
computer: Computer instance for computer-related tools
"""
super().__init__(computer)
# Initialize Anthropic-specific tools
self.computer_tool = ComputerTool(self.computer)
self.bash_tool = BashTool(self.computer)
self.edit_tool = EditTool(self.computer)
def _initialize_tools(self) -> ToolCollection:
"""Initialize all available tools."""
return ToolCollection(self.computer_tool, self.bash_tool, self.edit_tool)
async def _initialize_tools_specific(self) -> None:
"""Initialize Anthropic-specific tool requirements."""
await self.computer_tool.initialize_dimensions()
def get_tool_params(self) -> List[BetaToolUnionParam]:
"""Get tool parameters for Anthropic API calls."""
if self.tools is None:
raise RuntimeError("Tools not initialized. Call initialize() first.")
return cast(List[BetaToolUnionParam], self.tools.to_params())
async def execute_tool(self, name: str, tool_input: dict[str, Any]) -> ToolResult:
"""Execute a tool with the given input.
Args:
name: Name of the tool to execute
tool_input: Input parameters for the tool
Returns:
Result of the tool execution
"""
if self.tools is None:
raise RuntimeError("Tools not initialized. Call initialize() first.")
return await self.tools.run(name=name, tool_input=tool_input)

View File

@@ -1,42 +0,0 @@
"""Utility to run shell commands asynchronously with a timeout."""
import asyncio
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
MAX_RESPONSE_LEN: int = 16000
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
"""Truncate content and append a notice if content exceeds the specified length."""
return (
content
if not truncate_after or len(content) <= truncate_after
else content[:truncate_after] + TRUNCATED_MESSAGE
)
async def run(
cmd: str,
timeout: float | None = 120.0, # seconds
truncate_after: int | None = MAX_RESPONSE_LEN,
):
"""Run a shell command asynchronously with a timeout."""
process = await asyncio.create_subprocess_shell(
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
try:
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
return (
process.returncode or 0,
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
)
except asyncio.TimeoutError as exc:
try:
process.kill()
except ProcessLookupError:
pass
raise TimeoutError(
f"Command '{cmd}' timed out after {timeout} seconds"
) from exc

View File

@@ -1,16 +0,0 @@
from enum import StrEnum
class LLMProvider(StrEnum):
"""Enum for supported API providers."""
ANTHROPIC = "anthropic"
BEDROCK = "bedrock"
VERTEX = "vertex"
PROVIDER_TO_DEFAULT_MODEL_NAME: dict[LLMProvider, str] = {
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
LLMProvider.BEDROCK: "anthropic.claude-3-7-sonnet-20250219-v2:0",
LLMProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
}

View File

@@ -1,367 +0,0 @@
"""Utility functions for Anthropic message handling."""
import logging
import re
from typing import Any, Dict, List, Optional, Tuple, cast
from anthropic.types.beta import BetaMessage
from ...core.types import AgentResponse
from datetime import datetime
# Configure module logger
logger = logging.getLogger(__name__)
def to_anthropic_format(
messages: List[Dict[str, Any]],
) -> Tuple[List[Dict[str, Any]], str]:
"""Convert standard OpenAI format messages to Anthropic format.
Args:
messages: List of messages in OpenAI format
Returns:
Tuple containing (anthropic_messages, system_content)
"""
result = []
system_content = ""
# Process messages in order to maintain conversation flow
previous_assistant_tool_use_ids = set() # Track tool_use_ids in the previous assistant message
for i, msg in enumerate(messages):
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
# Collect system messages for later use
system_content += content + "\n"
continue
if role == "assistant":
# Track tool_use_ids in this assistant message for the next user message
previous_assistant_tool_use_ids = set()
if isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get("type") == "tool_use" and "id" in item:
previous_assistant_tool_use_ids.add(item["id"])
if role in ["user", "assistant"]:
anthropic_msg = {"role": role}
# Convert content based on type
if isinstance(content, str):
# Simple text content
anthropic_msg["content"] = [{"type": "text", "text": content}]
elif isinstance(content, list):
# Convert complex content
anthropic_content = []
for item in content:
item_type = item.get("type", "")
if item_type == "text":
anthropic_content.append({"type": "text", "text": item.get("text", "")})
elif item_type == "image_url":
# Convert OpenAI image format to Anthropic
image_url = item.get("image_url", {}).get("url", "")
if image_url.startswith("data:"):
# Extract base64 data and media type
match = re.match(r"data:(.+);base64,(.+)", image_url)
if match:
media_type, data = match.groups()
anthropic_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data,
},
}
)
else:
# Regular URL
anthropic_content.append(
{
"type": "image",
"source": {
"type": "url",
"url": image_url,
},
}
)
elif item_type == "tool_use":
# Always include tool_use blocks
anthropic_content.append(item)
elif item_type == "tool_result":
# Check if this is a user message AND if the tool_use_id exists in the previous assistant message
tool_use_id = item.get("tool_use_id")
# Only include tool_result if it references a tool_use from the immediately preceding assistant message
if (
role == "user"
and tool_use_id
and tool_use_id in previous_assistant_tool_use_ids
):
anthropic_content.append(item)
else:
content_text = "Tool Result: "
if "content" in item:
if isinstance(item["content"], list):
for content_item in item["content"]:
if (
isinstance(content_item, dict)
and content_item.get("type") == "text"
):
content_text += content_item.get("text", "")
elif isinstance(item["content"], str):
content_text += item["content"]
anthropic_content.append({"type": "text", "text": content_text})
anthropic_msg["content"] = anthropic_content
result.append(anthropic_msg)
return result, system_content
def from_anthropic_format(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert Anthropic format messages to standard OpenAI format.
Args:
messages: List of messages in Anthropic format
Returns:
List of messages in OpenAI format
"""
result = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", [])
if role in ["user", "assistant"]:
openai_msg = {"role": role}
# Simple case: single text block
if len(content) == 1 and content[0].get("type") == "text":
openai_msg["content"] = content[0].get("text", "")
else:
# Complex case: multiple blocks or non-text
openai_content = []
for item in content:
item_type = item.get("type", "")
if item_type == "text":
openai_content.append({"type": "text", "text": item.get("text", "")})
elif item_type == "image":
# Convert Anthropic image to OpenAI format
source = item.get("source", {})
if source.get("type") == "base64":
media_type = source.get("media_type", "image/png")
data = source.get("data", "")
openai_content.append(
{
"type": "image_url",
"image_url": {"url": f"data:{media_type};base64,{data}"},
}
)
else:
# URL
openai_content.append(
{
"type": "image_url",
"image_url": {"url": source.get("url", "")},
}
)
elif item_type in ["tool_use", "tool_result"]:
# Pass through tool-related content
openai_content.append(item)
openai_msg["content"] = openai_content
result.append(openai_msg)
return result
async def to_agent_response_format(
response: BetaMessage,
messages: List[Dict[str, Any]],
parsed_screen: Optional[dict] = None,
parser: Optional[Any] = None,
model: Optional[str] = None,
) -> AgentResponse:
"""Convert an Anthropic response to the standard agent response format.
Args:
response: The Anthropic API response (BetaMessage)
messages: List of messages in standard format
parsed_screen: Optional pre-parsed screen information
parser: Optional parser instance for coordinate calculation
model: Optional model name
Returns:
A response formatted according to the standard agent response format
"""
# Create unique IDs for this response
response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(response)}"
reasoning_id = f"rs_{response_id}"
action_id = f"cu_{response_id}"
call_id = f"call_{response_id}"
# Extract content and reasoning from Anthropic response
content = []
reasoning_text = None
action_details = None
for block in response.content:
if block.type == "text":
# Use the first text block as reasoning
if reasoning_text is None:
reasoning_text = block.text
content.append({"type": "text", "text": block.text})
elif block.type == "tool_use" and block.name == "computer":
try:
input_dict = cast(Dict[str, Any], block.input)
action = input_dict.get("action", "").lower()
# Extract coordinates from coordinate list if provided
coordinates = input_dict.get("coordinate", [100, 100])
x, y = coordinates if len(coordinates) == 2 else (100, 100)
if action == "screenshot":
action_details = {
"type": "screenshot",
}
elif action in ["click", "left_click", "right_click", "double_click"]:
action_details = {
"type": "click",
"button": "left" if action in ["click", "left_click"] else "right",
"double": action == "double_click",
"x": x,
"y": y,
}
elif action == "type":
action_details = {
"type": "type",
"text": input_dict.get("text", ""),
}
elif action == "key":
action_details = {
"type": "hotkey",
"keys": [input_dict.get("text", "")],
}
elif action == "scroll":
scroll_amount = input_dict.get("scroll_amount", 1)
scroll_direction = input_dict.get("scroll_direction", "down")
delta_y = scroll_amount if scroll_direction == "down" else -scroll_amount
action_details = {
"type": "scroll",
"x": x,
"y": y,
"delta_x": 0,
"delta_y": delta_y,
}
elif action == "move":
action_details = {
"type": "move",
"x": x,
"y": y,
}
except Exception as e:
logger.error(f"Error extracting action details: {str(e)}")
# Create output items with reasoning
output_items = []
if reasoning_text:
output_items.append(
{
"type": "reasoning",
"id": reasoning_id,
"summary": [
{
"type": "summary_text",
"text": reasoning_text,
}
],
}
)
# Add computer_call item with extracted or default action
computer_call = {
"type": "computer_call",
"id": action_id,
"call_id": call_id,
"action": action_details or {"type": "none", "description": "No action specified"},
"pending_safety_checks": [],
"status": "completed",
}
output_items.append(computer_call)
# Create the standard response format
standard_response = {
"id": response_id,
"object": "response",
"created_at": int(datetime.now().timestamp()),
"status": "completed",
"error": None,
"incomplete_details": None,
"instructions": None,
"max_output_tokens": None,
"model": model or "anthropic-default",
"output": output_items,
"parallel_tool_calls": True,
"previous_response_id": None,
"reasoning": {"effort": "medium", "generate_summary": "concise"},
"store": True,
"temperature": 1.0,
"text": {"format": {"type": "text"}},
"tool_choice": "auto",
"tools": [
{
"type": "computer_use_preview",
"display_height": 768,
"display_width": 1024,
"environment": "mac",
}
],
"top_p": 1.0,
"truncation": "auto",
"usage": {
"input_tokens": 0,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens": 0,
"output_tokens_details": {"reasoning_tokens": 0},
"total_tokens": 0,
},
"user": None,
"metadata": {},
"response": {
"choices": [
{
"message": {
"role": "assistant",
"content": content,
"tool_calls": [],
},
"finish_reason": response.stop_reason or "stop",
}
]
},
}
# Add tool calls if present
tool_calls = []
for block in response.content:
if hasattr(block, "type") and block.type == "tool_use":
tool_calls.append(
{
"id": f"call_{block.id}",
"type": "function",
"function": {"name": block.name, "arguments": block.input},
}
)
if tool_calls:
standard_response["response"]["choices"][0]["message"]["tool_calls"] = tool_calls
return cast(AgentResponse, standard_response)

View File

@@ -1,8 +0,0 @@
"""Omni provider implementation."""
from ...core.types import LLMProvider
from .image_utils import (
decode_base64_image,
)
__all__ = ["LLMProvider", "decode_base64_image"]

View File

@@ -1,42 +0,0 @@
"""API handling for Omni provider."""
import logging
from typing import Any, Dict, List
from .prompts import SYSTEM_PROMPT
logger = logging.getLogger(__name__)
class OmniAPIHandler:
"""Handler for Omni API calls."""
def __init__(self, loop):
"""Initialize the API handler.
Args:
loop: Parent loop instance
"""
self.loop = loop
async def make_api_call(
self, messages: List[Dict[str, Any]], system_prompt: str = SYSTEM_PROMPT
) -> Any:
"""Make an API call to the appropriate provider.
Args:
messages: List of messages in standard OpenAI format
system_prompt: System prompt to use
Returns:
API response
"""
if not self.loop._make_api_call:
raise RuntimeError("Loop does not have _make_api_call method")
try:
# Use the loop's _make_api_call method with standard messages
return await self.loop._make_api_call(messages=messages, system_prompt=system_prompt)
except Exception as e:
logger.error(f"Error making API call: {str(e)}")
raise

View File

@@ -1,103 +0,0 @@
"""Anthropic API client implementation."""
import logging
from typing import Any, Dict, List, Optional, Tuple, cast
import asyncio
from httpx import ConnectError, ReadTimeout
from anthropic import AsyncAnthropic, Anthropic
from anthropic.types import MessageParam
from .base import BaseOmniClient
logger = logging.getLogger(__name__)
class AnthropicClient(BaseOmniClient):
"""Client for making calls to Anthropic API."""
def __init__(self, api_key: str, model: str, max_retries: int = 3, retry_delay: float = 1.0):
"""Initialize the Anthropic client.
Args:
api_key: Anthropic API key
model: Anthropic model name (e.g. "claude-3-opus-20240229")
max_retries: Maximum number of retries for API calls
retry_delay: Base delay between retries in seconds
"""
if not model:
raise ValueError("Model name must be provided")
self.client = AsyncAnthropic(api_key=api_key)
self.model: str = model # Add explicit type annotation
self.max_retries = max_retries
self.retry_delay = retry_delay
def _convert_message_format(self, messages: List[Dict[str, Any]]) -> List[MessageParam]:
"""Convert messages from standard format to Anthropic format.
Args:
messages: Messages in standard format
Returns:
Messages in Anthropic format
"""
anthropic_messages = []
for message in messages:
# Skip messages with empty content
if not message.get("content"):
continue
if message["role"] == "user":
anthropic_messages.append({"role": "user", "content": message["content"]})
elif message["role"] == "assistant":
anthropic_messages.append({"role": "assistant", "content": message["content"]})
# Cast the list to the correct type expected by Anthropic
return cast(List[MessageParam], anthropic_messages)
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: int
) -> Any:
"""Run model with interleaved conversation format.
Args:
messages: List of messages to process
system: System prompt
max_tokens: Maximum tokens to generate
Returns:
Model response
"""
last_error = None
for attempt in range(self.max_retries):
try:
# Convert messages to Anthropic format
anthropic_messages = self._convert_message_format(messages)
response = await self.client.messages.create(
model=self.model,
max_tokens=max_tokens,
temperature=0,
system=system,
messages=anthropic_messages,
)
return response
except (ConnectError, ReadTimeout) as e:
last_error = e
logger.warning(
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
)
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
continue
except Exception as e:
logger.error(f"Unexpected error in Anthropic API call: {str(e)}")
raise RuntimeError(f"Anthropic API call failed: {str(e)}")
# If we get here, all retries failed
raise RuntimeError(f"Connection error after {self.max_retries} retries: {str(last_error)}")

View File

@@ -1,35 +0,0 @@
"""Base client implementation for Omni providers."""
import logging
from typing import Dict, List, Optional, Any, Tuple
logger = logging.getLogger(__name__)
class BaseOmniClient:
"""Base class for provider-specific clients."""
def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None):
"""Initialize base client.
Args:
api_key: Optional API key
model: Optional model name
"""
self.api_key = api_key
self.model = model
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""Run interleaved chat completion.
Args:
messages: List of message dicts
system: System prompt
max_tokens: Optional max tokens override
Returns:
Response dict
"""
raise NotImplementedError

View File

@@ -1,195 +0,0 @@
"""OpenAI-compatible client implementation."""
import os
import logging
from typing import Dict, List, Optional, Any
import aiohttp
import re
from .base import BaseOmniClient
logger = logging.getLogger(__name__)
# OpenAI-compatible client for the OmniLoop
class OAICompatClient(BaseOmniClient):
"""OpenAI-compatible API client implementation.
This client can be used with any service that implements the OpenAI API protocol, including:
- vLLM
- LM Studio
- LocalAI
- Ollama (with OpenAI compatibility)
- Text Generation WebUI
- Any other service with OpenAI API compatibility
"""
def __init__(
self,
api_key: Optional[str] = None,
model: str = "Qwen2.5-VL-7B-Instruct",
provider_base_url: Optional[str] = "http://localhost:8000/v1",
max_tokens: int = 4096,
temperature: float = 0.0,
):
"""Initialize the OpenAI-compatible client.
Args:
api_key: Not used for local endpoints, usually set to "EMPTY"
model: Model name to use
provider_base_url: API base URL. Typically in the format "http://localhost:PORT/v1"
Examples:
- vLLM: "http://localhost:8000/v1"
- LM Studio: "http://localhost:1234/v1"
- LocalAI: "http://localhost:8080/v1"
- Ollama: "http://localhost:11434/v1"
max_tokens: Maximum tokens to generate
temperature: Generation temperature
"""
super().__init__(api_key=api_key or "EMPTY", model=model)
self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
self.model = model
self.provider_base_url = (
provider_base_url or "http://localhost:8000/v1"
) # Use default if None
self.max_tokens = max_tokens
self.temperature = temperature
def _extract_base64_image(self, text: str) -> Optional[str]:
"""Extract base64 image data from an HTML img tag."""
pattern = r'data:image/[^;]+;base64,([^"]+)'
match = re.search(pattern, text)
return match.group(1) if match else None
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Create a loggable version of messages with image data truncated."""
loggable_messages = []
for msg in messages:
if isinstance(msg.get("content"), list):
new_content = []
for content in msg["content"]:
if content.get("type") == "image":
new_content.append(
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
)
else:
new_content.append(content)
loggable_messages.append({"role": msg["role"], "content": new_content})
else:
loggable_messages.append(msg)
return loggable_messages
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""Run interleaved chat completion.
Args:
messages: List of message dicts
system: System prompt
max_tokens: Optional max tokens override
Returns:
Response dict
"""
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
final_messages = [
{
"role": "system",
"content": [
{ "type": "text", "text": system }
]
}
]
# Process messages
for item in messages:
if isinstance(item, dict):
if isinstance(item["content"], list):
# Content is already in the correct format
final_messages.append(item)
else:
# Single string content, check for image
base64_img = self._extract_base64_image(item["content"])
if base64_img:
message = {
"role": item["role"],
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
}
],
}
else:
message = {
"role": item["role"],
"content": [{
"type": "text",
"text": item["content"]
}],
}
final_messages.append(message)
else:
# String content, check for image
base64_img = self._extract_base64_image(item)
if base64_img:
message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
}
],
}
else:
message = {"role": "user", "content": [{"type": "text", "text": item}]}
final_messages.append(message)
payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
payload["max_tokens"] = max_tokens or self.max_tokens
try:
async with aiohttp.ClientSession() as session:
# Use default base URL if none provided
base_url = self.provider_base_url or "http://localhost:8000/v1"
# Check if the base URL already includes the chat/completions endpoint
endpoint_url = base_url
if not endpoint_url.endswith("/chat/completions"):
# If URL is RunPod format, make it OpenAI compatible
if endpoint_url.startswith("https://api.runpod.ai/v2/"):
# Extract RunPod endpoint ID
parts = endpoint_url.split("/")
if len(parts) >= 5:
runpod_id = parts[4]
endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
# If the URL ends with /v1, append /chat/completions
elif endpoint_url.endswith("/v1"):
endpoint_url = f"{endpoint_url}/chat/completions"
# If the URL doesn't end with /v1, make sure it has a proper structure
elif not endpoint_url.endswith("/"):
endpoint_url = f"{endpoint_url}/chat/completions"
else:
endpoint_url = f"{endpoint_url}chat/completions"
# Log the endpoint URL for debugging
logger.debug(f"Using endpoint URL: {endpoint_url}")
async with session.post(endpoint_url, headers=headers, json=payload) as response:
response_json = await response.json()
if response.status != 200:
error_msg = response_json.get("error", {}).get(
"message", str(response_json)
)
logger.error(f"Error in API call: {error_msg}")
raise Exception(f"API error: {error_msg}")
return response_json
except Exception as e:
logger.error(f"Error in API call: {str(e)}")
raise

View File

@@ -1,122 +0,0 @@
"""Ollama API client implementation."""
import logging
from typing import Any, Dict, List, Optional, Tuple, cast
import asyncio
from httpx import ConnectError, ReadTimeout
from ollama import AsyncClient, Options
from ollama import Message
from .base import BaseOmniClient
logger = logging.getLogger(__name__)
class OllamaClient(BaseOmniClient):
"""Client for making calls to Ollama API."""
def __init__(self, api_key: str, model: str, max_retries: int = 3, retry_delay: float = 1.0):
"""Initialize the Ollama client.
Args:
api_key: Not used
model: Ollama model name (e.g. "gemma3:4b-it-q4_K_M")
max_retries: Maximum number of retries for API calls
retry_delay: Base delay between retries in seconds
"""
if not model:
raise ValueError("Model name must be provided")
self.client = AsyncClient(
host="http://localhost:11434",
)
self.model: str = model # Add explicit type annotation
self.max_retries = max_retries
self.retry_delay = retry_delay
def _convert_message_format(self, system: str, messages: List[Dict[str, Any]]) -> List[Any]:
"""Convert messages from standard format to Ollama format.
Args:
messages: Messages in standard format
Returns:
Messages in Ollama format
"""
ollama_messages = []
# Add system message
ollama_messages.append(
{
"role": "system",
"content": system,
}
)
for message in messages:
# Skip messages with empty content
if not message.get("content"):
continue
content = message.get("content", [{}])[0]
isImage = content.get("type", "") == "image_url"
isText = content.get("type", "") == "text"
if isText:
data = content.get("text", "")
ollama_messages.append({"role": message["role"], "content": data})
if isImage:
data = content.get("image_url", {}).get("url", "")
# remove header
data = data.removeprefix("data:image/png;base64,")
ollama_messages.append(
{"role": message["role"], "content": "Use this image", "images": [data]}
)
# Cast the list to the correct type expected by Ollama
return cast(List[Any], ollama_messages)
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: int
) -> Any:
"""Run model with interleaved conversation format.
Args:
messages: List of messages to process
system: System prompt
max_tokens: Not used
Returns:
Model response
"""
last_error = None
for attempt in range(self.max_retries):
try:
# Convert messages to Ollama format
ollama_messages = self._convert_message_format(system, messages)
response = await self.client.chat(
model=self.model,
options=Options(
temperature=0,
),
messages=ollama_messages,
format="json",
)
return response
except (ConnectError, ReadTimeout) as e:
last_error = e
logger.warning(
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
)
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
continue
except Exception as e:
logger.error(f"Unexpected error in Ollama API call: {str(e)}")
raise RuntimeError(f"Ollama API call failed: {str(e)}")
# If we get here, all retries failed
raise RuntimeError(f"Connection error after {self.max_retries} retries: {str(last_error)}")

View File

@@ -1,155 +0,0 @@
"""OpenAI client implementation."""
import os
import logging
from typing import Dict, List, Optional, Any
import aiohttp
import re
from datetime import datetime
from .base import BaseOmniClient
logger = logging.getLogger(__name__)
# OpenAI specific client for the OmniLoop
class OpenAIClient(BaseOmniClient):
"""OpenAI vision API client implementation."""
def __init__(
self,
api_key: Optional[str] = None,
model: str = "gpt-4o",
provider_base_url: str = "https://api.openai.com/v1",
max_tokens: int = 4096,
temperature: float = 0.0,
):
"""Initialize the OpenAI client.
Args:
api_key: OpenAI API key
model: Model to use
provider_base_url: API endpoint
max_tokens: Maximum tokens to generate
temperature: Generation temperature
"""
super().__init__(api_key=api_key, model=model)
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("No OpenAI API key provided")
self.model = model
self.provider_base_url = provider_base_url
self.max_tokens = max_tokens
self.temperature = temperature
def _extract_base64_image(self, text: str) -> Optional[str]:
"""Extract base64 image data from an HTML img tag."""
pattern = r'data:image/[^;]+;base64,([^"]+)'
match = re.search(pattern, text)
return match.group(1) if match else None
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Create a loggable version of messages with image data truncated."""
loggable_messages = []
for msg in messages:
if isinstance(msg.get("content"), list):
new_content = []
for content in msg["content"]:
if content.get("type") == "image":
new_content.append(
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
)
else:
new_content.append(content)
loggable_messages.append({"role": msg["role"], "content": new_content})
else:
loggable_messages.append(msg)
return loggable_messages
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""Run interleaved chat completion.
Args:
messages: List of message dicts
system: System prompt
max_tokens: Optional max tokens override
Returns:
Response dict
"""
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
final_messages = [{"role": "system", "content": system}]
# Process messages
for item in messages:
if isinstance(item, dict):
if isinstance(item["content"], list):
# Content is already in the correct format
final_messages.append(item)
else:
# Single string content, check for image
base64_img = self._extract_base64_image(item["content"])
if base64_img:
message = {
"role": item["role"],
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
}
],
}
else:
message = {
"role": item["role"],
"content": [{"type": "text", "text": item["content"]}],
}
final_messages.append(message)
else:
# String content, check for image
base64_img = self._extract_base64_image(item)
if base64_img:
message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
}
],
}
else:
message = {"role": "user", "content": [{"type": "text", "text": item}]}
final_messages.append(message)
payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
if "o1" in self.model or "o3-mini" in self.model:
payload["reasoning_effort"] = "low"
payload["max_completion_tokens"] = max_tokens or self.max_tokens
else:
payload["max_tokens"] = max_tokens or self.max_tokens
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.provider_base_url}/chat/completions", headers=headers, json=payload
) as response:
response_json = await response.json()
if response.status != 200:
error_msg = response_json.get("error", {}).get(
"message", str(response_json)
)
logger.error(f"Error in OpenAI API call: {error_msg}")
raise Exception(f"OpenAI API error: {error_msg}")
return response_json
except Exception as e:
logger.error(f"Error in OpenAI API call: {str(e)}")
raise

View File

@@ -1,25 +0,0 @@
import base64
def is_image_path(text: str) -> bool:
"""Check if a text string is an image file path.
Args:
text: Text string to check
Returns:
True if text ends with image extension, False otherwise
"""
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
return text.endswith(image_extensions)
def encode_image(image_path: str) -> str:
"""Encode image file to base64.
Args:
image_path: Path to image file
Returns:
Base64 encoded image string
"""
with open(image_path, "rb") as image_file:
return base64.b64encode(image_file.read()).decode("utf-8")

View File

@@ -1,34 +0,0 @@
"""Image processing utilities for the Cua provider."""
import base64
import logging
import re
from io import BytesIO
from typing import Optional, Tuple
from PIL import Image
logger = logging.getLogger(__name__)
def decode_base64_image(img_base64: str) -> Optional[Image.Image]:
"""Decode a base64 encoded image to a PIL Image.
Args:
img_base64: Base64 encoded image, may include data URL prefix
Returns:
PIL Image or None if decoding fails
"""
try:
# Remove data URL prefix if present
if img_base64.startswith("data:image"):
img_base64 = img_base64.split(",")[1]
# Decode base64 to bytes
img_data = base64.b64decode(img_base64)
# Convert bytes to PIL Image
return Image.open(BytesIO(img_data))
except Exception as e:
logger.error(f"Error decoding base64 image: {str(e)}")
return None

View File

@@ -1,990 +0,0 @@
"""Omni-specific agent loop implementation."""
import logging
from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator
import json
import re
import os
import asyncio
from httpx import ConnectError, ReadTimeout
from typing import cast
from .parser import OmniParser, ParseResult
from ...core.base import BaseLoop
from ...core.visualization import VisualizationHelper
from ...core.messages import StandardMessageManager, ImageRetentionConfig
from .utils import to_openai_agent_response_format
from ...core.types import AgentResponse
from computer import Computer
from ...core.types import LLMProvider
from .clients.openai import OpenAIClient
from .clients.anthropic import AnthropicClient
from .clients.ollama import OllamaClient
from .clients.oaicompat import OAICompatClient
from .prompts import SYSTEM_PROMPT
from .api_handler import OmniAPIHandler
from .tools.manager import ToolManager
from .tools import ToolResult
logger = logging.getLogger(__name__)
def extract_data(input_string: str, data_type: str) -> str:
"""Extract content from code blocks."""
pattern = f"```{data_type}" + r"(.*?)(```|$)"
matches = re.findall(pattern, input_string, re.DOTALL)
return matches[0][0].strip() if matches else input_string
class OmniLoop(BaseLoop):
"""Omni-specific implementation of the agent loop.
This class extends BaseLoop to provide support for multimodal models
from various providers (OpenAI, Anthropic, etc.) with UI parsing
and desktop automation capabilities.
"""
###########################################
# INITIALIZATION AND CONFIGURATION
###########################################
def __init__(
self,
parser: OmniParser,
provider: LLMProvider,
api_key: str,
model: str,
computer: Computer,
only_n_most_recent_images: Optional[int] = 2,
base_dir: Optional[str] = "trajectories",
max_retries: int = 3,
retry_delay: float = 1.0,
save_trajectory: bool = True,
provider_base_url: Optional[str] = None,
**kwargs,
):
"""Initialize the loop.
Args:
parser: Parser instance
provider: API provider
api_key: API key
model: Model name
computer: Computer instance
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
base_dir: Base directory for saving experiment data
max_retries: Maximum number of retries for API calls
retry_delay: Delay between retries in seconds
save_trajectory: Whether to save trajectory data
provider_base_url: Base URL for the API provider (used for OAICOMPAT)
"""
# Set parser and provider before initializing base class
self.parser = parser
self.provider = provider
self.provider_base_url = provider_base_url
# Initialize message manager with image retention config
self.message_manager = StandardMessageManager(
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
)
# Initialize base class (which will set up experiment manager)
super().__init__(
computer=computer,
model=model,
api_key=api_key,
max_retries=max_retries,
retry_delay=retry_delay,
base_dir=base_dir,
save_trajectory=save_trajectory,
only_n_most_recent_images=only_n_most_recent_images,
**kwargs,
)
# Set API client attributes
self.client = None
self.retry_count = 0
self.loop_task = None # Store the loop task for cancellation
# Initialize handlers
self.api_handler = OmniAPIHandler(loop=self)
self.viz_helper = VisualizationHelper(agent=self)
# Initialize tool manager
self.tool_manager = ToolManager(computer=computer, provider=provider)
logger.info("OmniLoop initialized with StandardMessageManager")
async def initialize(self) -> None:
"""Initialize the loop by setting up tools and clients."""
# Initialize base class
await super().initialize()
# Initialize tool manager with error handling
try:
logger.info("Initializing tool manager...")
await self.tool_manager.initialize()
logger.info("Tool manager initialized successfully.")
except Exception as e:
logger.error(f"Error initializing tool manager: {str(e)}")
logger.warning("Will attempt to initialize tools on first use.")
# Initialize API clients based on provider
if self.provider == LLMProvider.ANTHROPIC:
self.client = AnthropicClient(
api_key=self.api_key,
model=self.model,
)
elif self.provider == LLMProvider.OPENAI:
self.client = OpenAIClient(
api_key=self.api_key,
model=self.model,
)
elif self.provider == LLMProvider.OLLAMA:
self.client = OllamaClient(
api_key=self.api_key,
model=self.model,
)
elif self.provider == LLMProvider.OAICOMPAT:
self.client = OAICompatClient(
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
model=self.model,
provider_base_url=self.provider_base_url,
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
###########################################
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
###########################################
async def initialize_client(self) -> None:
"""Initialize the appropriate client based on provider.
Implements abstract method from BaseLoop to set up the specific
provider client (OpenAI, Anthropic, etc.).
"""
try:
logger.info(f"Initializing {self.provider} client with model {self.model}...")
if self.provider == LLMProvider.OPENAI:
self.client = OpenAIClient(api_key=self.api_key, model=self.model)
elif self.provider == LLMProvider.ANTHROPIC:
self.client = AnthropicClient(
api_key=self.api_key,
model=self.model,
max_retries=self.max_retries,
retry_delay=self.retry_delay,
)
elif self.provider == LLMProvider.OLLAMA:
self.client = OllamaClient(
api_key=self.api_key,
model=self.model,
)
elif self.provider == LLMProvider.OAICOMPAT:
self.client = OAICompatClient(
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
model=self.model,
provider_base_url=self.provider_base_url,
)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
logger.info(f"Initialized {self.provider} client with model {self.model}")
except Exception as e:
logger.error(f"Error initializing client: {str(e)}")
self.client = None
raise RuntimeError(f"Failed to initialize client: {str(e)}")
###########################################
# API CALL HANDLING
###########################################
async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any:
"""Make API call to provider with retry logic."""
# Create new turn directory for this API call
self._create_turn_dir()
request_data = None
last_error = None
for attempt in range(self.max_retries):
try:
# Ensure client is initialized
if self.client is None:
logger.info(
f"Client not initialized in _make_api_call (attempt {attempt+1}), initializing now..."
)
await self.initialize_client()
if self.client is None:
raise RuntimeError("Failed to initialize client")
# Get messages in standard format from the message manager
self.message_manager.messages = messages.copy()
prepared_messages = self.message_manager.get_messages()
# Special handling for Anthropic
if self.provider == LLMProvider.ANTHROPIC:
# Convert to Anthropic format
anthropic_messages, anthropic_system = self.message_manager.to_anthropic_format(
prepared_messages
)
# Filter out any empty/invalid messages
filtered_messages = [
msg
for msg in anthropic_messages
if msg.get("role") in ["user", "assistant"]
]
# Ensure there's at least one message for Anthropic
if not filtered_messages:
logger.warning(
"No valid messages found for Anthropic API call. Adding a default user message."
)
filtered_messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Please help with this task."}
],
}
]
# Combine system prompts if needed
final_system_prompt = anthropic_system or system_prompt
# Log request
request_data = {
"messages": filtered_messages,
"max_tokens": self.max_tokens,
"system": final_system_prompt,
}
self._log_api_call("request", request_data)
# Make API call
response = await self.client.run_interleaved(
messages=filtered_messages,
system=final_system_prompt,
max_tokens=self.max_tokens,
)
else:
# For OpenAI and others, use standard format directly
# Log request
request_data = {
"messages": prepared_messages,
"max_tokens": self.max_tokens,
"system": system_prompt,
}
self._log_api_call("request", request_data)
# Make API call
response = await self.client.run_interleaved(
messages=prepared_messages,
system=system_prompt,
max_tokens=self.max_tokens,
)
# Log success response
self._log_api_call("response", request_data, response)
return response
except (ConnectError, ReadTimeout) as e:
last_error = e
logger.warning(
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
)
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
# Reset client on connection errors to force re-initialization
self.client = None
continue
except RuntimeError as e:
# Handle client initialization errors specifically
last_error = e
self._log_api_call("error", request_data, error=e)
logger.error(
f"Client initialization error (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
)
if attempt < self.max_retries - 1:
# Reset client to force re-initialization
self.client = None
await asyncio.sleep(self.retry_delay)
continue
except Exception as e:
# Log unexpected error
last_error = e
self._log_api_call("error", request_data, error=e)
logger.error(f"Unexpected error in API call: {str(e)}")
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay)
continue
# If we get here, all retries failed
error_message = f"API call failed after {self.max_retries} attempts"
if last_error:
error_message += f": {str(last_error)}"
logger.error(error_message)
raise RuntimeError(error_message)
###########################################
# RESPONSE AND ACTION HANDLING
###########################################
async def _handle_response(
self, response: Any, messages: List[Dict[str, Any]], parsed_screen: ParseResult
) -> Tuple[bool, bool]:
"""Handle API response.
Args:
response: API response
messages: List of messages to update
parsed_screen: Current parsed screen information
Returns:
Tuple of (should_continue, action_screenshot_saved)
"""
action_screenshot_saved = False
# Helper function to safely add assistant messages using the message manager
def add_assistant_message(content):
if isinstance(content, str):
# Convert string to proper format
formatted_content = [{"type": "text", "text": content}]
self.message_manager.add_assistant_message(formatted_content)
logger.info("Added formatted text assistant message")
elif isinstance(content, list):
# Already in proper format
self.message_manager.add_assistant_message(content)
logger.info("Added structured assistant message")
else:
# Default case - convert to string
formatted_content = [{"type": "text", "text": str(content)}]
self.message_manager.add_assistant_message(formatted_content)
logger.info("Added converted assistant message")
try:
# Step 1: Normalize response to standard format based on provider
standard_content = []
raw_text = None
# Convert response to standardized content based on provider
if self.provider == LLMProvider.ANTHROPIC:
if hasattr(response, "content") and isinstance(response.content, list):
# Convert Anthropic response to standard format
for block in response.content:
if hasattr(block, "type"):
if block.type == "text":
standard_content.append({"type": "text", "text": block.text})
# Store raw text for JSON parsing
if raw_text is None:
raw_text = block.text
else:
raw_text += "\n" + block.text
else:
# Add other block types
block_dict = {}
for key, value in vars(block).items():
if not key.startswith("_"):
block_dict[key] = value
standard_content.append(block_dict)
else:
logger.warning("Invalid Anthropic response format")
return True, action_screenshot_saved
elif self.provider == LLMProvider.OLLAMA:
try:
raw_text = response["message"]["content"]
standard_content = [{"type": "text", "text": raw_text}]
except (KeyError, TypeError, IndexError) as e:
logger.error(f"Invalid response format: {str(e)}")
return True, action_screenshot_saved
elif self.provider == LLMProvider.OAICOMPAT:
try:
# OpenAI-compatible response format
raw_text = response["choices"][0]["message"]["content"]
standard_content = [{"type": "text", "text": raw_text}]
except (KeyError, TypeError, IndexError) as e:
logger.error(f"Invalid response format: {str(e)}")
return True, action_screenshot_saved
else:
# Assume OpenAI or compatible format
try:
raw_text = response["choices"][0]["message"]["content"]
standard_content = [{"type": "text", "text": raw_text}]
except (KeyError, TypeError, IndexError) as e:
logger.error(f"Invalid response format: {str(e)}")
return True, action_screenshot_saved
# Step 2: Add the normalized response to message history
add_assistant_message(standard_content)
# Step 3: Extract JSON from the content for action execution
parsed_content = None
# If we have raw text, try to extract JSON from it
if raw_text:
# Try different approaches to extract JSON
try:
# First try to parse the whole content as JSON
parsed_content = json.loads(raw_text)
logger.info("Successfully parsed whole content as JSON")
except json.JSONDecodeError:
try:
# Try to find JSON block
json_content = extract_data(raw_text, "json")
parsed_content = json.loads(json_content)
logger.info("Successfully parsed JSON from code block")
except (json.JSONDecodeError, IndexError):
try:
# Look for JSON object pattern
import re # Local import to ensure availability
json_pattern = r"\{[^}]+\}"
json_match = re.search(json_pattern, raw_text)
if json_match:
json_str = json_match.group(0)
parsed_content = json.loads(json_str)
logger.info("Successfully parsed JSON from text")
else:
logger.error(f"No JSON found in content")
return True, action_screenshot_saved
except json.JSONDecodeError as e:
# Try to sanitize the JSON string and retry
try:
# Remove or replace invalid control characters
import re # Local import to ensure availability
sanitized_text = re.sub(r"[\x00-\x1F\x7F]", "", raw_text)
# Try parsing again with sanitized text
parsed_content = json.loads(sanitized_text)
logger.info(
"Successfully parsed JSON after sanitizing control characters"
)
except json.JSONDecodeError:
logger.error(f"Failed to parse JSON from text: {str(e)}")
return True, action_screenshot_saved
# Step 4: Process the parsed content if available
if parsed_content:
# Clean up Box ID format
if "Box ID" in parsed_content and isinstance(parsed_content["Box ID"], str):
parsed_content["Box ID"] = parsed_content["Box ID"].replace("Box #", "")
# Add any explanatory text as reasoning if not present
if "Explanation" not in parsed_content and raw_text:
# Extract any text before the JSON as reasoning
text_before_json = raw_text.split("{")[0].strip()
if text_before_json:
parsed_content["Explanation"] = text_before_json
# Log the parsed content for debugging
logger.info(f"Parsed content: {json.dumps(parsed_content, indent=2)}")
# Step 5: Execute the action
try:
# Execute action using the common helper method
should_continue, action_screenshot_saved = (
await self._execute_action_with_tools(
parsed_content, cast(ParseResult, parsed_screen)
)
)
# Check if task is complete
if parsed_content.get("Action") == "None":
return False, action_screenshot_saved
return should_continue, action_screenshot_saved
except Exception as e:
logger.error(f"Error executing action: {str(e)}")
# Update the last assistant message with error
error_message = [{"type": "text", "text": f"Error executing action: {str(e)}"}]
# Replace the last assistant message with the error
self.message_manager.add_assistant_message(error_message)
return False, action_screenshot_saved
return True, action_screenshot_saved
except Exception as e:
logger.error(f"Error handling response: {str(e)}")
# Add error message using the message manager
error_message = [{"type": "text", "text": f"Error: {str(e)}"}]
self.message_manager.add_assistant_message(error_message)
raise
###########################################
# SCREEN PARSING - IMPLEMENTING ABSTRACT METHOD
###########################################
async def _get_parsed_screen_som(self, save_screenshot: bool = True) -> ParseResult:
"""Get parsed screen information with Screen Object Model.
Extends the base class method to use the OmniParser to parse the screen
and extract UI elements.
Args:
save_screenshot: Whether to save the screenshot (set to False when screenshots will be saved elsewhere)
Returns:
ParseResult containing screen information and elements
"""
try:
# Use the parser's parse_screen method which handles the screenshot internally
parsed_screen = await self.parser.parse_screen(computer=self.computer)
# Log information about the parsed results
logger.info(
f"Parsed screen with {len(parsed_screen.elements) if parsed_screen.elements else 0} elements"
)
# Save screenshot if requested and if we have image data
if save_screenshot and self.save_trajectory and parsed_screen.annotated_image_base64:
try:
# Extract just the image data (remove data:image/png;base64, prefix)
img_data = parsed_screen.annotated_image_base64
if "," in img_data:
img_data = img_data.split(",")[1]
# Process screenshot through hooks and save if needed
await self.handle_screenshot(img_data, action_type="state", parsed_screen=parsed_screen)
# Save with a generic "state" action type to indicate this is the current screen state
self._save_screenshot(img_data, action_type="state")
except Exception as e:
logger.error(f"Error saving screenshot: {str(e)}")
return parsed_screen
except Exception as e:
logger.error(f"Error getting parsed screen: {str(e)}")
raise
def _get_system_prompt(self) -> str:
"""Get the system prompt for the model."""
return SYSTEM_PROMPT
###########################################
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
###########################################
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
"""Run the agent loop with provided messages.
Args:
messages: List of messages in standard OpenAI format
Yields:
Agent response format
"""
try:
logger.info(f"Starting OmniLoop run with {len(messages)} messages")
# Initialize the message manager with the provided messages
self.message_manager.messages = messages.copy()
# Create queue for response streaming
queue = asyncio.Queue()
# Start loop in background task
self.loop_task = asyncio.create_task(self._run_loop(queue, messages))
# Process and yield messages as they arrive
while True:
try:
item = await queue.get()
if item is None: # Stop signal
break
yield item
queue.task_done()
except Exception as e:
logger.error(f"Error processing queue item: {str(e)}")
continue
# Wait for loop to complete
await self.loop_task
# Send completion message
yield {
"role": "assistant",
"content": "Task completed successfully.",
"metadata": {"title": "✅ Complete"},
}
except Exception as e:
logger.error(f"Error in run method: {str(e)}")
yield {
"role": "assistant",
"content": f"Error: {str(e)}",
"metadata": {"title": "❌ Error"},
}
async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
"""Internal method to run the agent loop with provided messages.
Args:
queue: Queue to put responses into
messages: List of messages in standard OpenAI format
"""
# Continue running until explicitly told to stop
running = True
turn_created = False
# Track if an action-specific screenshot has been saved this turn
action_screenshot_saved = False
attempt = 0
max_attempts = 3
while running and attempt < max_attempts:
try:
# Create a new turn directory if it's not already created
if not turn_created:
self._create_turn_dir()
turn_created = True
# Ensure client is initialized
if self.client is None:
logger.info("Initializing client...")
await self.initialize_client()
if self.client is None:
raise RuntimeError("Failed to initialize client")
logger.info("Client initialized successfully")
# Get up-to-date screen information
parsed_screen = await self._get_parsed_screen_som()
# Process screen info and update messages in standard format
try:
# Get image from parsed screen
image = parsed_screen.annotated_image_base64 or None
if image:
# Save elements as JSON if we have a turn directory
if self.current_turn_dir and hasattr(parsed_screen, "elements"):
elements_path = os.path.join(self.current_turn_dir, "elements.json")
with open(elements_path, "w") as f:
# Convert elements to dicts for JSON serialization
elements_json = [
elem.model_dump() for elem in parsed_screen.elements
]
json.dump(elements_json, f, indent=2)
logger.info(f"Saved elements to {elements_path}")
# Remove data URL prefix if present
if "," in image:
image = image.split(",")[1]
# Add screenshot to message history using message manager
self.message_manager.add_user_message(
[
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image}"},
}
]
)
logger.info("Added screenshot to message history")
except Exception as e:
logger.error(f"Error processing screen info: {str(e)}")
raise
# Get system prompt
system_prompt = self._get_system_prompt()
# Make API call with retries using the APIHandler
response = await self.api_handler.make_api_call(
self.message_manager.messages, system_prompt
)
# Handle the response (may execute actions)
# Returns: (should_continue, action_screenshot_saved)
should_continue, new_screenshot_saved = await self._handle_response(
response, self.message_manager.messages, parsed_screen
)
# Update whether an action screenshot was saved this turn
action_screenshot_saved = action_screenshot_saved or new_screenshot_saved
# Create OpenAI-compatible response format using utility function
openai_compatible_response = await to_openai_agent_response_format(
response=response,
messages=self.message_manager.messages,
model=self.model,
parsed_screen=parsed_screen,
parser=self.parser
)
# Log standardized response for ease of parsing
self._log_api_call("agent_response", request=None, response=openai_compatible_response)
# Put the response in the queue
await queue.put(openai_compatible_response)
# Check if we should continue this conversation
running = should_continue
# Create a new turn directory if we're continuing
if running:
turn_created = False
# Reset attempt counter on success
attempt = 0
except Exception as e:
attempt += 1
error_msg = f"Error in _run_loop method (attempt {attempt}/{max_attempts}): {str(e)}"
logger.error(error_msg)
# If this is our last attempt, provide more info about the error
if attempt >= max_attempts:
logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}")
await queue.put({
"role": "assistant",
"content": f"Error: {str(e)}",
"metadata": {"title": "❌ Error"},
})
# Create a brief delay before retrying
await asyncio.sleep(1)
finally:
# Signal that we're done
await queue.put(None)
async def cancel(self) -> None:
"""Cancel the currently running agent loop task.
This method stops the ongoing processing in the agent loop
by cancelling the loop_task if it exists and is running.
"""
if self.loop_task and not self.loop_task.done():
logger.info("Cancelling Omni loop task")
self.loop_task.cancel()
try:
# Wait for the task to be cancelled with a timeout
await asyncio.wait_for(self.loop_task, timeout=2.0)
except asyncio.TimeoutError:
logger.warning("Timeout while waiting for loop task to cancel")
except asyncio.CancelledError:
logger.info("Loop task cancelled successfully")
except Exception as e:
logger.error(f"Error while cancelling loop task: {str(e)}")
finally:
logger.info("Omni loop task cancelled")
else:
logger.info("No active Omni loop task to cancel")
async def process_model_response(self, response_text: str) -> Optional[Dict[str, Any]]:
"""Process model response to extract tool calls.
Args:
response_text: Model response text
Returns:
Extracted tool information, or None if no tool call was found
"""
try:
# Ensure tools are initialized before use
await self._ensure_tools_initialized()
# Look for tool use in the response
if "function_call" in response_text or "tool_use" in response_text:
# The extract_tool_call method should be implemented in the OmniAPIHandler
# For now, we'll just use a simple approach
# This will be replaced with the proper implementation
tool_info = None
if "function_call" in response_text:
# Extract function call params
try:
# Simple extraction - in real code this would be more robust
import json
import re
match = re.search(r'"function_call"\s*:\s*{([^}]+)}', response_text)
if match:
function_text = "{" + match.group(1) + "}"
tool_info = json.loads(function_text)
except Exception as e:
logger.error(f"Error extracting function call: {str(e)}")
if tool_info:
try:
# Execute the tool
result = await self.tool_manager.execute_tool(
name=tool_info.get("name"), tool_input=tool_info.get("arguments", {})
)
# Handle the result
return {"tool_result": result}
except Exception as e:
error_msg = (
f"Error executing tool '{tool_info.get('name', 'unknown')}': {str(e)}"
)
logger.error(error_msg)
return {"tool_result": ToolResult(error=error_msg)}
except Exception as e:
logger.error(f"Error processing tool call: {str(e)}")
return None
async def process_response_with_tools(
self, response_text: str, parsed_screen: Optional[ParseResult] = None
) -> Tuple[bool, str]:
"""Process model response and execute tools.
Args:
response_text: Model response text
parsed_screen: Current parsed screen information (optional)
Returns:
Tuple of (action_taken, observation)
"""
logger.info("Processing response with tools")
# Process the response to extract tool calls
tool_result = await self.process_model_response(response_text)
if tool_result and "tool_result" in tool_result:
# A tool was executed
result = tool_result["tool_result"]
if result.error:
return False, f"ERROR: {result.error}"
else:
return True, result.output or "Tool executed successfully"
# No action or tool call found
return False, "No action taken - no tool call detected in response"
###########################################
# UTILITY METHODS
###########################################
async def _ensure_tools_initialized(self) -> None:
"""Ensure the tool manager and tools are initialized before use."""
if not hasattr(self.tool_manager, "tools") or self.tool_manager.tools is None:
logger.info("Tools not initialized. Initializing now...")
await self.tool_manager.initialize()
logger.info("Tools initialized successfully.")
async def _execute_action_with_tools(
self, action_data: Dict[str, Any], parsed_screen: ParseResult
) -> Tuple[bool, bool]:
"""Execute an action using the tools-based approach.
Args:
action_data: Dictionary containing action details
parsed_screen: Current parsed screen information
Returns:
Tuple of (should_continue, action_screenshot_saved)
"""
action_screenshot_saved = False
action_type = None # Initialize for possible use in post-action screenshot
try:
# Extract the action
parsed_action = action_data.get("Action", "").lower()
# Only process if we have a valid action
if not parsed_action or parsed_action == "none":
return False, action_screenshot_saved
# Convert the parsed content to a format suitable for the tools system
tool_name = "computer" # Default to computer tool
tool_args = {"action": parsed_action}
# Add specific arguments based on action type
if parsed_action in ["left_click", "right_click", "double_click", "move_cursor"]:
# Calculate coordinates from Box ID using parser
try:
box_id = int(action_data["Box ID"])
x, y = await self.parser.calculate_click_coordinates(
box_id, cast(ParseResult, parsed_screen)
)
tool_args["x"] = x
tool_args["y"] = y
# Visualize action if screenshot is available
if parsed_screen and parsed_screen.annotated_image_base64:
img_data = parsed_screen.annotated_image_base64
# Remove data URL prefix if present
if img_data.startswith("data:image"):
img_data = img_data.split(",")[1]
# Save visualization for coordinate-based actions
self.viz_helper.visualize_action(x, y, img_data)
action_screenshot_saved = True
except (ValueError, KeyError) as e:
logger.error(f"Error processing Box ID: {str(e)}")
return False, action_screenshot_saved
elif parsed_action == "type_text":
tool_args["text"] = action_data.get("Value", "")
# For type_text, store the value in the action type for screenshot naming
action_type = f"type_{tool_args['text'][:20]}" # Truncate if too long
elif parsed_action == "press_key":
tool_args["key"] = action_data.get("Value", "")
action_type = f"press_{tool_args['key']}"
elif parsed_action == "hotkey":
value = action_data.get("Value", "")
if isinstance(value, list):
tool_args["keys"] = value
action_type = f"hotkey_{'_'.join(value)}"
else:
# Split string format like "command+space" into a list
keys = [k.strip() for k in value.lower().split("+")]
tool_args["keys"] = keys
action_type = f"hotkey_{value.replace('+', '_')}"
elif parsed_action in ["scroll_down", "scroll_up"]:
clicks = int(action_data.get("amount", 1))
tool_args["amount"] = clicks
action_type = f"scroll_{parsed_action.split('_')[1]}_{clicks}"
# Visualize scrolling if screenshot is available
if parsed_screen and parsed_screen.annotated_image_base64:
img_data = parsed_screen.annotated_image_base64
# Remove data URL prefix if present
if img_data.startswith("data:image"):
img_data = img_data.split(",")[1]
direction = "down" if parsed_action == "scroll_down" else "up"
# For scrolling, we save the visualization
self.viz_helper.visualize_scroll(direction, clicks, img_data)
action_screenshot_saved = True
# Ensure tools are initialized before use
await self._ensure_tools_initialized()
# Execute tool with prepared arguments
result = await self.tool_manager.execute_tool(name=tool_name, tool_input=tool_args)
# Take a new screenshot after the action if we haven't already saved one
if not action_screenshot_saved:
try:
# Get a new screenshot after the action
new_parsed_screen = await self._get_parsed_screen_som(save_screenshot=False)
if new_parsed_screen and new_parsed_screen.annotated_image_base64:
img_data = new_parsed_screen.annotated_image_base64
# Remove data URL prefix if present
if img_data.startswith("data:image"):
img_data = img_data.split(",")[1]
# Save with action type if defined, otherwise use the action name
if action_type:
self._save_screenshot(img_data, action_type=action_type)
else:
self._save_screenshot(img_data, action_type=parsed_action)
action_screenshot_saved = True
except Exception as screenshot_error:
logger.error(f"Error taking post-action screenshot: {str(screenshot_error)}")
# Continue the loop if the action is not "None"
return True, action_screenshot_saved
except Exception as e:
logger.error(f"Error executing action: {str(e)}")
# Update the last assistant message with error
error_message = [{"type": "text", "text": f"Error executing action: {str(e)}"}]
# Replace the last assistant message with the error
self.message_manager.add_assistant_message(error_message)
return False, action_screenshot_saved

View File

@@ -1,307 +0,0 @@
"""Parser implementation for the Omni provider."""
import logging
from typing import Any, Dict, List, Optional, Tuple
import base64
import torch
# Import from the SOM package
from som import OmniParser as OmniDetectParser
from som.models import ParseResult, ParserMetadata
logger = logging.getLogger(__name__)
class OmniParser:
"""Parser for handling responses from multiple providers."""
# Class-level shared OmniDetectParser instance
_shared_parser = None
def __init__(self, force_device: Optional[str] = None):
"""Initialize the OmniParser.
Args:
force_device: Optional device to force for detection (cpu/cuda/mps)
"""
self.response_buffer = []
# Use shared parser if available, otherwise create a new one
if OmniParser._shared_parser is None:
logger.info("Initializing shared OmniDetectParser...")
# Determine the best device to use
device = force_device
if not device:
if torch.cuda.is_available():
device = "cuda"
elif (
hasattr(torch, "backends")
and hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
):
device = "mps"
else:
device = "cpu"
logger.info(f"Using device: {device} for OmniDetectParser")
self.detect_parser = OmniDetectParser(force_device=device)
# Preload the detection model to avoid repeated loading
try:
# Access the detector to trigger model loading
detector = self.detect_parser.detector
if detector.model is None:
logger.info("Preloading detection model...")
detector.load_model()
logger.info("Detection model preloaded successfully")
except Exception as e:
logger.error(f"Error preloading detection model: {str(e)}")
# Store as shared instance
OmniParser._shared_parser = self.detect_parser
else:
logger.info("Using existing shared OmniDetectParser")
self.detect_parser = OmniParser._shared_parser
async def parse_screen(self, computer: Any) -> ParseResult:
"""Parse a screenshot and extract screen information.
Args:
computer: Computer instance
Returns:
ParseResult with screen elements and image data
"""
try:
# Get screenshot from computer
logger.info("Taking screenshot...")
screenshot = await computer.interface.screenshot()
# Log screenshot info
logger.info(f"Screenshot type: {type(screenshot)}")
logger.info(f"Screenshot is bytes: {isinstance(screenshot, bytes)}")
logger.info(f"Screenshot is str: {isinstance(screenshot, str)}")
logger.info(f"Screenshot length: {len(screenshot) if screenshot else 0}")
# If screenshot is a string (likely base64), convert it to bytes
if isinstance(screenshot, str):
try:
screenshot = base64.b64decode(screenshot)
logger.info("Successfully converted base64 string to bytes")
logger.info(f"Decoded bytes length: {len(screenshot)}")
except Exception as e:
logger.error(f"Error decoding base64: {str(e)}")
logger.error(f"First 100 chars of screenshot string: {screenshot[:100]}")
# Pass screenshot to OmniDetectParser
logger.info("Passing screenshot to OmniDetectParser...")
parse_result = self.detect_parser.parse(
screenshot_data=screenshot, box_threshold=0.3, iou_threshold=0.1, use_ocr=True
)
logger.info("Screenshot parsed successfully")
logger.info(f"Parse result has {len(parse_result.elements)} elements")
# Log element IDs for debugging
for i, elem in enumerate(parse_result.elements):
logger.info(
f"Element {i+1} (ID: {elem.id}): {elem.type} with confidence {elem.confidence:.3f}"
)
return parse_result
except Exception as e:
logger.error(f"Error parsing screen: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Create a minimal valid result for error cases
return ParseResult(
elements=[],
screen_info=None,
annotated_image_base64="",
parsed_content_list=[{"error": str(e)}],
metadata=ParserMetadata(
image_size=(0, 0),
num_icons=0,
num_text=0,
device="cpu",
ocr_enabled=False,
latency=0.0,
),
)
def parse_tool_call(self, response: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""Parse a tool call from the response.
Args:
response: Response from the provider
Returns:
Parsed tool call or None if no tool call found
"""
try:
# Handle Anthropic format
if "tool_calls" in response:
tool_call = response["tool_calls"][0]
return {
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
}
# Handle OpenAI format
if "function_call" in response:
return {
"name": response["function_call"]["name"],
"arguments": response["function_call"]["arguments"],
}
# Handle Groq format (OpenAI-compatible)
if "choices" in response and response["choices"]:
choice = response["choices"][0]
if "function_call" in choice:
return {
"name": choice["function_call"]["name"],
"arguments": choice["function_call"]["arguments"],
}
return None
except Exception as e:
logger.error(f"Error parsing tool call: {str(e)}")
return None
def parse_response(self, response: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
"""Parse a response from any provider.
Args:
response: Response from the provider
Returns:
Tuple of (content, metadata)
"""
try:
content = ""
metadata = {}
# Handle Anthropic format
if "content" in response and isinstance(response["content"], list):
for item in response["content"]:
if item["type"] == "text":
content += item["text"]
# Handle OpenAI format
elif "choices" in response and response["choices"]:
content = response["choices"][0]["message"]["content"]
# Handle direct content
elif isinstance(response.get("content"), str):
content = response["content"]
# Extract metadata if present
if "metadata" in response:
metadata = response["metadata"]
return content, metadata
except Exception as e:
logger.error(f"Error parsing response: {str(e)}")
return str(e), {"error": True}
def format_for_provider(
self, messages: List[Dict[str, Any]], provider: str
) -> List[Dict[str, Any]]:
"""Format messages for a specific provider.
Args:
messages: List of messages to format
provider: Provider to format for
Returns:
Formatted messages
"""
try:
formatted = []
for msg in messages:
formatted_msg = {"role": msg["role"]}
# Handle content formatting
if isinstance(msg["content"], list):
# For providers that support multimodal
if provider in ["anthropic", "openai"]:
formatted_msg["content"] = msg["content"]
else:
# Extract text only for other providers
text_content = next(
(item["text"] for item in msg["content"] if item["type"] == "text"), ""
)
formatted_msg["content"] = text_content
else:
formatted_msg["content"] = msg["content"]
formatted.append(formatted_msg)
return formatted
except Exception as e:
logger.error(f"Error formatting messages: {str(e)}")
return messages # Return original messages on error
async def calculate_click_coordinates(
self, box_id: int, parsed_screen: ParseResult
) -> Tuple[int, int]:
"""Calculate click coordinates based on box ID.
Args:
box_id: The ID of the box to click
parsed_screen: The parsed screen information
Returns:
Tuple of (x, y) coordinates
Raises:
ValueError: If box_id is invalid or missing from parsed screen
"""
# First try to use structured elements data
logger.info(f"Elements count: {len(parsed_screen.elements)}")
# Try to find element with matching ID
for element in parsed_screen.elements:
if element.id == box_id:
logger.info(f"Found element with ID {box_id}: {element}")
bbox = element.bbox
# Get screen dimensions from the metadata if available, or fallback
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
logger.info(f"Screen dimensions: width={width}, height={height}")
# Create a dictionary from the element's bbox for calculate_element_center
bbox_dict = {"x1": bbox.x1, "y1": bbox.y1, "x2": bbox.x2, "y2": bbox.y2}
from ...core.visualization import calculate_element_center
center_x, center_y = calculate_element_center(bbox_dict, width, height)
logger.info(f"Calculated center: ({center_x}, {center_y})")
# Validate coordinates - if they're (0,0) or unreasonably small,
# use a default position in the center of the screen
if center_x == 0 and center_y == 0:
logger.warning("Got (0,0) coordinates, using fallback position")
center_x = width // 2
center_y = height // 2
logger.info(f"Using fallback center: ({center_x}, {center_y})")
return center_x, center_y
# If we couldn't find the box, use center of screen
logger.error(
f"Box ID {box_id} not found in structured elements (count={len(parsed_screen.elements)})"
)
# Use center of screen as fallback
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
logger.warning(f"Using fallback position in center of screen ({width//2}, {height//2})")
return width // 2, height // 2

View File

@@ -1,64 +0,0 @@
"""Prompts for the Omni agent."""
SYSTEM_PROMPT = """
You are using a macOS device.
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
You may be given some history plan and actions, this is the response from the previous loop.
You should carefully consider your plan base on the task, screenshot, and history actions.
Your available "Next Action" only include:
- type_text: types a string of text.
- left_click: move mouse to box id and left clicks.
- right_click: move mouse to box id and right clicks.
- double_click: move mouse to box id and double clicks.
- move_cursor: move mouse to box id.
- scroll_up: scrolls the screen up to view previous content.
- scroll_down: scrolls the screen down, when the desired button is not visible, or you need to see more content.
- hotkey: press a sequence of keys.
- wait: waits for 1 second for the device to load or respond.
Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on (if action is one of 'type', 'hover', 'scroll_up', 'scroll_down', 'wait', there should be no Box ID field), and the value (if the action is 'type') in order to complete the task.
Output format:
{
"Explanation": str, # describe what is in the current screen, taking into account the history, then describe your step-by-step thoughts on how to achieve the task, choose one action from available actions at a time.
"Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely.
"Box ID": n,
"Value": "xxx" # only provide value field if the action is type, else don't include value key
}
One Example:
{
"Explanation": "The current screen shows google result of amazon, in previous action I have searched amazon on google. Then I need to click on the first search results to go to amazon.com.",
"Action": "left_click",
"Box ID": 4
}
Another Example:
{
"Explanation": "The current screen shows the front page of amazon. There is no previous action. Therefore I need to type "Apple watch" in the search bar.",
"Action": "type_text",
"Box ID": 2,
"Value": "Apple watch"
}
Another Example:
{
"Explanation": "I am starting a Spotlight search to find the Safari browser.",
"Action": "hotkey",
"Value": "command+space"
}
IMPORTANT NOTES:
1. You should only give a single action at a time.
2. The Box ID is the id of the element you should operate on, it is a number. Its background color corresponds to the color of the bounding box of the element.
3. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task.
4. Attach the next action prediction in the "Action" field.
5. For starting applications, always use the "hotkey" action with command+space for starting a Spotlight search.
6. When the task is completed, don't complete additional actions. You should say "Action": "None" in the json field.
7. The tasks involve buying multiple products or navigating through multiple pages. You should break it into subgoals and complete each subgoal one by one in the order of the instructions.
8. Avoid choosing the same action/elements multiple times in a row, if it happens, reflect to yourself, what may have gone wrong, and predict a different action.
9. Reflect whether the element is clickable or not, for example reflect if it is an hyperlink or a button or a normal text.
10. If you are prompted with login information page or captcha page, or you think it need user's permission to do the next action, you should say "Action": "None" in the json field.
"""

View File

@@ -1,30 +0,0 @@
"""Omni provider tools - compatible with multiple LLM providers."""
from ....core.tools import BaseTool, ToolResult, ToolError, ToolFailure, CLIResult
from .base import BaseOmniTool
from .computer import ComputerTool
from .bash import BashTool
from .manager import ToolManager
# Re-export the tools with Omni-specific names for backward compatibility
OmniToolResult = ToolResult
OmniToolError = ToolError
OmniToolFailure = ToolFailure
OmniCLIResult = CLIResult
# We'll export specific tools once implemented
__all__ = [
"BaseTool",
"BaseOmniTool",
"ToolResult",
"ToolError",
"ToolFailure",
"CLIResult",
"OmniToolResult",
"OmniToolError",
"OmniToolFailure",
"OmniCLIResult",
"ComputerTool",
"BashTool",
"ToolManager",
]

View File

@@ -1,29 +0,0 @@
"""Omni-specific tool base classes."""
from abc import ABCMeta, abstractmethod
from typing import Any, Dict
from ....core.tools.base import BaseTool
class BaseOmniTool(BaseTool, metaclass=ABCMeta):
"""Abstract base class for Omni provider tools."""
def __init__(self):
"""Initialize the base Omni tool."""
# No specific initialization needed yet, but included for future extensibility
pass
@abstractmethod
async def __call__(self, **kwargs) -> Any:
"""Executes the tool with the given arguments."""
...
@abstractmethod
def to_params(self) -> Dict[str, Any]:
"""Convert tool to Omni provider-specific API parameters.
Returns:
Dictionary with tool parameters for the specific API
"""
raise NotImplementedError

View File

@@ -1,74 +0,0 @@
"""Bash tool for Omni provider."""
import logging
from typing import Any, Dict
from computer import Computer
from ....core.tools import ToolResult, ToolError
from .base import BaseOmniTool
logger = logging.getLogger(__name__)
class BashTool(BaseOmniTool):
"""Tool for executing bash commands."""
name = "bash"
description = "Execute bash commands on the system"
def __init__(self, computer: Computer):
"""Initialize the bash tool.
Args:
computer: Computer instance
"""
super().__init__()
self.computer = computer
def to_params(self) -> Dict[str, Any]:
"""Convert tool to API parameters.
Returns:
Dictionary with tool parameters
"""
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The bash command to execute",
},
},
"required": ["command"],
},
},
}
async def __call__(self, **kwargs) -> ToolResult:
"""Execute bash command.
Args:
**kwargs: Command parameters
Returns:
Tool execution result
"""
try:
command = kwargs.get("command", "")
if not command:
return ToolResult(error="No command specified")
# The true implementation would use the actual method to run terminal commands
# Since we're getting linter errors, we'll just implement a placeholder that will
# be replaced with the correct implementation when this tool is fully integrated
logger.info(f"Would execute command: {command}")
return ToolResult(output=f"Command executed (placeholder): {command}")
except Exception as e:
logger.error(f"Error in bash tool: {str(e)}")
return ToolResult(error=f"Error: {str(e)}")

View File

@@ -1,179 +0,0 @@
"""Computer tool for Omni provider."""
import logging
from typing import Any, Dict
import json
from computer import Computer
from ....core.tools import ToolResult, ToolError
from .base import BaseOmniTool
from ..parser import ParseResult
logger = logging.getLogger(__name__)
class ComputerTool(BaseOmniTool):
"""Tool for interacting with the computer UI."""
name = "computer"
description = "Interact with the computer's graphical user interface"
def __init__(self, computer: Computer):
"""Initialize the computer tool.
Args:
computer: Computer instance
"""
super().__init__()
self.computer = computer
# Default to standard screen dimensions (will be set more accurately during initialization)
self.screen_dimensions = {"width": 1440, "height": 900}
async def initialize_dimensions(self) -> None:
"""Initialize screen dimensions."""
# For now, we'll use default values
# In the future, we can implement proper screen dimension detection
logger.info(f"Using default screen dimensions: {self.screen_dimensions}")
def to_params(self) -> Dict[str, Any]:
"""Convert tool to API parameters.
Returns:
Dictionary with tool parameters
"""
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"left_click",
"right_click",
"double_click",
"move_cursor",
"drag_to",
"type_text",
"press_key",
"hotkey",
"scroll_up",
"scroll_down",
],
"description": "The action to perform",
},
"x": {
"type": "number",
"description": "X coordinate for click or cursor movement",
},
"y": {
"type": "number",
"description": "Y coordinate for click or cursor movement",
},
"box_id": {
"type": "integer",
"description": "ID of the UI element to interact with",
},
"text": {
"type": "string",
"description": "Text to type",
},
"key": {
"type": "string",
"description": "Key to press",
},
"keys": {
"type": "array",
"items": {"type": "string"},
"description": "Keys to press as hotkey combination",
},
"amount": {
"type": "integer",
"description": "Amount to scroll",
},
"duration": {
"type": "number",
"description": "Duration for drag operations",
},
},
"required": ["action"],
},
},
}
async def __call__(self, **kwargs) -> ToolResult:
"""Execute computer action.
Args:
**kwargs: Action parameters
Returns:
Tool execution result
"""
try:
action = kwargs.get("action", "").lower()
if not action:
return ToolResult(error="No action specified")
# Execute the action on the computer
method = getattr(self.computer.interface, action, None)
if not method:
return ToolResult(error=f"Unsupported action: {action}")
# Prepare arguments based on action type
args = {}
if action in ["left_click", "right_click", "double_click", "move_cursor"]:
x = kwargs.get("x")
y = kwargs.get("y")
if x is None or y is None:
box_id = kwargs.get("box_id")
if box_id is None:
return ToolResult(error="Box ID or coordinates required")
# Get coordinates from box_id implementation would be here
# For now, return error
return ToolResult(error="Box ID-based clicking not implemented yet")
args["x"] = x
args["y"] = y
elif action == "drag_to":
x = kwargs.get("x")
y = kwargs.get("y")
if x is None or y is None:
return ToolResult(error="Coordinates required for drag_to")
args.update(
{
"x": x,
"y": y,
"button": kwargs.get("button", "left"),
"duration": float(kwargs.get("duration", 0.5)),
}
)
elif action == "type_text":
text = kwargs.get("text")
if not text:
return ToolResult(error="Text required for type_text")
args["text"] = text
elif action == "press_key":
key = kwargs.get("key")
if not key:
return ToolResult(error="Key required for press_key")
args["key"] = key
elif action == "hotkey":
keys = kwargs.get("keys")
if not keys:
return ToolResult(error="Keys required for hotkey")
# Call with positional arguments instead of kwargs
await method(*keys)
return ToolResult(output=f"Hotkey executed: {'+'.join(keys)}")
elif action in ["scroll_down", "scroll_up"]:
args["clicks"] = int(kwargs.get("amount", 1))
# Execute action with prepared arguments
await method(**args)
return ToolResult(output=f"Action {action} executed successfully")
except Exception as e:
logger.error(f"Error executing computer action: {str(e)}")
return ToolResult(error=f"Error: {str(e)}")

View File

@@ -1,61 +0,0 @@
"""Tool manager for the Omni provider."""
from typing import Any, Dict, List
from computer.computer import Computer
from ....core.tools import BaseToolManager, ToolResult
from ....core.tools.collection import ToolCollection
from .computer import ComputerTool
from .bash import BashTool
from ....core.types import LLMProvider
class ToolManager(BaseToolManager):
"""Manages Omni provider tool initialization and execution."""
def __init__(self, computer: Computer, provider: LLMProvider):
"""Initialize the tool manager.
Args:
computer: Computer instance for computer-related tools
provider: The LLM provider being used
"""
super().__init__(computer)
self.provider = provider
# Initialize Omni-specific tools
self.computer_tool = ComputerTool(self.computer)
self.bash_tool = BashTool(self.computer)
def _initialize_tools(self) -> ToolCollection:
"""Initialize all available tools."""
return ToolCollection(self.computer_tool, self.bash_tool)
async def _initialize_tools_specific(self) -> None:
"""Initialize Omni provider-specific tool requirements."""
await self.computer_tool.initialize_dimensions()
def get_tool_params(self) -> List[Dict[str, Any]]:
"""Get tool parameters for API calls.
Returns:
List of tool parameters for the current provider's API
"""
if self.tools is None:
raise RuntimeError("Tools not initialized. Call initialize() first.")
return self.tools.to_params()
async def execute_tool(self, name: str, tool_input: dict[str, Any]) -> ToolResult:
"""Execute a tool with the given input.
Args:
name: Name of the tool to execute
tool_input: Input parameters for the tool
Returns:
Result of the tool execution
"""
if self.tools is None:
raise RuntimeError("Tools not initialized. Call initialize() first.")
return await self.tools.run(name=name, tool_input=tool_input)

View File

@@ -1,236 +0,0 @@
"""Main entry point for computer agents."""
import asyncio
import json
import logging
import os
from typing import Any, Dict, List, Optional
from som.models import ParseResult
from ...core.types import AgentResponse
logger = logging.getLogger(__name__)
async def to_openai_agent_response_format(
response: Any,
messages: List[Dict[str, Any]],
parsed_screen: Optional[ParseResult] = None,
parser: Optional[Any] = None,
model: Optional[str] = None,
) -> AgentResponse:
"""Create an OpenAI computer use agent compatible response format.
Args:
response: The original API response
messages: List of messages in standard OpenAI format
parsed_screen: Optional pre-parsed screen information
parser: Optional parser instance for coordinate calculation
model: Optional model name
Returns:
A response formatted according to OpenAI's computer use agent standard, including:
- All standard OpenAI computer use agent fields
- Original response in response.choices[0].message
- Full message history in messages field
"""
from datetime import datetime
import time
# Create a unique ID for this response
response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(response)}"
reasoning_id = f"rs_{response_id}"
action_id = f"cu_{response_id}"
call_id = f"call_{response_id}"
# Extract the last assistant message
assistant_msg = None
for msg in reversed(messages):
if msg["role"] == "assistant":
assistant_msg = msg
break
if not assistant_msg:
# If no assistant message found, create a default one
assistant_msg = {"role": "assistant", "content": "No response available"}
# Initialize output array
output_items = []
# Extract reasoning and action details from the response
content = assistant_msg["content"]
reasoning_text = None
action_details = None
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
try:
# Try to parse JSON from text block
text_content = item.get("text", "")
parsed_json = json.loads(text_content)
# Get reasoning text
if reasoning_text is None:
reasoning_text = parsed_json.get("Explanation", "")
# Extract action details
action = parsed_json.get("Action", "").lower()
text_input = parsed_json.get("Text", "")
value = parsed_json.get("Value", "") # Also handle Value field
box_id = parsed_json.get("Box ID") # Extract Box ID
if action in ["click", "left_click"]:
# Always calculate coordinates from Box ID for click actions
x, y = 100, 100 # Default fallback values
if parsed_screen and box_id is not None and parser is not None:
try:
box_id_int = (
box_id
if isinstance(box_id, int)
else int(str(box_id)) if str(box_id).isdigit() else None
)
if box_id_int is not None:
# Use the parser's method to calculate coordinates
x, y = await parser.calculate_click_coordinates(
box_id_int, parsed_screen
)
except Exception as e:
logger.error(
f"Error extracting coordinates for Box ID {box_id}: {str(e)}"
)
action_details = {
"type": "click",
"button": "left",
"box_id": (
(
box_id
if isinstance(box_id, int)
else int(box_id) if str(box_id).isdigit() else None
)
if box_id is not None
else None
),
"x": x,
"y": y,
}
elif action in ["type", "type_text"] and (text_input or value):
action_details = {
"type": "type",
"text": text_input or value,
}
elif action == "hotkey" and value:
action_details = {
"type": "hotkey",
"keys": value,
}
elif action == "scroll":
# Use default coordinates for scrolling
delta_x = 0
delta_y = 0
# Try to extract scroll delta values from content if available
scroll_data = parsed_json.get("Scroll", {})
if scroll_data:
delta_x = scroll_data.get("delta_x", 0)
delta_y = scroll_data.get("delta_y", 0)
action_details = {
"type": "scroll",
"x": 100,
"y": 100,
"scroll_x": delta_x,
"scroll_y": delta_y,
}
elif action == "none":
# Handle case when action is None (task completion)
action_details = {"type": "none", "description": "Task completed"}
except json.JSONDecodeError:
# If not JSON, just use as reasoning text
if reasoning_text is None:
reasoning_text = ""
reasoning_text += item.get("text", "")
# Add reasoning item if we have text content
if reasoning_text:
output_items.append(
{
"type": "reasoning",
"id": reasoning_id,
"summary": [
{
"type": "summary_text",
"text": reasoning_text[:200], # Truncate to reasonable length
}
],
}
)
# If no action details extracted, use default
if not action_details:
action_details = {
"type": "click",
"button": "left",
"x": 100,
"y": 100,
}
# Add computer_call item
computer_call = {
"type": "computer_call",
"id": action_id,
"call_id": call_id,
"action": action_details,
"pending_safety_checks": [],
"status": "completed",
}
output_items.append(computer_call)
# Extract user and assistant messages from the history
user_messages = []
assistant_messages = []
for msg in messages:
if msg["role"] == "user":
user_messages.append(msg)
elif msg["role"] == "assistant":
assistant_messages.append(msg)
# Create the OpenAI-compatible response format with all expected fields
return {
"id": response_id,
"object": "response",
"created_at": int(time.time()),
"status": "completed",
"error": None,
"incomplete_details": None,
"instructions": None,
"max_output_tokens": None,
"model": model or "unknown",
"output": output_items,
"parallel_tool_calls": True,
"previous_response_id": None,
"reasoning": {"effort": "medium", "generate_summary": "concise"},
"store": True,
"temperature": 1.0,
"text": {"format": {"type": "text"}},
"tool_choice": "auto",
"tools": [
{
"type": "computer_use_preview",
"display_height": 768,
"display_width": 1024,
"environment": "mac",
}
],
"top_p": 1.0,
"truncation": "auto",
"usage": {
"input_tokens": 0, # Placeholder values
"input_tokens_details": {"cached_tokens": 0},
"output_tokens": 0, # Placeholder values
"output_tokens_details": {"reasoning_tokens": 0},
"total_tokens": 0, # Placeholder values
},
"user": None,
"metadata": {},
# Include the original response for backward compatibility
"response": {"choices": [{"message": assistant_msg, "finish_reason": "stop"}]},
}

View File

@@ -1,6 +0,0 @@
"""OpenAI Agent Response API provider for computer control."""
from .types import LLMProvider
from .loop import OpenAILoop
__all__ = ["OpenAILoop", "LLMProvider"]

View File

@@ -1,456 +0,0 @@
"""API handler for the OpenAI provider."""
import logging
import requests
import os
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from datetime import datetime
if TYPE_CHECKING:
from .loop import OpenAILoop
logger = logging.getLogger(__name__)
class OpenAIAPIHandler:
"""Handler for OpenAI API interactions."""
def __init__(self, loop: "OpenAILoop"):
"""Initialize the API handler.
Args:
loop: OpenAI loop instance
"""
self.loop = loop
self.api_key = os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError("OPENAI_API_KEY environment variable not set")
self.api_base = "https://api.openai.com/v1"
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
# Add organization if specified
org_id = os.getenv("OPENAI_ORG")
if org_id:
self.headers["OpenAI-Organization"] = org_id
logger.info("Initialized OpenAI API handler")
async def send_initial_request(
self,
messages: List[Dict[str, Any]],
display_width: str,
display_height: str,
previous_response_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Send an initial request to the OpenAI API with a screenshot.
Args:
messages: List of message objects in standard format
display_width: Width of the display in pixels
display_height: Height of the display in pixels
previous_response_id: Optional ID of the previous response to link requests
Returns:
API response
"""
# Convert display dimensions to integers
try:
width = int(display_width)
height = int(display_height)
except (ValueError, TypeError) as e:
logger.error(f"Failed to convert display dimensions to integers: {str(e)}")
raise ValueError(
f"Display dimensions must be integers: width={display_width}, height={display_height}"
)
# Extract the latest text message and screenshot from messages
latest_text = None
latest_screenshot = None
for msg in reversed(messages):
if not isinstance(msg, dict):
continue
content = msg.get("content", [])
if isinstance(content, str) and not latest_text:
latest_text = content
continue
if not isinstance(content, list):
continue
for item in content:
if not isinstance(item, dict):
continue
# Look for text if we don't have it yet
if not latest_text and item.get("type") == "text" and "text" in item:
latest_text = item.get("text", "")
# Look for an image if we don't have it yet
if not latest_screenshot and item.get("type") == "image":
source = item.get("source", {})
if source.get("type") == "base64" and "data" in source:
latest_screenshot = source["data"]
# Prepare the input array
input_array = []
# Add the text message if found
if latest_text:
input_array.append({"role": "user", "content": latest_text})
# Add the screenshot if found and no previous_response_id is provided
if latest_screenshot and not previous_response_id:
input_array.append(
{
"type": "message",
"role": "user",
"content": [
{
"type": "input_image",
"image_url": f"data:image/png;base64,{latest_screenshot}",
}
],
}
)
# Prepare the request payload - using minimal format from docs
payload = {
"model": "computer-use-preview",
"tools": [
{
"type": "computer_use_preview",
"display_width": width,
"display_height": height,
"environment": "mac", # We're on macOS
}
],
"input": input_array,
"reasoning": {
"generate_summary": "concise",
},
"truncation": "auto",
}
# Add previous_response_id if provided
if previous_response_id:
payload["previous_response_id"] = previous_response_id
# Log the request using the BaseLoop's log_api_call method
self.loop._log_api_call("request", payload)
# Log for debug purposes
logger.info("Sending initial request to OpenAI API")
logger.debug(f"Request payload: {self._sanitize_response(payload)}")
# Send the request
response = requests.post(
f"{self.api_base}/responses",
headers=self.headers,
json=payload,
)
if response.status_code != 200:
error_message = f"OpenAI API error: {response.status_code} {response.text}"
logger.error(error_message)
# Log the error using the BaseLoop's log_api_call method
self.loop._log_api_call("error", payload, error=Exception(error_message))
raise Exception(error_message)
response_data = response.json()
# Log the response using the BaseLoop's log_api_call method
self.loop._log_api_call("response", payload, response_data)
# Log for debug purposes
logger.info("Received response from OpenAI API")
logger.debug(f"Response data: {self._sanitize_response(response_data)}")
return response_data
async def send_computer_call_request(
self,
messages: List[Dict[str, Any]],
display_width: str,
display_height: str,
previous_response_id: str,
) -> Dict[str, Any]:
"""Send a request to the OpenAI API with computer_call_output.
Args:
messages: List of message objects in standard format
display_width: Width of the display in pixels
display_height: Height of the display in pixels
system_prompt: System prompt to include
previous_response_id: ID of the previous response to link requests
Returns:
API response
"""
# Convert display dimensions to integers
try:
width = int(display_width)
height = int(display_height)
except (ValueError, TypeError) as e:
logger.error(f"Failed to convert display dimensions to integers: {str(e)}")
raise ValueError(
f"Display dimensions must be integers: width={display_width}, height={display_height}"
)
# Find the most recent computer_call_output with call_id
call_id = None
screenshot_base64 = None
# Look for call_id and screenshot in messages
for msg in reversed(messages):
if not isinstance(msg, dict):
continue
# Check if the message itself has a call_id
if "call_id" in msg and not call_id:
call_id = msg["call_id"]
content = msg.get("content", [])
if not isinstance(content, list):
continue
for item in content:
if not isinstance(item, dict):
continue
# Look for call_id
if not call_id and "call_id" in item:
call_id = item["call_id"]
# Look for screenshot in computer_call_output
if not screenshot_base64 and item.get("type") == "computer_call_output":
output = item.get("output", {})
if isinstance(output, dict) and "image_url" in output:
image_url = output.get("image_url", "")
if image_url.startswith("data:image/png;base64,"):
screenshot_base64 = image_url[len("data:image/png;base64,") :]
# Look for screenshot in image type
if not screenshot_base64 and item.get("type") == "image":
source = item.get("source", {})
if source.get("type") == "base64" and "data" in source:
screenshot_base64 = source["data"]
if not call_id or not screenshot_base64:
logger.error("Missing call_id or screenshot for computer_call_output")
logger.error(f"Last message: {messages[-1] if messages else None}")
raise ValueError("Cannot create computer call request: missing call_id or screenshot")
# Prepare the request payload using minimal format from docs
payload = {
"model": "computer-use-preview",
"previous_response_id": previous_response_id,
"tools": [
{
"type": "computer_use_preview",
"display_width": width,
"display_height": height,
"environment": "mac", # We're on macOS
}
],
"input": [
{
"type": "computer_call_output",
"call_id": call_id,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_base64}",
},
}
],
"truncation": "auto",
}
# Log the request using the BaseLoop's log_api_call method
self.loop._log_api_call("request", payload)
# Log for debug purposes
logger.info("Sending computer call request to OpenAI API")
logger.debug(f"Request payload: {self._sanitize_response(payload)}")
# Send the request
response = requests.post(
f"{self.api_base}/responses",
headers=self.headers,
json=payload,
)
if response.status_code != 200:
error_message = f"OpenAI API error: {response.status_code} {response.text}"
logger.error(error_message)
# Log the error using the BaseLoop's log_api_call method
self.loop._log_api_call("error", payload, error=Exception(error_message))
raise Exception(error_message)
response_data = response.json()
# Log the response using the BaseLoop's log_api_call method
self.loop._log_api_call("response", payload, response_data)
# Log for debug purposes
logger.info("Received response from OpenAI API")
logger.debug(f"Response data: {self._sanitize_response(response_data)}")
return response_data
def _format_messages_for_agent_response(
self, messages: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Format messages for the OpenAI Agent Response API.
The Agent Response API requires specific content types:
- For user messages: use "input_text", "input_image", etc.
- For assistant messages: use "output_text" only
Additionally, when using the computer tool, only one image can be sent.
Args:
messages: List of standard messages
Returns:
Messages formatted for the Agent Response API
"""
formatted_messages = []
has_image = False # Track if we've already included an image
# We need to process messages in reverse to ensure we keep the most recent image
# but preserve the original order in the final output
reversed_messages = list(reversed(messages))
temp_formatted = []
for msg in reversed_messages:
if not msg:
continue
role = msg.get("role", "user")
content = msg.get("content", "")
logger.debug(f"Processing message - Role: {role}, Content type: {type(content)}")
if isinstance(content, list):
logger.debug(
f"List content items: {[item.get('type') for item in content if isinstance(item, dict)]}"
)
if isinstance(content, str):
# For string content, create a message with the appropriate text type
if role == "user":
temp_formatted.append(
{"role": role, "content": [{"type": "input_text", "text": content}]}
)
elif role == "assistant":
# For assistant, we need explicit output_text
temp_formatted.append(
{"role": role, "content": [{"type": "output_text", "text": content}]}
)
elif role == "system":
# System messages need to be formatted as input_text as well
temp_formatted.append(
{"role": role, "content": [{"type": "input_text", "text": content}]}
)
elif isinstance(content, list):
# For list content, convert each item to the correct type based on role
formatted_content = []
has_image_in_this_message = False
for item in content:
if not isinstance(item, dict):
continue
item_type = item.get("type")
if role == "user":
# Handle user message formatting
if item_type == "text" or item_type == "input_text":
# Text from user is input_text
formatted_content.append(
{"type": "input_text", "text": item.get("text", "")}
)
elif (item_type == "image" or item_type == "image_url") and not has_image:
# Only include the first/most recent image we encounter
if item_type == "image":
# Image from user is input_image
source = item.get("source", {})
if source.get("type") == "base64" and "data" in source:
formatted_content.append(
{
"type": "input_image",
"image_url": f"data:image/png;base64,{source['data']}",
}
)
has_image = True
has_image_in_this_message = True
elif item_type == "image_url":
# Convert "image_url" to "input_image"
formatted_content.append(
{
"type": "input_image",
"image_url": item.get("image_url", {}).get("url", ""),
}
)
has_image = True
has_image_in_this_message = True
elif role == "assistant":
# Handle assistant message formatting - only output_text is supported
if item_type == "text" or item_type == "output_text":
formatted_content.append(
{"type": "output_text", "text": item.get("text", "")}
)
if formatted_content:
# If this message had an image, mark it for inclusion
temp_formatted.append(
{
"role": role,
"content": formatted_content,
"_had_image": has_image_in_this_message, # Temporary marker
}
)
# Reverse back to original order and cleanup
for msg in reversed(temp_formatted):
# Remove our temporary marker
if "_had_image" in msg:
del msg["_had_image"]
formatted_messages.append(msg)
# Log summary for debugging
num_images = sum(
1
for msg in formatted_messages
for item in (msg.get("content", []) if isinstance(msg.get("content"), list) else [])
if isinstance(item, dict) and item.get("type") == "input_image"
)
logger.info(f"Formatted {len(messages)} messages for OpenAI API with {num_images} images")
return formatted_messages
def _sanitize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
"""Sanitize response for logging by removing large image data.
Args:
response: Response to sanitize
Returns:
Sanitized response
"""
from .utils import sanitize_message
# Deep copy to avoid modifying the original
sanitized = response.copy()
# Sanitize output items if present
if "output" in sanitized and isinstance(sanitized["output"], list):
sanitized["output"] = [sanitize_message(item) for item in sanitized["output"]]
return sanitized

View File

@@ -1,472 +0,0 @@
"""OpenAI Agent Response API provider implementation."""
import logging
import asyncio
import base64
from typing import Any, Dict, List, Optional, AsyncGenerator, Callable, Awaitable, TYPE_CHECKING
from computer import Computer
from ...core.base import BaseLoop
from ...core.types import AgentResponse
from ...core.messages import StandardMessageManager, ImageRetentionConfig
from .api_handler import OpenAIAPIHandler
from .response_handler import OpenAIResponseHandler
from .tools.manager import ToolManager
from .types import LLMProvider, ResponseItemType
logger = logging.getLogger(__name__)
class OpenAILoop(BaseLoop):
"""OpenAI-specific implementation of the agent loop.
This class extends BaseLoop to provide specialized support for OpenAI's Agent Response API
with computer control capabilities.
"""
###########################################
# INITIALIZATION AND CONFIGURATION
###########################################
def __init__(
self,
api_key: str,
computer: Computer,
model: str = "computer-use-preview",
only_n_most_recent_images: Optional[int] = 2,
base_dir: Optional[str] = "trajectories",
max_retries: int = 3,
retry_delay: float = 1.0,
save_trajectory: bool = True,
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
**kwargs,
):
"""Initialize the OpenAI loop.
Args:
api_key: OpenAI API key
model: Model name (ignored, always uses computer-use-preview)
computer: Computer instance
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
base_dir: Base directory for saving experiment data
max_retries: Maximum number of retries for API calls
retry_delay: Delay between retries in seconds
save_trajectory: Whether to save trajectory data
acknowledge_safety_check_callback: Optional callback for safety check acknowledgment
**kwargs: Additional provider-specific arguments
"""
# Always use computer-use-preview model
if model != "computer-use-preview":
logger.info(
f"Overriding provided model '{model}' with required model 'computer-use-preview'"
)
# Initialize base class with core config
super().__init__(
computer=computer,
model="computer-use-preview", # Always use computer-use-preview
api_key=api_key,
max_retries=max_retries,
retry_delay=retry_delay,
base_dir=base_dir,
save_trajectory=save_trajectory,
only_n_most_recent_images=only_n_most_recent_images,
**kwargs,
)
# Initialize message manager
self.message_manager = StandardMessageManager(
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
)
# OpenAI-specific attributes
self.provider = LLMProvider.OPENAI
self.client = None
self.retry_count = 0
self.acknowledge_safety_check_callback = acknowledge_safety_check_callback
self.queue = asyncio.Queue() # Initialize queue
self.last_response_id = None # Store the last response ID across runs
self.loop_task = None # Store the loop task for cancellation
# Initialize handlers
self.api_handler = OpenAIAPIHandler(self)
self.response_handler = OpenAIResponseHandler(self)
# Initialize tool manager with callback
self.tool_manager = ToolManager(
computer=computer, acknowledge_safety_check_callback=acknowledge_safety_check_callback
)
###########################################
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
###########################################
async def initialize_client(self) -> None:
"""Initialize the OpenAI API client and tools.
Implements abstract method from BaseLoop to set up the OpenAI-specific
client, tool manager, and message manager.
"""
try:
# Initialize tool manager
await self.tool_manager.initialize()
except Exception as e:
logger.error(f"Error initializing OpenAI client: {str(e)}")
self.client = None
raise RuntimeError(f"Failed to initialize OpenAI client: {str(e)}")
###########################################
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
###########################################
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
"""Run the agent loop with provided messages.
Args:
messages: List of message objects in standard format
Yields:
Agent response format
"""
try:
logger.info("Starting OpenAI loop run")
# Create queue for response streaming
self.queue = asyncio.Queue()
# Ensure tool manager is initialized
await self.tool_manager.initialize()
# Start loop in background task
self.loop_task = asyncio.create_task(self._run_loop(self.queue, messages))
# Process and yield messages as they arrive
while True:
try:
item = await self.queue.get()
if item is None: # Stop signal
break
yield item
self.queue.task_done()
except Exception as e:
logger.error(f"Error processing queue item: {str(e)}")
continue
# Wait for loop to complete
await self.loop_task
# Send completion message
yield {
"role": "assistant",
"content": "Task completed successfully.",
"metadata": {"title": "✅ Complete"},
}
except Exception as e:
logger.error(f"Error executing task: {str(e)}")
yield {
"role": "assistant",
"content": f"Error: {str(e)}",
"metadata": {"title": "❌ Error"},
}
async def cancel(self) -> None:
"""Cancel the currently running agent loop task.
This method stops the ongoing processing in the agent loop
by cancelling the loop_task if it exists and is running.
"""
if self.loop_task and not self.loop_task.done():
logger.info("Cancelling OpenAI loop task")
self.loop_task.cancel()
try:
# Wait for the task to be cancelled with a timeout
await asyncio.wait_for(self.loop_task, timeout=2.0)
except asyncio.TimeoutError:
logger.warning("Timeout while waiting for loop task to cancel")
except asyncio.CancelledError:
logger.info("Loop task cancelled successfully")
except Exception as e:
logger.error(f"Error while cancelling loop task: {str(e)}")
finally:
# Put None in the queue to signal any waiting consumers to stop
await self.queue.put(None)
logger.info("OpenAI loop task cancelled")
else:
logger.info("No active OpenAI loop task to cancel")
###########################################
# AGENT LOOP IMPLEMENTATION
###########################################
async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
"""Run the agent loop with provided messages.
Args:
queue: Queue for response streaming
messages: List of messages in standard format
"""
try:
# Use the instance-level last_response_id instead of creating a local variable
# This way it persists between runs
# Capture initial screenshot
try:
# Take screenshot
screenshot = await self.computer.interface.screenshot()
logger.info("Screenshot captured successfully")
# Convert to base64 if needed
if isinstance(screenshot, bytes):
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
elif isinstance(screenshot, (bytearray, memoryview)):
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
else:
screenshot_base64 = str(screenshot)
# Emit screenshot callbacks
await self.handle_screenshot(screenshot_base64, action_type="initial_state")
self._save_screenshot(screenshot_base64, action_type="state")
# First add any existing user messages that were passed to run()
user_query = None
for msg in messages:
if msg.get("role") == "user":
user_content = msg.get("content", "")
if isinstance(user_content, str) and user_content:
user_query = user_content
# Add the user's original query to the message manager
self.message_manager.add_user_message(
[{"type": "text", "text": user_content}]
)
break
# Add screenshot to message manager
message_content = [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": screenshot_base64,
},
}
]
# Add appropriate text with the screenshot
message_content.append(
{
"type": "text",
"text": user_query,
}
)
# Add the screenshot and text to the message manager
self.message_manager.add_user_message(message_content)
# Process user request and convert our standard message format to one OpenAI expects
messages = self.message_manager.messages
logger.info(f"Starting agent loop with {len(messages)} messages")
# Create initial turn directory
if self.save_trajectory:
self._create_turn_dir()
# Call API
screen_size = await self.computer.interface.get_screen_size()
response = await self.api_handler.send_initial_request(
messages=self.message_manager.get_messages(), # Apply image retention policy
display_width=str(screen_size["width"]),
display_height=str(screen_size["height"]),
previous_response_id=self.last_response_id,
)
# Store response ID for next request
# OpenAI API response structure: the ID is in the response dictionary
if isinstance(response, dict) and "id" in response:
self.last_response_id = response["id"] # Update instance variable
logger.info(f"Received response with ID: {self.last_response_id}")
else:
logger.warning(
f"Could not find response ID in OpenAI response: {type(response)}"
)
# Don't reset last_response_id to None - keep the previous value if available
# Log standardized response for ease of parsing
# Since this is the openAI responses format, we don't need to convert it to agent response format
self._log_api_call("agent_response", request=None, response=response)
# Process API response
await queue.put(response)
# Loop to continue processing responses until task is complete
task_complete = False
while not task_complete:
# Check if there are any computer calls
output_items = response.get("output", []) or []
computer_calls = [
item for item in output_items if item.get("type") == "computer_call"
]
if not computer_calls:
logger.info("No computer calls in response, task may be complete.")
task_complete = True
continue
# Process the first computer call
computer_call = computer_calls[0]
action = computer_call.get("action", {})
call_id = computer_call.get("call_id")
# Check for safety checks
pending_safety_checks = computer_call.get("pending_safety_checks", [])
acknowledged_safety_checks = []
if pending_safety_checks:
# Log safety checks
for check in pending_safety_checks:
logger.warning(
f"Safety check: {check.get('code')} - {check.get('message')}"
)
# If we have a callback, use it to acknowledge safety checks
if self.acknowledge_safety_check_callback:
acknowledged = await self.acknowledge_safety_check_callback(
pending_safety_checks
)
if not acknowledged:
logger.warning("Safety check acknowledgment failed")
await queue.put(
{
"role": "assistant",
"content": "Safety checks were not acknowledged. Cannot proceed with action.",
"metadata": {"title": "⚠️ Safety Warning"},
}
)
continue
acknowledged_safety_checks = pending_safety_checks
# Execute the action
try:
# Create a new turn directory for this action if saving trajectories
if self.save_trajectory:
self._create_turn_dir()
# Execute the tool
result = await self.tool_manager.execute_tool("computer", action)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
if isinstance(screenshot, bytes):
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
elif isinstance(screenshot, (bytearray, memoryview)):
screenshot_base64 = base64.b64encode(bytes(screenshot)).decode("utf-8")
else:
screenshot_base64 = str(screenshot)
# Process screenshot through hooks
action_type = f"after_{action.get('type', 'action')}"
await self.handle_screenshot(screenshot_base64, action_type=action_type)
self._save_screenshot(screenshot_base64, action_type=action_type)
# Create computer_call_output
computer_call_output = {
"type": "computer_call_output",
"call_id": call_id,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_base64}",
},
}
# Add acknowledged safety checks if any
if acknowledged_safety_checks:
computer_call_output["acknowledged_safety_checks"] = (
acknowledged_safety_checks
)
# Save to message manager for history
self.message_manager.add_system_message(
f"[Computer action executed: {action.get('type')}]"
)
self.message_manager.add_user_message([computer_call_output])
# For follow-up requests with previous_response_id, we only need to send
# the computer_call_output, not the full message history
# The API handler will extract this from the message history
if isinstance(self.last_response_id, str):
response = await self.api_handler.send_computer_call_request(
messages=self.message_manager.get_messages(), # Apply image retention policy
display_width=str(screen_size["width"]),
display_height=str(screen_size["height"]),
previous_response_id=self.last_response_id, # Use instance variable
)
# Store response ID for next request
if isinstance(response, dict) and "id" in response:
self.last_response_id = response["id"] # Update instance variable
logger.info(f"Received response with ID: {self.last_response_id}")
else:
logger.warning(
f"Could not find response ID in OpenAI response: {type(response)}"
)
# Keep using the previous response ID if we can't find a new one
# Process the response
# await self.response_handler.process_response(response, queue)
self._log_api_call("agent_response", request=None, response=response)
await queue.put(response)
except Exception as e:
logger.error(f"Error executing computer action: {str(e)}")
await queue.put(
{
"role": "assistant",
"content": f"Error executing action: {str(e)}",
"metadata": {"title": "❌ Error"},
}
)
task_complete = True
except Exception as e:
logger.error(f"Error capturing initial screenshot: {str(e)}")
await queue.put(
{
"role": "assistant",
"content": f"Error capturing screenshot: {str(e)}",
"metadata": {"title": "❌ Error"},
}
)
await queue.put(None) # Signal that we're done
return
# Signal that we're done
await queue.put(None)
except Exception as e:
logger.error(f"Error in _run_loop: {str(e)}")
await queue.put(
{
"role": "assistant",
"content": f"Error: {str(e)}",
"metadata": {"title": "❌ Error"},
}
)
await queue.put(None) # Signal that we're done
def get_last_response_id(self) -> Optional[str]:
"""Get the last response ID.
Returns:
The last response ID or None if no response has been received
"""
return self.last_response_id
def set_last_response_id(self, response_id: str) -> None:
"""Set the last response ID.
Args:
response_id: OpenAI response ID to set
"""
self.last_response_id = response_id
logger.info(f"Manually set response ID to: {self.last_response_id}")

View File

@@ -1,205 +0,0 @@
"""Response handler for the OpenAI provider."""
import logging
import asyncio
import traceback
from typing import Any, Dict, List, Optional, TYPE_CHECKING, AsyncGenerator
import base64
from ...core.types import AgentResponse
from .types import ResponseItemType
if TYPE_CHECKING:
from .loop import OpenAILoop
logger = logging.getLogger(__name__)
class OpenAIResponseHandler:
"""Handler for OpenAI API responses."""
def __init__(self, loop: "OpenAILoop"):
"""Initialize the response handler.
Args:
loop: OpenAI loop instance
"""
self.loop = loop
logger.info("Initialized OpenAI response handler")
async def process_response(self, response: Dict[str, Any], queue: asyncio.Queue) -> None:
"""Process the response from the OpenAI API.
Args:
response: Response from the API
queue: Queue for response streaming
"""
try:
# Get output items
output_items = response.get("output", []) or []
# Process each output item
for item in output_items:
if not isinstance(item, dict):
continue
item_type = item.get("type")
# For computer_call items, we only need to add to the queue
# The loop is now handling executing the action and creating the computer_call_output
if item_type == ResponseItemType.COMPUTER_CALL:
# Send computer_call to queue so it can be processed
await queue.put(item)
elif item_type == ResponseItemType.MESSAGE:
# Send message to queue
await queue.put(item)
elif item_type == ResponseItemType.REASONING:
# Process reasoning summary
summary = None
if "summary" in item and isinstance(item["summary"], list):
for summary_item in item["summary"]:
if (
isinstance(summary_item, dict)
and summary_item.get("type") == "summary_text"
):
summary = summary_item.get("text")
break
if summary:
# Log the reasoning summary
logger.info(f"Reasoning summary: {summary}")
# Send reasoning summary to queue with a special format
await queue.put(
{
"role": "assistant",
"content": f"[Reasoning: {summary}]",
"metadata": {"title": "💭 Reasoning", "is_summary": True},
}
)
# Also pass the original reasoning item to the queue for complete context
await queue.put(item)
except Exception as e:
logger.error(f"Error processing response: {str(e)}")
await queue.put(
{
"role": "assistant",
"content": f"Error processing response: {str(e)}",
"metadata": {"title": "❌ Error"},
}
)
def _process_message_item(self, item: Dict[str, Any]) -> AgentResponse:
"""Process a message item from the response.
Args:
item: Message item from the response
Returns:
Processed message in AgentResponse format
"""
# Extract content items - add null check
content_items = item.get("content", []) or []
# Extract text from content items - use output_text type from OpenAI
text = ""
for content_item in content_items:
# Skip if content_item is None or not a dict
if content_item is None or not isinstance(content_item, dict):
continue
# In OpenAI Agent Response API, text content is in "output_text" type items
if content_item.get("type") == "output_text":
text += content_item.get("text", "")
# Create agent response
return {
"role": "assistant",
"content": text
or "I don't have a response for that right now.", # Provide fallback when text is empty
"metadata": {"title": "💬 Response"},
}
async def _process_computer_call(self, item: Dict[str, Any], queue: asyncio.Queue) -> None:
"""Process a computer call item from the response.
Args:
item: Computer call item
queue: Queue to add responses to
"""
try:
# Log the computer call
action = item.get("action", {}) or {}
if not isinstance(action, dict):
logger.warning(f"Expected dict for action, got {type(action)}")
action = {}
action_type = action.get("type", "unknown")
logger.info(f"Processing computer call: {action_type}")
# Execute the tool call
result = await self.loop.tool_manager.execute_tool("computer", action)
# Add any message to the conversation history and queue
if result and result.base64_image:
# Update message history with the call output
self.loop.message_manager.add_user_message(
[{"type": "text", "text": f"[Computer action completed: {action_type}]"}]
)
# Add image to messages (using correct content types for Agent Response API)
self.loop.message_manager.add_user_message(
[
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": result.base64_image,
},
}
]
)
# If browser environment, include URL if available
# if (
# hasattr(self.loop.computer, "environment")
# and self.loop.computer.environment == "browser"
# ):
# try:
# if hasattr(self.loop.computer.interface, "get_current_url"):
# current_url = await self.loop.computer.interface.get_current_url()
# self.loop.message_manager.add_user_message(
# [
# {
# "type": "text",
# "text": f"Current URL: {current_url}",
# }
# ]
# )
# except Exception as e:
# logger.warning(f"Failed to get current URL: {str(e)}")
# Log successful completion
logger.info(f"Computer call {action_type} executed successfully")
except Exception as e:
logger.error(f"Error executing computer call: {str(e)}")
logger.debug(traceback.format_exc())
# Add error to conversation
self.loop.message_manager.add_user_message(
[{"type": "text", "text": f"Error executing computer action: {str(e)}"}]
)
# Send error to queue
error_response = {
"role": "assistant",
"content": f"Error executing computer action: {str(e)}",
"metadata": {"title": "❌ Error"},
}
await queue.put(error_response)

View File

@@ -1,15 +0,0 @@
"""OpenAI tools module for computer control."""
from .manager import ToolManager
from .computer import ComputerTool
from .base import BaseOpenAITool, ToolResult, ToolError, ToolFailure, CLIResult
__all__ = [
"ToolManager",
"ComputerTool",
"BaseOpenAITool",
"ToolResult",
"ToolError",
"ToolFailure",
"CLIResult",
]

View File

@@ -1,79 +0,0 @@
"""OpenAI-specific tool base classes."""
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, fields, replace
from typing import Any, Dict, List, Optional
from ....core.tools.base import BaseTool
class BaseOpenAITool(BaseTool, metaclass=ABCMeta):
"""Abstract base class for OpenAI-defined tools."""
def __init__(self):
"""Initialize the base OpenAI tool."""
# No specific initialization needed yet, but included for future extensibility
pass
@abstractmethod
async def __call__(self, **kwargs) -> Any:
"""Executes the tool with the given arguments."""
...
@abstractmethod
def to_params(self) -> Dict[str, Any]:
"""Convert tool to OpenAI-specific API parameters.
Returns:
Dictionary with tool parameters for OpenAI API
"""
raise NotImplementedError
@dataclass(kw_only=True, frozen=True)
class ToolResult:
"""Represents the result of a tool execution."""
output: str | None = None
error: str | None = None
base64_image: str | None = None
system: str | None = None
content: list[dict] | None = None
def __bool__(self):
return any(getattr(self, field.name) for field in fields(self))
def __add__(self, other: "ToolResult"):
def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True):
if field and other_field:
if concatenate:
return field + other_field
raise ValueError("Cannot combine tool results")
return field or other_field
return ToolResult(
output=combine_fields(self.output, other.output),
error=combine_fields(self.error, other.error),
base64_image=combine_fields(self.base64_image, other.base64_image, False),
system=combine_fields(self.system, other.system),
content=self.content or other.content, # Use first non-None content
)
def replace(self, **kwargs):
"""Returns a new ToolResult with the given fields replaced."""
return replace(self, **kwargs)
class CLIResult(ToolResult):
"""A ToolResult that can be rendered as a CLI output."""
class ToolFailure(ToolResult):
"""A ToolResult that represents a failure."""
class ToolError(Exception):
"""Raised when a tool encounters an error."""
def __init__(self, message):
self.message = message

View File

@@ -1,326 +0,0 @@
"""Computer tool for OpenAI."""
import asyncio
import base64
import logging
from typing import Literal, Any, Dict, Optional, List, Union
from computer.computer import Computer
from .base import BaseOpenAITool, ToolError, ToolResult
from ....core.tools.computer import BaseComputerTool
TYPING_DELAY_MS = 12
TYPING_GROUP_SIZE = 50
# Key mapping for special keys
KEY_MAPPING = {
"enter": "return",
"backspace": "delete",
"delete": "forwarddelete",
"escape": "esc",
"pageup": "page_up",
"pagedown": "page_down",
"arrowup": "up",
"arrowdown": "down",
"arrowleft": "left",
"arrowright": "right",
"home": "home",
"end": "end",
"tab": "tab",
"space": "space",
"shift": "shift",
"control": "control",
"alt": "alt",
"meta": "command",
}
Action = Literal[
"key",
"type",
"mouse_move",
"left_click",
"right_click",
"double_click",
"screenshot",
"scroll",
"drag",
]
class ComputerTool(BaseComputerTool, BaseOpenAITool):
"""
A tool that allows the agent to interact with the screen, keyboard, and mouse of the current computer.
"""
name: Literal["computer"] = "computer"
api_type: Literal["computer_use_preview"] = "computer_use_preview"
width: Optional[int] = None
height: Optional[int] = None
display_num: Optional[int] = None
computer: Computer # The CUA Computer instance
logger = logging.getLogger(__name__)
def __init__(self, computer: Computer):
"""Initialize the computer tool.
Args:
computer: Computer instance
"""
self.computer = computer
self.width = None
self.height = None
self.logger = logging.getLogger(__name__)
# Initialize the base computer tool first
BaseComputerTool.__init__(self, computer)
# Then initialize the OpenAI tool
BaseOpenAITool.__init__(self)
# Additional initialization
self.width = None # Will be initialized from computer interface
self.height = None # Will be initialized from computer interface
self.display_num = None
def to_params(self) -> Dict[str, Any]:
"""Convert tool to API parameters.
Returns:
Dictionary with tool parameters
"""
if self.width is None or self.height is None:
raise RuntimeError(
"Screen dimensions not initialized. Call initialize_dimensions() first."
)
return {
"type": self.api_type,
"display_width": self.width,
"display_height": self.height,
"display_number": self.display_num,
}
async def initialize_dimensions(self):
"""Initialize screen dimensions from the computer interface."""
try:
display_size = await self.computer.interface.get_screen_size()
self.width = display_size["width"]
self.height = display_size["height"]
assert isinstance(self.width, int) and isinstance(self.height, int)
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
except Exception as e:
# Fall back to defaults if we can't get accurate dimensions
self.width = 1024
self.height = 768
self.logger.warning(
f"Failed to get screen dimensions, using defaults: {self.width}x{self.height}. Error: {e}"
)
async def __call__(
self,
*,
type: str, # OpenAI uses 'type' instead of 'action'
text: Optional[str] = None,
**kwargs,
):
try:
# Ensure dimensions are initialized
if self.width is None or self.height is None:
await self.initialize_dimensions()
if self.width is None or self.height is None:
raise ToolError("Failed to initialize screen dimensions")
if type == "type":
if text is None:
raise ToolError("text is required for type action")
return await self.handle_typing(text)
elif type == "click":
# Map button to correct action name
button = kwargs.get("button")
if button is None:
raise ToolError("button is required for click action")
return await self.handle_click(button, kwargs["x"], kwargs["y"])
elif type == "keypress":
# Check for keys in kwargs if text is None
if text is None:
if "keys" in kwargs and isinstance(kwargs["keys"], list):
# Pass the keys list directly instead of joining and then splitting
return await self.handle_key(kwargs["keys"])
else:
raise ToolError("Either 'text' or 'keys' is required for keypress action")
return await self.handle_key(text)
elif type == "mouse_move":
if "coordinates" not in kwargs:
raise ToolError("coordinates is required for mouse_move action")
return await self.handle_mouse_move(
kwargs["coordinates"][0], kwargs["coordinates"][1]
)
elif type == "scroll":
# Get x, y coordinates directly from kwargs
x = kwargs.get("x")
y = kwargs.get("y")
if x is None or y is None:
raise ToolError("x and y coordinates are required for scroll action")
scroll_x = kwargs.get("scroll_x", 0) // 50
scroll_y = kwargs.get("scroll_y", 0) // 50
return await self.handle_scroll(x, y, scroll_x, scroll_y)
elif type == "drag":
path = kwargs.get("path")
if not path or not isinstance(path, list) or len(path) < 2:
raise ToolError("path is required for drag action and must contain at least 2 points")
return await self.handle_drag(path)
elif type == "screenshot":
return await self.screenshot()
elif type == "wait":
duration = kwargs.get("duration", 1.0)
await asyncio.sleep(duration)
return await self.screenshot()
else:
raise ToolError(f"Unsupported action: {type}")
except Exception as e:
self.logger.error(f"Error in ComputerTool.__call__: {str(e)}")
raise ToolError(f"Failed to execute {type}: {str(e)}")
async def handle_click(self, button: str, x: int, y: int) -> ToolResult:
"""Handle mouse clicks."""
try:
# Perform the click based on button type
if button == "left":
await self.computer.interface.left_click(x, y)
elif button == "right":
await self.computer.interface.right_click(x, y)
elif button == "double":
await self.computer.interface.double_click(x, y)
else:
raise ToolError(f"Unsupported button type: {button}")
# Wait briefly for UI to update
await asyncio.sleep(0.3)
return ToolResult(
output=f"Performed {button} click at ({x}, {y})",
)
except Exception as e:
self.logger.error(f"Error in handle_click: {str(e)}")
raise ToolError(f"Failed to perform {button} click at ({x}, {y}): {str(e)}")
async def handle_typing(self, text: str) -> ToolResult:
"""Handle typing text with a small delay between characters."""
try:
# Type the text with a small delay
await self.computer.interface.type_text(text)
await asyncio.sleep(0.3)
return ToolResult(output=f"Typed: {text}")
except Exception as e:
self.logger.error(f"Error in handle_typing: {str(e)}")
raise ToolError(f"Failed to type '{text}': {str(e)}")
async def handle_key(self, key: Union[str, List[str]]) -> ToolResult:
"""Handle key press, supporting both single keys and combinations.
Args:
key: Either a string (e.g. "ctrl+c") or a list of keys (e.g. ["ctrl", "c"])
"""
try:
# Check if key is already a list
if isinstance(key, list):
keys = [k.strip().lower() for k in key]
else:
# Split key string into list if it's a combination (e.g. "ctrl+c")
keys = [k.strip().lower() for k in key.split("+")]
# Map each key
mapped_keys = [KEY_MAPPING.get(k, k) for k in keys]
if len(mapped_keys) > 1:
# For key combinations (like Ctrl+C)
await self.computer.interface.hotkey(*mapped_keys)
else:
# Single key press
await self.computer.interface.press_key(mapped_keys[0])
# Wait briefly
await asyncio.sleep(0.3)
return ToolResult(output=f"Pressed key: {key}")
except Exception as e:
self.logger.error(f"Error in handle_key: {str(e)}")
raise ToolError(f"Failed to press key '{key}': {str(e)}")
async def handle_mouse_move(self, x: int, y: int) -> ToolResult:
"""Handle mouse movement."""
try:
# Move cursor to position
await self.computer.interface.move_cursor(x, y)
# Wait briefly
await asyncio.sleep(0.2)
return ToolResult(output=f"Moved cursor to ({x}, {y})")
except Exception as e:
self.logger.error(f"Error in handle_mouse_move: {str(e)}")
raise ToolError(f"Failed to move cursor to ({x}, {y}): {str(e)}")
async def handle_scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> ToolResult:
"""Handle scrolling."""
try:
# Move cursor to position first
await self.computer.interface.move_cursor(x, y)
# Scroll based on direction
if scroll_y > 0:
await self.computer.interface.scroll_down(abs(scroll_y))
elif scroll_y < 0:
await self.computer.interface.scroll_up(abs(scroll_y))
# Wait for UI to update
await asyncio.sleep(0.5)
return ToolResult(output=f"Scrolled at ({x}, {y}) by ({scroll_x}, {scroll_y})")
except Exception as e:
self.logger.error(f"Error in handle_scroll: {str(e)}")
raise ToolError(f"Failed to scroll at ({x}, {y}): {str(e)}")
async def handle_drag(self, path: List[Dict[str, int]]) -> ToolResult:
"""Handle mouse drag operation using a path of coordinates.
Args:
path: List of coordinate points {"x": int, "y": int} defining the drag path
Returns:
ToolResult with the operation result and screenshot
"""
try:
# Convert from [{"x": x, "y": y}, ...] format to [(x, y), ...] format
points = [(p["x"], p["y"]) for p in path]
# Perform drag action
if len(points) == 2:
await self.computer.interface.move_cursor(points[0][0], points[0][1])
await self.computer.interface.drag_to(points[1][0], points[1][1])
else:
await self.computer.interface.drag(points, button="left")
# Wait for UI to update
await asyncio.sleep(0.5)
return ToolResult(
output=f"Dragged from ({path[0]['x']}, {path[0]['y']}) to ({path[-1]['x']}, {path[-1]['y']})",
)
except Exception as e:
self.logger.error(f"Error in handle_drag: {str(e)}")
raise ToolError(f"Failed to perform drag operation: {str(e)}")
async def screenshot(self) -> ToolResult:
"""Take a screenshot."""
try:
# Take screenshot
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(output="Screenshot taken", base64_image=base64_screenshot)
except Exception as e:
self.logger.error(f"Error in screenshot: {str(e)}")
raise ToolError(f"Failed to take screenshot: {str(e)}")

View File

@@ -1,106 +0,0 @@
"""Tool manager for the OpenAI provider."""
import logging
from typing import Dict, Any, Optional, List, Callable, Awaitable, Union
from computer import Computer
from ..types import ComputerAction, ResponseItemType
from .computer import ComputerTool
from ....core.tools.base import ToolResult, ToolFailure
from ....core.tools.collection import ToolCollection
logger = logging.getLogger(__name__)
class ToolManager:
"""Manager for computer tools in the OpenAI agent."""
def __init__(
self,
computer: Computer,
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
):
"""Initialize the tool manager.
Args:
computer: Computer instance
acknowledge_safety_check_callback: Optional callback for safety check acknowledgment
"""
self.computer = computer
self.acknowledge_safety_check_callback = acknowledge_safety_check_callback
self._initialized = False
self.computer_tool = ComputerTool(computer)
self.tools = None
logger.info("Initialized OpenAI ToolManager")
async def initialize(self) -> None:
"""Initialize the tool manager."""
if not self._initialized:
logger.info("Initializing OpenAI ToolManager")
# Initialize the computer tool
await self.computer_tool.initialize_dimensions()
# Initialize tool collection
self.tools = ToolCollection(self.computer_tool)
self._initialized = True
logger.info("OpenAI ToolManager initialized")
async def get_tools_definition(self) -> List[Dict[str, Any]]:
"""Get the tools definition for the OpenAI agent.
Returns:
Tools definition for the OpenAI agent
"""
if not self.tools:
raise RuntimeError("Tools not initialized. Call initialize() first.")
# For the OpenAI Agent Response API, we use a special "computer-preview" tool
# which provides the correct interface for computer control
display_width, display_height = await self._get_computer_dimensions()
# Get environment, using "mac" as default since we're on macOS
environment = getattr(self.computer, "environment", "mac")
# Ensure environment is one of the allowed values
if environment not in ["windows", "mac", "linux", "browser"]:
logger.warning(f"Invalid environment value: {environment}, using 'mac' instead")
environment = "mac"
return [
{
"type": "computer-preview",
"display_width": display_width,
"display_height": display_height,
"environment": environment,
}
]
async def _get_computer_dimensions(self) -> tuple[int, int]:
"""Get the dimensions of the computer display.
Returns:
Tuple of (width, height)
"""
# If computer tool is initialized, use its dimensions
if self.computer_tool.width is not None and self.computer_tool.height is not None:
return (self.computer_tool.width, self.computer_tool.height)
# Try to get from computer.interface if available
screen_size = await self.computer.interface.get_screen_size()
return (int(screen_size["width"]), int(screen_size["height"]))
async def execute_tool(self, name: str, tool_input: Dict[str, Any]) -> ToolResult:
"""Execute a tool with the given input.
Args:
name: Name of the tool to execute
tool_input: Input parameters for the tool
Returns:
Result of the tool execution
"""
if not self.tools:
raise RuntimeError("Tools not initialized. Call initialize() first.")
return await self.tools.run(name=name, tool_input=tool_input)

View File

@@ -1,36 +0,0 @@
"""Type definitions for the OpenAI provider."""
from enum import StrEnum, auto
from typing import Dict, List, Optional, Union, Any
from dataclasses import dataclass
class LLMProvider(StrEnum):
"""OpenAI LLM provider types."""
OPENAI = "openai"
class ResponseItemType(StrEnum):
"""Types of items in OpenAI Agent Response output."""
MESSAGE = "message"
COMPUTER_CALL = "computer_call"
COMPUTER_CALL_OUTPUT = "computer_call_output"
REASONING = "reasoning"
@dataclass
class ComputerAction:
"""Represents a computer action to be performed."""
type: str
x: Optional[int] = None
y: Optional[int] = None
text: Optional[str] = None
button: Optional[str] = None
keys: Optional[List[str]] = None
ms: Optional[int] = None
scroll_x: Optional[int] = None
scroll_y: Optional[int] = None
path: Optional[List[Dict[str, int]]] = None

View File

@@ -1,98 +0,0 @@
"""Utility functions for the OpenAI provider."""
import logging
import json
import base64
from typing import Any, Dict, List, Optional
from ...core.types import AgentResponse
logger = logging.getLogger(__name__)
def format_images_for_openai(images_base64: List[str]) -> List[Dict[str, Any]]:
"""Format images for OpenAI Agent Response API.
Args:
images_base64: List of base64 encoded images
Returns:
List of formatted image items for Agent Response API
"""
return [
{"type": "input_image", "image_url": f"data:image/png;base64,{image}"}
for image in images_base64
]
def extract_message_content(message: Dict[str, Any]) -> str:
"""Extract text content from a message.
Args:
message: Message to extract content from
Returns:
Text content from the message
"""
if isinstance(message.get("content"), str):
return message["content"]
if isinstance(message.get("content"), list):
text = ""
role = message.get("role", "user")
for item in message["content"]:
if isinstance(item, dict):
# For user messages
if role == "user" and item.get("type") == "input_text":
text += item.get("text", "")
# For standard format
elif item.get("type") == "text":
text += item.get("text", "")
# For assistant messages in Agent Response API format
elif item.get("type") == "output_text":
text += item.get("text", "")
return text
return ""
def sanitize_message(msg: Dict[str, Any]) -> Dict[str, Any]:
"""Sanitize a message for logging by removing large image data.
Args:
msg: Message to sanitize
Returns:
Sanitized message
"""
if not isinstance(msg, dict):
return msg
sanitized = msg.copy()
# Handle message content
if isinstance(sanitized.get("content"), list):
sanitized_content = []
for item in sanitized["content"]:
if isinstance(item, dict):
# Handle various image types
if item.get("type") == "image_url" and "image_url" in item:
sanitized_content.append({"type": "image_url", "image_url": "[omitted]"})
elif item.get("type") == "input_image" and "image_url" in item:
sanitized_content.append({"type": "input_image", "image_url": "[omitted]"})
elif item.get("type") == "image" and "source" in item:
sanitized_content.append({"type": "image", "source": "[omitted]"})
else:
sanitized_content.append(item)
else:
sanitized_content.append(item)
sanitized["content"] = sanitized_content
# Handle computer_call_output
if sanitized.get("type") == "computer_call_output" and "output" in sanitized:
output = sanitized["output"]
if isinstance(output, dict) and "image_url" in output:
sanitized["output"] = {**output, "image_url": "[omitted]"}
return sanitized

View File

@@ -1 +0,0 @@
"""UI-TARS Agent provider package."""

View File

@@ -1,35 +0,0 @@
"""Base client implementation for Omni providers."""
import logging
from typing import Dict, List, Optional, Any, Tuple
logger = logging.getLogger(__name__)
class BaseUITarsClient:
"""Base class for provider-specific clients."""
def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None):
"""Initialize base client.
Args:
api_key: Optional API key
model: Optional model name
"""
self.api_key = api_key
self.model = model
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""Run interleaved chat completion.
Args:
messages: List of message dicts
system: System prompt
max_tokens: Optional max tokens override
Returns:
Response dict
"""
raise NotImplementedError

View File

@@ -1,263 +0,0 @@
"""MLX LVM client implementation."""
import io
import logging
import base64
import tempfile
import os
import re
import math
from typing import Dict, List, Optional, Any, cast, Tuple
from PIL import Image
from .base import BaseUITarsClient
import mlx.core as mx
from mlx_vlm import load, generate
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_config
from transformers.tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
# Constants for smart_resize
IMAGE_FACTOR = 28
MIN_PIXELS = 100 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
def round_by_factor(number: float, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: float, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: float, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
class MLXVLMUITarsClient(BaseUITarsClient):
"""MLX LVM client implementation class."""
def __init__(
self,
model: str = "mlx-community/UI-TARS-1.5-7B-4bit"
):
"""Initialize MLX LVM client.
Args:
model: Model name or path (defaults to mlx-community/UI-TARS-1.5-7B-4bit)
"""
# Load model and processor
model_obj, processor = load(
model,
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
)
self.config = load_config(model)
self.model = model_obj
self.processor = processor
self.model_name = model
def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
Args:
text: Text containing box tokens
original_size: Original image size (width, height)
model_size: Model processed image size (width, height)
Returns:
Text with processed coordinates
"""
# Find all box tokens
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
def process_coords(match):
model_x, model_y = int(match.group(1)), int(match.group(2))
# Scale coordinates from model space to original image space
# Both original_size and model_size are in (width, height) format
new_x = int(model_x * original_size[0] / model_size[0]) # Width
new_y = int(model_y * original_size[1] / model_size[1]) # Height
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
return re.sub(box_pattern, process_coords, text)
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""Run interleaved chat completion.
Args:
messages: List of message dicts
system: System prompt
max_tokens: Optional max tokens override
Returns:
Response dict
"""
# Ensure the system message is included
if not any(msg.get("role") == "system" for msg in messages):
messages = [{"role": "system", "content": system}] + messages
# Create a deep copy of messages to avoid modifying the original
processed_messages = messages.copy()
# Extract images and process messages
images = []
original_sizes = {} # Track original sizes of images for coordinate mapping
model_sizes = {} # Track model processed sizes
image_index = 0
for msg_idx, msg in enumerate(messages):
content = msg.get("content", [])
if not isinstance(content, list):
continue
# Create a copy of the content list to modify
processed_content = []
for item_idx, item in enumerate(content):
if item.get("type") == "image_url":
image_url = item.get("image_url", {}).get("url", "")
pil_image = None
if image_url.startswith("data:image/"):
# Extract base64 data
base64_data = image_url.split(',')[1]
# Convert base64 to PIL Image
image_data = base64.b64decode(base64_data)
pil_image = Image.open(io.BytesIO(image_data))
else:
# Handle file path or URL
pil_image = Image.open(image_url)
# Store original image size for coordinate mapping
original_size = pil_image.size
original_sizes[image_index] = original_size
# Use smart_resize to determine model size
# Note: smart_resize expects (height, width) but PIL gives (width, height)
height, width = original_size[1], original_size[0]
new_height, new_width = smart_resize(height, width)
# Store model size in (width, height) format for consistent coordinate processing
model_sizes[image_index] = (new_width, new_height)
# Resize the image using the calculated dimensions from smart_resize
resized_image = pil_image.resize((new_width, new_height))
images.append(resized_image)
image_index += 1
# Copy items to processed content list
processed_content.append(item.copy())
# Update the processed message content
processed_messages[msg_idx] = msg.copy()
processed_messages[msg_idx]["content"] = processed_content
logger.info(f"resized {len(images)} from {original_sizes[0]} to {model_sizes[0]}")
# Process user text input with box coordinates after image processing
# Swap original_size and model_size arguments for inverse transformation
for msg_idx, msg in enumerate(processed_messages):
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
if "<|box_start|>" in msg.get("content") and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
orig_size = original_sizes[0]
model_size = model_sizes[0]
# Swap arguments to perform inverse transformation for user input
processed_messages[msg_idx]["content"] = self._process_coordinates(msg["content"], model_size, orig_size)
try:
# Format prompt according to model requirements using the processor directly
prompt = self.processor.apply_chat_template(
processed_messages,
tokenize=False,
add_generation_prompt=True
)
tokenizer = cast(PreTrainedTokenizer, self.processor)
print("generating response...")
# Generate response
text_content, usage = generate(
self.model,
tokenizer,
str(prompt),
images,
verbose=False,
max_tokens=max_tokens
)
from pprint import pprint
print("DEBUG - AGENT GENERATION --------")
pprint(text_content)
print("DEBUG - AGENT GENERATION --------")
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return {
"choices": [
{
"message": {
"role": "assistant",
"content": f"Error generating response: {str(e)}"
},
"finish_reason": "error"
}
],
"model": self.model_name,
"error": str(e)
}
# Process coordinates in the response back to original image space
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
# Get original image size and model size (using the first image)
orig_size = original_sizes[0]
model_size = model_sizes[0]
# Check if output contains box tokens that need processing
if "<|box_start|>" in text_content:
# Process coordinates from model space back to original image space
text_content = self._process_coordinates(text_content, orig_size, model_size)
# Format response to match OpenAI format
response = {
"choices": [
{
"message": {
"role": "assistant",
"content": text_content
},
"finish_reason": "stop"
}
],
"model": self.model_name,
"usage": usage
}
return response

View File

@@ -1,214 +0,0 @@
"""OpenAI-compatible client implementation."""
import os
import logging
from typing import Dict, List, Optional, Any
import aiohttp
import re
from .base import BaseUITarsClient
import asyncio
logger = logging.getLogger(__name__)
# OpenAI-compatible client for the UI_Tars
class OAICompatClient(BaseUITarsClient):
"""OpenAI-compatible API client implementation.
This client can be used with any service that implements the OpenAI API protocol, including:
- Huggingface Text Generation Interface endpoints
- vLLM
- LM Studio
- LocalAI
- Ollama (with OpenAI compatibility)
- Text Generation WebUI
- Any other service with OpenAI API compatibility
"""
def __init__(
self,
api_key: Optional[str] = None,
model: str = "Qwen2.5-VL-7B-Instruct",
provider_base_url: Optional[str] = "http://localhost:8000/v1",
max_tokens: int = 4096,
temperature: float = 0.0,
):
"""Initialize the OpenAI-compatible client.
Args:
api_key: Not used for local endpoints, usually set to "EMPTY"
model: Model name to use
provider_base_url: API base URL. Typically in the format "http://localhost:PORT/v1"
Examples:
- vLLM: "http://localhost:8000/v1"
- LM Studio: "http://localhost:1234/v1"
- LocalAI: "http://localhost:8080/v1"
- Ollama: "http://localhost:11434/v1"
max_tokens: Maximum tokens to generate
temperature: Generation temperature
"""
super().__init__(api_key=api_key or "EMPTY", model=model)
self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
self.model = model
self.provider_base_url = (
provider_base_url or "http://localhost:8000/v1"
) # Use default if None
self.max_tokens = max_tokens
self.temperature = temperature
def _extract_base64_image(self, text: str) -> Optional[str]:
"""Extract base64 image data from an HTML img tag."""
pattern = r'data:image/[^;]+;base64,([^"]+)'
match = re.search(pattern, text)
return match.group(1) if match else None
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Create a loggable version of messages with image data truncated."""
loggable_messages = []
for msg in messages:
if isinstance(msg.get("content"), list):
new_content = []
for content in msg["content"]:
if content.get("type") == "image":
new_content.append(
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
)
else:
new_content.append(content)
loggable_messages.append({"role": msg["role"], "content": new_content})
else:
loggable_messages.append(msg)
return loggable_messages
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""Run interleaved chat completion.
Args:
messages: List of message dicts
system: System prompt
max_tokens: Optional max tokens override
Returns:
Response dict
"""
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
final_messages = [
{
"role": "system",
"content": [
{ "type": "text", "text": system }
]
}
]
# Process messages
for item in messages:
if isinstance(item, dict):
if isinstance(item["content"], list):
# Content is already in the correct format
final_messages.append(item)
else:
# Single string content, check for image
base64_img = self._extract_base64_image(item["content"])
if base64_img:
message = {
"role": item["role"],
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
}
],
}
else:
message = {
"role": item["role"],
"content": [{"type": "text", "text": item["content"]}],
}
final_messages.append(message)
else:
# String content, check for image
base64_img = self._extract_base64_image(item)
if base64_img:
message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
}
],
}
else:
message = {"role": "user", "content": [{"type": "text", "text": item}]}
final_messages.append(message)
payload = {
"model": self.model,
"messages": final_messages,
"max_tokens": max_tokens or self.max_tokens,
"temperature": self.temperature,
"top_p": 0.7,
}
try:
async with aiohttp.ClientSession() as session:
# Use default base URL if none provided
base_url = self.provider_base_url or "http://localhost:8000/v1"
# Check if the base URL already includes the chat/completions endpoint
endpoint_url = base_url
if not endpoint_url.endswith("/chat/completions"):
# If URL is RunPod format, make it OpenAI compatible
if endpoint_url.startswith("https://api.runpod.ai/v2/"):
# Extract RunPod endpoint ID
parts = endpoint_url.split("/")
if len(parts) >= 5:
runpod_id = parts[4]
endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
# If the URL ends with /v1, append /chat/completions
elif endpoint_url.endswith("/v1"):
endpoint_url = f"{endpoint_url}/chat/completions"
# If the URL doesn't end with /v1, make sure it has a proper structure
elif not endpoint_url.endswith("/"):
endpoint_url = f"{endpoint_url}/chat/completions"
else:
endpoint_url = f"{endpoint_url}chat/completions"
# Log the endpoint URL for debugging
logger.debug(f"Using endpoint URL: {endpoint_url}")
async with session.post(endpoint_url, headers=headers, json=payload) as response:
# Log the status and content type
logger.debug(f"Status: {response.status}")
logger.debug(f"Content-Type: {response.headers.get('Content-Type')}")
# Get the raw text of the response
response_text = await response.text()
logger.debug(f"Response content: {response_text}")
# if 503, then the endpoint is still warming up
if response.status == 503:
logger.error(f"Endpoint is still warming up, trying again in 30 seconds...")
await asyncio.sleep(30)
raise Exception(f"Endpoint is still warming up: {response_text}")
# Try to parse as JSON if the content type is appropriate
if "application/json" in response.headers.get('Content-Type', ''):
response_json = await response.json()
else:
raise Exception(f"Response is not JSON format")
if response.status != 200:
logger.error(f"Error in API call: {response_text}")
raise Exception(f"API error: {response_text}")
return response_json
except Exception as e:
logger.error(f"Error in API call: {str(e)}")
raise

View File

@@ -1,660 +0,0 @@
"""UI-TARS-specific agent loop implementation."""
import logging
import asyncio
import re
import os
import json
import base64
import copy
from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator, cast
from httpx import ConnectError, ReadTimeout
from ...core.base import BaseLoop
from ...core.messages import StandardMessageManager, ImageRetentionConfig
from ...core.types import AgentResponse, LLMProvider
from ...core.visualization import VisualizationHelper
from computer import Computer
from .utils import add_box_token, parse_actions, parse_action_parameters, to_agent_response_format
from .tools.manager import ToolManager
from .tools.computer import ToolResult
from .prompts import COMPUTER_USE, SYSTEM_PROMPT, MAC_SPECIFIC_NOTES
from .clients.oaicompat import OAICompatClient
from .clients.mlxvlm import MLXVLMUITarsClient
logger = logging.getLogger(__name__)
class UITARSLoop(BaseLoop):
"""UI-TARS-specific implementation of the agent loop.
This class extends BaseLoop to provide support for the UI-TARS model
with computer control capabilities.
"""
###########################################
# INITIALIZATION AND CONFIGURATION
###########################################
def __init__(
self,
computer: Computer,
api_key: str,
model: str,
provider: Optional[LLMProvider] = None,
provider_base_url: Optional[str] = "http://localhost:8000/v1",
only_n_most_recent_images: Optional[int] = 2,
base_dir: Optional[str] = "trajectories",
max_retries: int = 3,
retry_delay: float = 1.0,
save_trajectory: bool = True,
**kwargs,
):
"""Initialize the loop.
Args:
computer: Computer instance
api_key: API key (may not be needed for local endpoints)
model: Model name (e.g., "ui-tars")
provider_base_url: Base URL for the API provider
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
base_dir: Base directory for saving experiment data
max_retries: Maximum number of retries for API calls
retry_delay: Delay between retries in seconds
save_trajectory: Whether to save trajectory data
provider: The LLM provider to use (defaults to OAICOMPAT if not specified)
"""
# Set provider before initializing base class
self.provider = provider or LLMProvider.OAICOMPAT
self.provider_base_url = provider_base_url
# Initialize message manager with image retention config
self.message_manager = StandardMessageManager(
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
)
# Initialize base class (which will set up experiment manager)
super().__init__(
computer=computer,
model=model,
api_key=api_key,
max_retries=max_retries,
retry_delay=retry_delay,
base_dir=base_dir,
save_trajectory=save_trajectory,
only_n_most_recent_images=only_n_most_recent_images,
**kwargs,
)
# Set API client attributes
self.client = None
self.retry_count = 0
self.loop_task = None # Store the loop task for cancellation
# Initialize visualization helper
self.viz_helper = VisualizationHelper(agent=self)
# Initialize tool manager
self.tool_manager = ToolManager(computer=computer)
logger.info("UITARSLoop initialized with StandardMessageManager")
async def initialize(self) -> None:
"""Initialize the loop by setting up tools and clients."""
# Initialize base class
await super().initialize()
# Initialize tool manager with error handling
try:
logger.info("Initializing tool manager...")
await self.tool_manager.initialize()
logger.info("Tool manager initialized successfully.")
except Exception as e:
logger.error(f"Error initializing tool manager: {str(e)}")
logger.warning("Will attempt to initialize tools on first use.")
# Initialize client for the selected provider
try:
await self.initialize_client()
except Exception as e:
logger.error(f"Error initializing client: {str(e)}")
raise RuntimeError(f"Failed to initialize client: {str(e)}")
###########################################
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
###########################################
async def initialize_client(self) -> None:
"""Initialize the appropriate client.
Implements abstract method from BaseLoop to set up the specific
provider client based on the configured provider.
"""
try:
if self.provider == LLMProvider.MLXVLM:
logger.info(f"Initializing MLX VLM client for UI-TARS with model {self.model}...")
self.client = MLXVLMUITarsClient(
model=self.model,
)
logger.info(f"Initialized MLX VLM client with model {self.model}")
else:
# Default to OAICompat client for other providers
logger.info(f"Initializing OAICompat client for UI-TARS with model {self.model}...")
self.client = OAICompatClient(
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
model=self.model,
provider_base_url=self.provider_base_url,
)
logger.info(f"Initialized OAICompat client with model {self.model}")
except Exception as e:
logger.error(f"Error initializing client: {str(e)}")
self.client = None
raise RuntimeError(f"Failed to initialize client: {str(e)}")
###########################################
# MESSAGE FORMATTING
###########################################
def to_uitars_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert messages to UI-TARS compatible format.
Args:
messages: List of messages in standard format
Returns:
List of messages formatted for UI-TARS
"""
# Create a copy of the messages to avoid modifying the original
uitars_messages = copy.deepcopy(messages)
# Find the first user message to modify
first_user_idx = None
instruction = ""
for idx, msg in enumerate(uitars_messages):
if msg.get("role") == "user":
first_user_idx = idx
content = msg.get("content", "")
if isinstance(content, str):
instruction = content
break
elif isinstance(content, list):
for item in content:
if item.get("type") == "text":
instruction = item.get("text", "")
break
if instruction:
break
# Only modify the first user message if found
if first_user_idx is not None and instruction:
# Create the computer use prompt
user_prompt = COMPUTER_USE.format(
instruction='\n'.join([instruction, MAC_SPECIFIC_NOTES]),
language="English"
)
# Replace the content of the first user message
if isinstance(uitars_messages[first_user_idx].get("content", ""), str):
uitars_messages[first_user_idx]["content"] = [{"type": "text", "text": user_prompt}]
elif isinstance(uitars_messages[first_user_idx].get("content", ""), list):
# Find and replace only the text part, keeping images
for i, item in enumerate(uitars_messages[first_user_idx]["content"]):
if item.get("type") == "text":
uitars_messages[first_user_idx]["content"][i]["text"] = user_prompt
break
# Add box tokens to assistant responses
for idx, msg in enumerate(uitars_messages):
if msg.get("role") == "assistant":
content = msg.get("content", "")
if content and isinstance(content, list):
for i, part in enumerate(content):
if part.get('type') == 'text':
uitars_messages[idx]["content"][i]["text"] = add_box_token(part['text'])
return uitars_messages
###########################################
# API CALL HANDLING
###########################################
async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any:
"""Make API call to provider with retry logic."""
# Create new turn directory for this API call
self._create_turn_dir()
request_data = None
last_error = None
for attempt in range(self.max_retries):
try:
# Ensure client is initialized
if self.client is None:
logger.info(
f"Client not initialized in _make_api_call (attempt {attempt+1}), initializing now..."
)
await self.initialize_client()
if self.client is None:
raise RuntimeError("Failed to initialize client")
# Get messages in standard format from the message manager
self.message_manager.messages = messages.copy()
prepared_messages = self.message_manager.get_messages()
# Convert messages to UI-TARS format
uitars_messages = self.to_uitars_format(prepared_messages)
# Log request
request_data = {
"messages": uitars_messages,
"max_tokens": self.max_tokens,
"system": system_prompt,
}
self._log_api_call("request", request_data)
# Make API call
response = await self.client.run_interleaved(
messages=uitars_messages,
system=system_prompt,
max_tokens=self.max_tokens,
)
# Log success response
self._log_api_call("response", request_data, response)
return response
except (ConnectError, ReadTimeout) as e:
last_error = e
logger.warning(
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
)
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
# Reset client on connection errors to force re-initialization
self.client = None
continue
except RuntimeError as e:
# Handle client initialization errors specifically
last_error = e
self._log_api_call("error", request_data, error=e)
logger.error(
f"Client initialization error (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
)
if attempt < self.max_retries - 1:
# Reset client to force re-initialization
self.client = None
await asyncio.sleep(self.retry_delay)
continue
except Exception as e:
# Log unexpected error
last_error = e
self._log_api_call("error", request_data, error=e)
logger.error(f"Unexpected error in API call: {str(e)}")
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay)
continue
# If we get here, all retries failed
error_message = f"API call failed after {self.max_retries} attempts"
if last_error:
error_message += f": {str(last_error)}"
logger.error(error_message)
raise RuntimeError(error_message)
###########################################
# RESPONSE AND ACTION HANDLING
###########################################
async def _handle_response(
self, response: Any, messages: List[Dict[str, Any]]
) -> Tuple[bool, bool]:
"""Handle API response.
Args:
response: API response
messages: List of messages to update
Returns:
Tuple of (should_continue, action_screenshot_saved)
"""
action_screenshot_saved = False
try:
# Step 1: Extract the raw response text
raw_text = None
try:
# OpenAI-compatible response format
raw_text = response["choices"][0]["message"]["content"]
except (KeyError, TypeError, IndexError) as e:
logger.error(f"Invalid response format: {str(e)}")
return True, action_screenshot_saved
# Step 2: Add the response to message history
self.message_manager.add_assistant_message([{"type": "text", "text": raw_text}])
# Step 3: Parse actions from the response
parsed_actions = parse_actions(raw_text)
if not parsed_actions:
logger.warning("No action found in the response")
return True, action_screenshot_saved
# Step 4: Execute each action
for action in parsed_actions:
action_type = None
# Handle "finished" action
if action.startswith("finished"):
logger.info("Agent completed the task")
return False, action_screenshot_saved
# Process other action types (click, type, etc.)
try:
# Parse action parameters using the utility function
action_name, tool_args = parse_action_parameters(action)
if not action_name:
logger.warning(f"Could not parse action: {action}")
continue
# Mark actions that would create screenshots
if action_name in ["click", "left_double", "right_single", "drag", "scroll"]:
action_screenshot_saved = True
# Execute the tool with prepared arguments
await self._ensure_tools_initialized()
# Let's log what we're about to execute for debugging
logger.info(f"Executing computer tool with arguments: {tool_args}")
result = await self.tool_manager.execute_tool(name="computer", tool_input=tool_args)
# Handle the result
if hasattr(result, "error") and result.error:
logger.error(f"Error executing tool: {result.error}")
else:
# Action was successful
logger.info(f"Successfully executed {action_name}")
# Save screenshot if one was returned and we haven't already saved one
if hasattr(result, "base64_image") and result.base64_image:
self._save_screenshot(result.base64_image, action_type=action_name)
action_screenshot_saved = True
except Exception as e:
logger.error(f"Error executing action {action}: {str(e)}")
# Continue the loop if there are actions to process
return True, action_screenshot_saved
except Exception as e:
logger.error(f"Error handling response: {str(e)}")
# Add error message using the message manager
error_message = [{"type": "text", "text": f"Error: {str(e)}"}]
self.message_manager.add_assistant_message(error_message)
raise
###########################################
# SCREEN HANDLING
###########################################
async def _get_current_screen(self, save_screenshot: bool = True) -> str:
"""Get the current screen as a base64 encoded image.
Args:
save_screenshot: Whether to save the screenshot
Returns:
Base64 encoded screenshot
"""
try:
# Take a screenshot
screenshot = await self.computer.interface.screenshot()
# Convert to base64
img_base64 = base64.b64encode(screenshot).decode("utf-8")
# Process screenshot through hooks and save if needed
await self.handle_screenshot(img_base64, action_type="state")
# Save screenshot if requested
if save_screenshot and self.save_trajectory:
self._save_screenshot(img_base64, action_type="state")
return img_base64
except Exception as e:
logger.error(f"Error getting current screen: {str(e)}")
raise
###########################################
# SYSTEM PROMPT
###########################################
def _get_system_prompt(self) -> str:
"""Get the system prompt for the model."""
return SYSTEM_PROMPT
###########################################
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
###########################################
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
"""Run the agent loop with provided messages.
Args:
messages: List of messages in standard OpenAI format
Yields:
Agent response format
"""
try:
logger.info(f"Starting UITARSLoop run with {len(messages)} messages")
# Initialize the message manager with the provided messages
self.message_manager.messages = messages.copy()
# Create queue for response streaming
queue = asyncio.Queue()
# Start loop in background task
self.loop_task = asyncio.create_task(self._run_loop(queue, messages))
# Process and yield messages as they arrive
while True:
try:
item = await queue.get()
if item is None: # Stop signal
break
yield item
queue.task_done()
except Exception as e:
logger.error(f"Error processing queue item: {str(e)}")
continue
# Wait for loop to complete
await self.loop_task
# Send completion message
yield {
"role": "assistant",
"content": "Task completed successfully.",
"metadata": {"title": "✅ Complete"},
}
except Exception as e:
logger.error(f"Error in run method: {str(e)}")
yield {
"role": "assistant",
"content": f"Error: {str(e)}",
"metadata": {"title": "❌ Error"},
}
async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
"""Internal method to run the agent loop with provided messages.
Args:
queue: Queue to put responses into
messages: List of messages in standard OpenAI format
"""
# Continue running until explicitly told to stop
running = True
turn_created = False
# Track if an action-specific screenshot has been saved this turn
action_screenshot_saved = False
attempt = 0
max_attempts = 3
try:
while running and attempt < max_attempts:
try:
# Create a new turn directory if it's not already created
if not turn_created:
self._create_turn_dir()
turn_created = True
# Ensure client is initialized
if self.client is None:
logger.info("Initializing client...")
await self.initialize_client()
if self.client is None:
raise RuntimeError("Failed to initialize client")
logger.info("Client initialized successfully")
# Get current screen
base64_screenshot = await self._get_current_screen()
# Add screenshot to message history
self.message_manager.add_user_message(
[
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_screenshot}"},
}
]
)
logger.info("Added screenshot to message history")
# Get system prompt
system_prompt = self._get_system_prompt()
# Make API call with retries
response = await self._make_api_call(
self.message_manager.messages, system_prompt
)
# Handle the response (may execute actions)
# Returns: (should_continue, action_screenshot_saved)
should_continue, new_screenshot_saved = await self._handle_response(
response, self.message_manager.messages
)
# Update whether an action screenshot was saved this turn
action_screenshot_saved = action_screenshot_saved or new_screenshot_saved
agent_response = await to_agent_response_format(
response,
messages,
model=self.model,
)
# Log standardized response for ease of parsing
self._log_api_call("agent_response", request=None, response=agent_response)
# Put the response in the queue
await queue.put(agent_response)
# Check if we should continue this conversation
running = should_continue
# Create a new turn directory if we're continuing
if running:
turn_created = False
# Reset attempt counter on success
attempt = 0
except Exception as e:
attempt += 1
error_msg = f"Error in run method (attempt {attempt}/{max_attempts}): {str(e)}"
logger.error(error_msg)
# If this is our last attempt, provide more info about the error
if attempt >= max_attempts:
logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}")
await queue.put({
"role": "assistant",
"content": f"Error: {str(e)}",
"metadata": {"title": "❌ Error"},
})
# Create a brief delay before retrying
await asyncio.sleep(1)
finally:
# Signal that we're done
await queue.put(None)
async def cancel(self) -> None:
"""Cancel the currently running agent loop task.
This method stops the ongoing processing in the agent loop
by cancelling the loop_task if it exists and is running.
"""
if self.loop_task and not self.loop_task.done():
logger.info("Cancelling UITARS loop task")
self.loop_task.cancel()
try:
# Wait for the task to be cancelled with a timeout
await asyncio.wait_for(self.loop_task, timeout=2.0)
except asyncio.TimeoutError:
logger.warning("Timeout while waiting for loop task to cancel")
except asyncio.CancelledError:
logger.info("Loop task cancelled successfully")
except Exception as e:
logger.error(f"Error while cancelling loop task: {str(e)}")
finally:
logger.info("UITARS loop task cancelled")
else:
logger.info("No active UITARS loop task to cancel")
###########################################
# UTILITY METHODS
###########################################
async def _ensure_tools_initialized(self) -> None:
"""Ensure the tool manager and tools are initialized before use."""
if not hasattr(self.tool_manager, "tools") or self.tool_manager.tools is None:
logger.info("Tools not initialized. Initializing now...")
await self.tool_manager.initialize()
logger.info("Tools initialized successfully.")
async def process_model_response(self, response_text: str) -> Optional[Dict[str, Any]]:
"""Process model response to extract tool calls.
Args:
response_text: Model response text
Returns:
Extracted tool information, or None if no tool call was found
"""
# UI-TARS doesn't use the standard tool call format, so we parse its actions differently
parsed_actions = parse_actions(response_text)
if parsed_actions:
return {"actions": parsed_actions}
return None

View File

@@ -1,63 +0,0 @@
"""Prompts for UI-TARS agent."""
MAC_SPECIFIC_NOTES = """
(You are operating on macOS, use 'cmd' instead of 'ctrl' for most shortcuts e.g., hotkey(key='cmd c') for copy, hotkey(key='cmd v') for paste, hotkey(key='cmd t') for new tab).)
"""
SYSTEM_PROMPT = "You are a helpful assistant."
COMPUTER_USE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
hotkey(key='')
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
wait() #Sleep for 5s and take a screenshot to check for any changes.
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""
MOBILE_USE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
```
Thought: ...
Action: ...
```
## Action Space
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
long_press(start_box='<|box_start|>(x1,y1)<|box_end|>')
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
open_app(app_name=\'\')
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
press_home()
press_back()
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
## Note
- Use {language} in `Thought` part.
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
## User Instruction
{instruction}
"""

View File

@@ -1 +0,0 @@
"""UI-TARS tools package."""

View File

@@ -1,283 +0,0 @@
"""Computer tool for UI-TARS."""
import asyncio
import base64
import logging
import re
from typing import Any, Dict, List, Optional, Literal, Union
from computer import Computer
from ....core.tools.base import ToolResult, ToolFailure
from ....core.tools.computer import BaseComputerTool
logger = logging.getLogger(__name__)
class ComputerTool(BaseComputerTool):
"""
A tool that allows the UI-TARS agent to interact with the screen, keyboard, and mouse.
"""
name: str = "computer"
width: Optional[int] = None
height: Optional[int] = None
computer: Computer
def __init__(self, computer: Computer):
"""Initialize the computer tool.
Args:
computer: Computer instance
"""
super().__init__(computer)
self.computer = computer
self.width = None
self.height = None
self.logger = logging.getLogger(__name__)
def to_params(self) -> Dict[str, Any]:
"""Convert tool to API parameters.
Returns:
Dictionary with tool parameters
"""
if self.width is None or self.height is None:
raise RuntimeError(
"Screen dimensions not initialized. Call initialize_dimensions() first."
)
return {
"type": "computer",
"display_width": self.width,
"display_height": self.height,
}
async def initialize_dimensions(self) -> None:
"""Initialize screen dimensions from the computer interface."""
try:
display_size = await self.computer.interface.get_screen_size()
self.width = display_size["width"]
self.height = display_size["height"]
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
except Exception as e:
# Fall back to defaults if we can't get accurate dimensions
self.width = 1024
self.height = 768
self.logger.warning(
f"Failed to get screen dimensions, using defaults: {self.width}x{self.height}. Error: {e}"
)
async def __call__(
self,
*,
action: str,
**kwargs,
) -> ToolResult:
"""Execute a computer action.
Args:
action: The action to perform (based on UI-TARS action space)
**kwargs: Additional parameters for the action
Returns:
ToolResult containing action output and possibly a base64 image
"""
try:
# Ensure dimensions are initialized
if self.width is None or self.height is None:
await self.initialize_dimensions()
if self.width is None or self.height is None:
return ToolFailure(error="Failed to initialize screen dimensions")
# Handle actions defined in UI-TARS action space (from prompts.py)
# Handle standard click (left click)
if action == "click":
if "x" in kwargs and "y" in kwargs:
x, y = kwargs["x"], kwargs["y"]
await self.computer.interface.left_click(x, y)
# Wait briefly for UI to update
await asyncio.sleep(0.5)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Clicked at ({x}, {y})",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing coordinates for click action")
# Handle double click
elif action == "left_double":
if "x" in kwargs and "y" in kwargs:
x, y = kwargs["x"], kwargs["y"]
await self.computer.interface.double_click(x, y)
# Wait briefly for UI to update
await asyncio.sleep(0.5)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Double-clicked at ({x}, {y})",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing coordinates for left_double action")
# Handle right click
elif action == "right_single":
if "x" in kwargs and "y" in kwargs:
x, y = kwargs["x"], kwargs["y"]
await self.computer.interface.right_click(x, y)
# Wait briefly for UI to update
await asyncio.sleep(0.5)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Right-clicked at ({x}, {y})",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing coordinates for right_single action")
# Handle typing text
elif action == "type_text":
if "text" in kwargs:
text = kwargs["text"]
await self.computer.interface.type_text(text)
# Wait for UI to update
await asyncio.sleep(0.3)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Typed: {text}",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing text for type action")
# Handle hotkey
elif action == "hotkey":
if "keys" in kwargs:
keys = kwargs["keys"]
if len(keys) > 1:
await self.computer.interface.hotkey(*keys)
else:
# Single key press
await self.computer.interface.press_key(keys[0])
# Wait for UI to update
await asyncio.sleep(0.3)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Pressed hotkey: {', '.join(keys)}",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing keys for hotkey action")
# Handle drag action
elif action == "drag":
if all(k in kwargs for k in ["start_x", "start_y", "end_x", "end_y"]):
start_x, start_y = kwargs["start_x"], kwargs["start_y"]
end_x, end_y = kwargs["end_x"], kwargs["end_y"]
# Perform drag
await self.computer.interface.move_cursor(start_x, start_y)
await self.computer.interface.drag_to(end_x, end_y)
# Wait for UI to update
await asyncio.sleep(0.5)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Dragged from ({start_x}, {start_y}) to ({end_x}, {end_y})",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing coordinates for drag action")
# Handle scroll action
elif action == "scroll":
if all(k in kwargs for k in ["x", "y", "direction"]):
x, y = kwargs["x"], kwargs["y"]
direction = kwargs["direction"]
# Move cursor to position
await self.computer.interface.move_cursor(x, y)
# Scroll based on direction
if direction == "down":
await self.computer.interface.scroll_down(5)
elif direction == "up":
await self.computer.interface.scroll_up(5)
elif direction == "right":
pass # await self.computer.interface.scroll_right(5)
elif direction == "left":
pass # await self.computer.interface.scroll_left(5)
else:
return ToolFailure(error=f"Invalid scroll direction: {direction}")
# Wait for UI to update
await asyncio.sleep(0.5)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Scrolled {direction} at ({x}, {y})",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing parameters for scroll action")
# Handle wait action
elif action == "wait":
# Sleep for 5 seconds as specified in the action space
await asyncio.sleep(5)
# Take screenshot after waiting
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output="Waited for 5 seconds",
base64_image=base64_screenshot,
)
# Handle finished action (task completion)
elif action == "finished":
content = kwargs.get("content", "Task completed")
return ToolResult(
output=f"Task finished: {content}",
)
return await self._handle_scroll(action)
else:
return ToolFailure(error=f"Unsupported action: {action}")
except Exception as e:
self.logger.error(f"Error in ComputerTool.__call__: {str(e)}")
return ToolFailure(error=f"Failed to execute {action}: {str(e)}")

View File

@@ -1,60 +0,0 @@
"""Tool manager for the UI-TARS provider."""
import logging
from typing import Any, Dict, List, Optional
from computer import Computer
from ....core.tools import BaseToolManager
from ....core.tools.collection import ToolCollection
from .computer import ComputerTool
logger = logging.getLogger(__name__)
class ToolManager(BaseToolManager):
"""Manages UI-TARS provider tool initialization and execution."""
def __init__(self, computer: Computer):
"""Initialize the tool manager.
Args:
computer: Computer instance for computer-related tools
"""
super().__init__(computer)
# Initialize UI-TARS-specific tools
self.computer_tool = ComputerTool(self.computer)
self._initialized = False
def _initialize_tools(self) -> ToolCollection:
"""Initialize all available tools."""
return ToolCollection(self.computer_tool)
async def _initialize_tools_specific(self) -> None:
"""Initialize UI-TARS provider-specific tool requirements."""
await self.computer_tool.initialize_dimensions()
def get_tool_params(self) -> List[Dict[str, Any]]:
"""Get tool parameters for API calls.
Returns:
List of tool parameters for the current provider's API
"""
if self.tools is None:
raise RuntimeError("Tools not initialized. Call initialize() first.")
return self.tools.to_params()
async def execute_tool(self, name: str, tool_input: dict[str, Any]) -> Any:
"""Execute a tool with the given input.
Args:
name: Name of the tool to execute
tool_input: Input parameters for the tool
Returns:
Result of the tool execution
"""
if self.tools is None:
raise RuntimeError("Tools not initialized. Call initialize() first.")
return await self.tools.run(name=name, tool_input=tool_input)

Some files were not shown because too many files have changed in this diff Show More