mirror of
https://github.com/trycua/computer.git
synced 2026-01-01 19:10:30 -06:00
removed agent2
This commit is contained in:
@@ -1,381 +0,0 @@
|
||||
<div align="center">
|
||||
<h1>
|
||||
<div class="image-wrapper" style="display: inline-block;">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="../../../img/logo_white.png" style="display: block; margin: auto;">
|
||||
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="../../../img/logo_black.png" style="display: block; margin: auto;">
|
||||
<img alt="Shows my svg">
|
||||
</picture>
|
||||
</div>
|
||||
|
||||
[](#)
|
||||
[](#)
|
||||
[](https://discord.com/invite/mVnXXpdE85)
|
||||
[](https://pypi.org/project/cua-computer/)
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
**cua-agent** is a general Computer-Use framework with liteLLM integration for running agentic workflows on macOS, Windows, and Linux sandboxes. It provides a unified interface for computer-use agents across multiple LLM providers with advanced callback system for extensibility.
|
||||
|
||||
## Features
|
||||
|
||||
- **Safe Computer-Use/Tool-Use**: Using Computer SDK for sandboxed desktops
|
||||
- **Multi-Agent Support**: Anthropic Claude, OpenAI computer-use-preview, UI-TARS, Omniparser + any LLM
|
||||
- **Multi-API Support**: Take advantage of liteLLM supporting 100+ LLMs / model APIs, including local models (`huggingface-local/`, `ollama_chat/`, `mlx/`)
|
||||
- **Cross-Platform**: Works on Windows, macOS, and Linux with cloud and local computer instances
|
||||
- **Extensible Callbacks**: Built-in support for image retention, cache control, PII anonymization, budget limits, and trajectory tracking
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
pip install "cua-agent[all]"
|
||||
|
||||
# or install specific providers
|
||||
pip install "cua-agent[openai]" # OpenAI computer-use-preview support
|
||||
pip install "cua-agent[anthropic]" # Anthropic Claude support
|
||||
pip install "cua-agent[omni]" # Omniparser + any LLM support
|
||||
pip install "cua-agent[uitars]" # UI-TARS
|
||||
pip install "cua-agent[uitars-mlx]" # UI-TARS + MLX support
|
||||
pip install "cua-agent[uitars-hf]" # UI-TARS + Huggingface support
|
||||
pip install "cua-agent[ui]" # Gradio UI support
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
import os
|
||||
from agent import ComputerAgent
|
||||
from computer import Computer
|
||||
|
||||
async def main():
|
||||
# Set up computer instance
|
||||
async with Computer(
|
||||
os_type="linux",
|
||||
provider_type="cloud",
|
||||
name=os.getenv("CUA_CONTAINER_NAME"),
|
||||
api_key=os.getenv("CUA_API_KEY")
|
||||
) as computer:
|
||||
|
||||
# Create agent
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
tools=[computer],
|
||||
only_n_most_recent_images=3,
|
||||
trajectory_dir="trajectories",
|
||||
max_trajectory_budget=5.0 # $5 budget limit
|
||||
)
|
||||
|
||||
# Run agent
|
||||
messages = [{"role": "user", "content": "Take a screenshot and tell me what you see"}]
|
||||
|
||||
async for result in agent.run(messages):
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
### Anthropic Claude (Computer Use API)
|
||||
```python
|
||||
model="anthropic/claude-3-5-sonnet-20241022"
|
||||
model="anthropic/claude-3-5-sonnet-20240620"
|
||||
model="anthropic/claude-opus-4-20250514"
|
||||
model="anthropic/claude-sonnet-4-20250514"
|
||||
```
|
||||
|
||||
### OpenAI Computer Use Preview
|
||||
```python
|
||||
model="openai/computer-use-preview"
|
||||
```
|
||||
|
||||
### UI-TARS (Local or Huggingface Inference)
|
||||
```python
|
||||
model="huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B"
|
||||
model="ollama_chat/0000/ui-tars-1.5-7b"
|
||||
```
|
||||
|
||||
### Omniparser + Any LLM
|
||||
```python
|
||||
model="omniparser+ollama_chat/mistral-small3.2"
|
||||
model="omniparser+vertex_ai/gemini-pro"
|
||||
model="omniparser+anthropic/claude-3-5-sonnet-20241022"
|
||||
model="omniparser+openai/gpt-4o"
|
||||
```
|
||||
|
||||
## Custom Tools
|
||||
|
||||
Define custom tools using decorated functions:
|
||||
|
||||
```python
|
||||
from computer.helpers import sandboxed
|
||||
|
||||
@sandboxed()
|
||||
def read_file(location: str) -> str:
|
||||
"""Read contents of a file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
location : str
|
||||
Path to the file to read
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Contents of the file or error message
|
||||
"""
|
||||
try:
|
||||
with open(location, 'r') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
|
||||
def calculate(a: int, b: int) -> int:
|
||||
"""Calculate the sum of two integers"""
|
||||
return a + b
|
||||
|
||||
# Use with agent
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
tools=[computer, read_file, calculate]
|
||||
)
|
||||
```
|
||||
|
||||
## Callbacks System
|
||||
|
||||
agent provides a comprehensive callback system for extending functionality:
|
||||
|
||||
### Built-in Callbacks
|
||||
|
||||
```python
|
||||
from agent.callbacks import (
|
||||
ImageRetentionCallback,
|
||||
TrajectorySaverCallback,
|
||||
BudgetManagerCallback,
|
||||
LoggingCallback
|
||||
)
|
||||
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
tools=[computer],
|
||||
callbacks=[
|
||||
ImageRetentionCallback(only_n_most_recent_images=3),
|
||||
TrajectorySaverCallback(trajectory_dir="trajectories"),
|
||||
BudgetManagerCallback(max_budget=10.0, raise_error=True),
|
||||
LoggingCallback(level=logging.INFO)
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Callbacks
|
||||
|
||||
```python
|
||||
from agent.callbacks.base import AsyncCallbackHandler
|
||||
|
||||
class CustomCallback(AsyncCallbackHandler):
|
||||
async def on_llm_start(self, messages):
|
||||
"""Preprocess messages before LLM call"""
|
||||
# Add custom preprocessing logic
|
||||
return messages
|
||||
|
||||
async def on_llm_end(self, messages):
|
||||
"""Postprocess messages after LLM call"""
|
||||
# Add custom postprocessing logic
|
||||
return messages
|
||||
|
||||
async def on_usage(self, usage):
|
||||
"""Track usage information"""
|
||||
print(f"Tokens used: {usage.total_tokens}")
|
||||
```
|
||||
|
||||
## Budget Management
|
||||
|
||||
Control costs with built-in budget management:
|
||||
|
||||
```python
|
||||
# Simple budget limit
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
max_trajectory_budget=5.0 # $5 limit
|
||||
)
|
||||
|
||||
# Advanced budget configuration
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
max_trajectory_budget={
|
||||
"max_budget": 10.0,
|
||||
"raise_error": True, # Raise error when exceeded
|
||||
"reset_after_each_run": False # Persistent across runs
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Trajectory Management
|
||||
|
||||
Save and replay agent conversations:
|
||||
|
||||
```python
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
trajectory_dir="trajectories", # Auto-save trajectories
|
||||
tools=[computer]
|
||||
)
|
||||
|
||||
# Trajectories are saved with:
|
||||
# - Complete conversation history
|
||||
# - Usage statistics and costs
|
||||
# - Timestamps and metadata
|
||||
# - Screenshots and computer actions
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### ComputerAgent Parameters
|
||||
|
||||
- `model`: Model identifier (required)
|
||||
- `tools`: List of computer objects and decorated functions
|
||||
- `callbacks`: List of callback handlers for extensibility
|
||||
- `only_n_most_recent_images`: Limit recent images to prevent context overflow
|
||||
- `verbosity`: Logging level (logging.INFO, logging.DEBUG, etc.)
|
||||
- `trajectory_dir`: Directory to save conversation trajectories
|
||||
- `max_retries`: Maximum API call retries (default: 3)
|
||||
- `screenshot_delay`: Delay between actions and screenshots (default: 0.5s)
|
||||
- `use_prompt_caching`: Enable prompt caching for supported models
|
||||
- `max_trajectory_budget`: Budget limit configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Computer instance (cloud)
|
||||
export CUA_CONTAINER_NAME="your-container-name"
|
||||
export CUA_API_KEY="your-cua-api-key"
|
||||
|
||||
# LLM API keys
|
||||
export ANTHROPIC_API_KEY="your-anthropic-key"
|
||||
export OPENAI_API_KEY="your-openai-key"
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Streaming Responses
|
||||
|
||||
```python
|
||||
async for result in agent.run(messages, stream=True):
|
||||
# Process streaming chunks
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"], end="", flush=True)
|
||||
elif item["type"] == "computer_call":
|
||||
action = item["action"]
|
||||
print(f"\n[Action: {action['type']}]")
|
||||
```
|
||||
|
||||
### Interactive Chat Loop
|
||||
|
||||
```python
|
||||
history = []
|
||||
while True:
|
||||
user_input = input("> ")
|
||||
if user_input.lower() in ['quit', 'exit']:
|
||||
break
|
||||
|
||||
history.append({"role": "user", "content": user_input})
|
||||
|
||||
async for result in agent.run(history):
|
||||
history += result["output"]
|
||||
|
||||
# Display assistant responses
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"])
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
try:
|
||||
async for result in agent.run(messages):
|
||||
# Process results
|
||||
pass
|
||||
except BudgetExceededException:
|
||||
print("Budget limit exceeded")
|
||||
except Exception as e:
|
||||
print(f"Agent error: {e}")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### ComputerAgent.run()
|
||||
|
||||
```python
|
||||
async def run(
|
||||
self,
|
||||
messages: Messages,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Run the agent with the given messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
stream: Whether to stream the response
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
AsyncGenerator that yields response chunks
|
||||
"""
|
||||
```
|
||||
|
||||
### Message Format
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Take a screenshot and describe what you see"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'll take a screenshot for you."
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Response Format
|
||||
|
||||
```python
|
||||
{
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "I can see..."}]
|
||||
},
|
||||
{
|
||||
"type": "computer_call",
|
||||
"action": {"type": "screenshot"},
|
||||
"call_id": "call_123"
|
||||
},
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": "call_123",
|
||||
"output": {"image_url": "data:image/png;base64,..."}
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 150,
|
||||
"completion_tokens": 75,
|
||||
"total_tokens": 225,
|
||||
"response_cost": 0.01,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see LICENSE file for details.
|
||||
@@ -1,19 +0,0 @@
|
||||
"""
|
||||
agent - Decorator-based Computer Use Agent with liteLLM integration
|
||||
"""
|
||||
|
||||
from .decorators import agent_loop
|
||||
from .agent import ComputerAgent
|
||||
from .types import Messages, AgentResponse
|
||||
|
||||
# Import loops to register them
|
||||
from . import loops
|
||||
|
||||
__all__ = [
|
||||
"agent_loop",
|
||||
"ComputerAgent",
|
||||
"Messages",
|
||||
"AgentResponse"
|
||||
]
|
||||
|
||||
__version__ = "0.4.0b3"
|
||||
@@ -1,21 +0,0 @@
|
||||
"""
|
||||
Entry point for running agent CLI module.
|
||||
|
||||
Usage:
|
||||
python -m agent.cli <model_string>
|
||||
"""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
from .cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if 'cli' is specified as the module
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "cli":
|
||||
# Remove 'cli' from arguments and run CLI
|
||||
sys.argv.pop(1)
|
||||
asyncio.run(main())
|
||||
else:
|
||||
print("Usage: python -m agent.cli <model_string>")
|
||||
print("Example: python -m agent.cli openai/computer-use-preview")
|
||||
sys.exit(1)
|
||||
@@ -1,9 +0,0 @@
|
||||
"""
|
||||
Adapters package for agent - Custom LLM adapters for LiteLLM
|
||||
"""
|
||||
|
||||
from .huggingfacelocal_adapter import HuggingFaceLocalAdapter
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceLocalAdapter",
|
||||
]
|
||||
@@ -1,229 +0,0 @@
|
||||
import asyncio
|
||||
import warnings
|
||||
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from litellm.llms.custom_llm import CustomLLM
|
||||
from litellm import completion, acompletion
|
||||
|
||||
# Try to import HuggingFace dependencies
|
||||
try:
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
|
||||
class HuggingFaceLocalAdapter(CustomLLM):
|
||||
"""HuggingFace Local Adapter for running vision-language models locally."""
|
||||
|
||||
def __init__(self, device: str = "auto", **kwargs):
|
||||
"""Initialize the adapter.
|
||||
|
||||
Args:
|
||||
device: Device to load model on ("auto", "cuda", "cpu", etc.)
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.models = {} # Cache for loaded models
|
||||
self.processors = {} # Cache for loaded processors
|
||||
|
||||
def _load_model_and_processor(self, model_name: str):
|
||||
"""Load model and processor if not already cached.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to load
|
||||
|
||||
Returns:
|
||||
Tuple of (model, processor)
|
||||
"""
|
||||
if model_name not in self.models:
|
||||
# Load model
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=self.device,
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
# Load processor
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
# Cache them
|
||||
self.models[model_name] = model
|
||||
self.processors[model_name] = processor
|
||||
|
||||
return self.models[model_name], self.processors[model_name]
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert OpenAI format messages to HuggingFace format.
|
||||
|
||||
Args:
|
||||
messages: Messages in OpenAI format
|
||||
|
||||
Returns:
|
||||
Messages in HuggingFace format
|
||||
"""
|
||||
converted_messages = []
|
||||
|
||||
for message in messages:
|
||||
converted_message = {
|
||||
"role": message["role"],
|
||||
"content": []
|
||||
}
|
||||
|
||||
content = message.get("content", [])
|
||||
if isinstance(content, str):
|
||||
# Simple text content
|
||||
converted_message["content"].append({
|
||||
"type": "text",
|
||||
"text": content
|
||||
})
|
||||
elif isinstance(content, list):
|
||||
# Multi-modal content
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
converted_message["content"].append({
|
||||
"type": "text",
|
||||
"text": item.get("text", "")
|
||||
})
|
||||
elif item.get("type") == "image_url":
|
||||
# Convert image_url format to image format
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
converted_message["content"].append({
|
||||
"type": "image",
|
||||
"image": image_url
|
||||
})
|
||||
|
||||
converted_messages.append(converted_message)
|
||||
|
||||
return converted_messages
|
||||
|
||||
def _generate(self, **kwargs) -> str:
|
||||
"""Generate response using the local HuggingFace model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing messages and model info
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"HuggingFace transformers dependencies not found. "
|
||||
"Please install with: pip install \"cua-agent[uitars-hf]\""
|
||||
)
|
||||
|
||||
# Extract messages and model from kwargs
|
||||
messages = kwargs.get('messages', [])
|
||||
model_name = kwargs.get('model', 'ByteDance-Seed/UI-TARS-1.5-7B')
|
||||
max_new_tokens = kwargs.get('max_tokens', 128)
|
||||
|
||||
# Warn about ignored kwargs
|
||||
ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
|
||||
if ignored_kwargs:
|
||||
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
|
||||
|
||||
# Load model and processor
|
||||
model, processor = self._load_model_and_processor(model_name)
|
||||
|
||||
# Convert messages to HuggingFace format
|
||||
hf_messages = self._convert_messages(messages)
|
||||
|
||||
# Apply chat template and tokenize
|
||||
inputs = processor.apply_chat_template(
|
||||
hf_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Move inputs to the same device as model
|
||||
if torch.cuda.is_available() and self.device != "cpu":
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Generate response
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
|
||||
# Trim input tokens from output
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
# Decode output
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
return output_text[0] if output_text else ""
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Synchronous completion method.
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
return completion(
|
||||
model=f"huggingface-local/{kwargs['model']}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Asynchronous completion method.
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
generated_text = await asyncio.to_thread(self._generate, **kwargs)
|
||||
|
||||
return await acompletion(
|
||||
model=f"huggingface-local/{kwargs['model']}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
"""Synchronous streaming method.
|
||||
|
||||
Returns:
|
||||
Iterator of GenericStreamingChunk
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"is_finished": True,
|
||||
"text": generated_text,
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
"""Asynchronous streaming method.
|
||||
|
||||
Returns:
|
||||
AsyncIterator of GenericStreamingChunk
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
generated_text = await asyncio.to_thread(self._generate, **kwargs)
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"is_finished": True,
|
||||
"text": generated_text,
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk
|
||||
@@ -1,577 +0,0 @@
|
||||
"""
|
||||
ComputerAgent - Main agent class that selects and runs agent loops
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set
|
||||
|
||||
from litellm.responses.utils import Usage
|
||||
from .types import Messages, Computer
|
||||
from .decorators import find_agent_loop
|
||||
from .computer_handler import OpenAIComputerHandler, acknowledge_safety_check_callback, check_blocklisted_url
|
||||
import json
|
||||
import litellm
|
||||
import litellm.utils
|
||||
import inspect
|
||||
from .adapters import HuggingFaceLocalAdapter
|
||||
from .callbacks import ImageRetentionCallback, LoggingCallback, TrajectorySaverCallback, BudgetManagerCallback
|
||||
|
||||
def get_json(obj: Any, max_depth: int = 10) -> Any:
|
||||
def custom_serializer(o: Any, depth: int = 0, seen: Set[int] = None) -> Any:
|
||||
if seen is None:
|
||||
seen = set()
|
||||
|
||||
# Use model_dump() if available
|
||||
if hasattr(o, 'model_dump'):
|
||||
return o.model_dump()
|
||||
|
||||
# Check depth limit
|
||||
if depth > max_depth:
|
||||
return f"<max_depth_exceeded:{max_depth}>"
|
||||
|
||||
# Check for circular references using object id
|
||||
obj_id = id(o)
|
||||
if obj_id in seen:
|
||||
return f"<circular_reference:{type(o).__name__}>"
|
||||
|
||||
# Handle Computer objects
|
||||
if hasattr(o, '__class__') and 'computer' in getattr(o, '__class__').__name__.lower():
|
||||
return f"<computer:{o.__class__.__name__}>"
|
||||
|
||||
# Handle objects with __dict__
|
||||
if hasattr(o, '__dict__'):
|
||||
seen.add(obj_id)
|
||||
try:
|
||||
result = {}
|
||||
for k, v in o.__dict__.items():
|
||||
if v is not None:
|
||||
# Recursively serialize with updated depth and seen set
|
||||
serialized_value = custom_serializer(v, depth + 1, seen.copy())
|
||||
result[k] = serialized_value
|
||||
return result
|
||||
finally:
|
||||
seen.discard(obj_id)
|
||||
|
||||
# Handle common types that might contain nested objects
|
||||
elif isinstance(o, dict):
|
||||
seen.add(obj_id)
|
||||
try:
|
||||
return {
|
||||
k: custom_serializer(v, depth + 1, seen.copy())
|
||||
for k, v in o.items()
|
||||
if v is not None
|
||||
}
|
||||
finally:
|
||||
seen.discard(obj_id)
|
||||
|
||||
elif isinstance(o, (list, tuple, set)):
|
||||
seen.add(obj_id)
|
||||
try:
|
||||
return [
|
||||
custom_serializer(item, depth + 1, seen.copy())
|
||||
for item in o
|
||||
if item is not None
|
||||
]
|
||||
finally:
|
||||
seen.discard(obj_id)
|
||||
|
||||
# For basic types that json.dumps can handle
|
||||
elif isinstance(o, (str, int, float, bool)) or o is None:
|
||||
return o
|
||||
|
||||
# Fallback to string representation
|
||||
else:
|
||||
return str(o)
|
||||
|
||||
def remove_nones(obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
return {k: remove_nones(v) for k, v in obj.items() if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
return [remove_nones(item) for item in obj if item is not None]
|
||||
return obj
|
||||
|
||||
# Serialize with circular reference and depth protection
|
||||
serialized = custom_serializer(obj)
|
||||
|
||||
# Convert to JSON string and back to ensure JSON compatibility
|
||||
json_str = json.dumps(serialized)
|
||||
parsed = json.loads(json_str)
|
||||
|
||||
# Final cleanup of any remaining None values
|
||||
return remove_nones(parsed)
|
||||
|
||||
def sanitize_message(msg: Any) -> Any:
|
||||
"""Return a copy of the message with image_url omitted for computer_call_output messages."""
|
||||
if msg.get("type") == "computer_call_output":
|
||||
output = msg.get("output", {})
|
||||
if isinstance(output, dict):
|
||||
sanitized = msg.copy()
|
||||
sanitized["output"] = {**output, "image_url": "[omitted]"}
|
||||
return sanitized
|
||||
return msg
|
||||
|
||||
class ComputerAgent:
|
||||
"""
|
||||
Main agent class that automatically selects the appropriate agent loop
|
||||
based on the model and executes tool calls.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tools: Optional[List[Any]] = None,
|
||||
custom_loop: Optional[Callable] = None,
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
verbosity: Optional[int] = None,
|
||||
trajectory_dir: Optional[str] = None,
|
||||
max_retries: Optional[int] = 3,
|
||||
screenshot_delay: Optional[float | int] = 0.5,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
max_trajectory_budget: Optional[float | dict] = None,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Initialize ComputerAgent.
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
|
||||
tools: List of tools (computer objects, decorated functions, etc.)
|
||||
custom_loop: Custom agent loop function to use instead of auto-selection
|
||||
only_n_most_recent_images: If set, only keep the N most recent images in message history. Adds ImageRetentionCallback automatically.
|
||||
callbacks: List of AsyncCallbackHandler instances for preprocessing/postprocessing
|
||||
verbosity: Logging level (logging.DEBUG, logging.INFO, etc.). If set, adds LoggingCallback automatically
|
||||
trajectory_dir: If set, saves trajectory data (screenshots, responses) to this directory. Adds TrajectorySaverCallback automatically.
|
||||
max_retries: Maximum number of retries for failed API calls
|
||||
screenshot_delay: Delay before screenshots in seconds
|
||||
use_prompt_caching: If set, use prompt caching to avoid reprocessing the same prompt. Intended for use with anthropic providers.
|
||||
max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded
|
||||
**kwargs: Additional arguments passed to the agent loop
|
||||
"""
|
||||
self.model = model
|
||||
self.tools = tools or []
|
||||
self.custom_loop = custom_loop
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self.callbacks = callbacks or []
|
||||
self.verbosity = verbosity
|
||||
self.trajectory_dir = trajectory_dir
|
||||
self.max_retries = max_retries
|
||||
self.screenshot_delay = screenshot_delay
|
||||
self.use_prompt_caching = use_prompt_caching
|
||||
self.kwargs = kwargs
|
||||
|
||||
# == Add built-in callbacks ==
|
||||
|
||||
# Add logging callback if verbosity is set
|
||||
if self.verbosity is not None:
|
||||
self.callbacks.append(LoggingCallback(level=self.verbosity))
|
||||
|
||||
# Add image retention callback if only_n_most_recent_images is set
|
||||
if self.only_n_most_recent_images:
|
||||
self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
|
||||
|
||||
# Add trajectory saver callback if trajectory_dir is set
|
||||
if self.trajectory_dir:
|
||||
self.callbacks.append(TrajectorySaverCallback(self.trajectory_dir))
|
||||
|
||||
# Add budget manager if max_trajectory_budget is set
|
||||
if max_trajectory_budget:
|
||||
if isinstance(max_trajectory_budget, dict):
|
||||
self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
|
||||
else:
|
||||
self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
|
||||
|
||||
# == Enable local model providers w/ LiteLLM ==
|
||||
|
||||
# Register local model providers
|
||||
hf_adapter = HuggingFaceLocalAdapter(
|
||||
device="auto"
|
||||
)
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "huggingface-local", "custom_handler": hf_adapter}
|
||||
]
|
||||
|
||||
# == Initialize computer agent ==
|
||||
|
||||
# Find the appropriate agent loop
|
||||
if custom_loop:
|
||||
self.agent_loop = custom_loop
|
||||
self.agent_loop_info = None
|
||||
else:
|
||||
loop_info = find_agent_loop(model)
|
||||
if not loop_info:
|
||||
raise ValueError(f"No agent loop found for model: {model}")
|
||||
self.agent_loop = loop_info.func
|
||||
self.agent_loop_info = loop_info
|
||||
|
||||
self.tool_schemas = []
|
||||
self.computer_handler = None
|
||||
|
||||
async def _initialize_computers(self):
|
||||
"""Initialize computer objects"""
|
||||
if not self.tool_schemas:
|
||||
for tool in self.tools:
|
||||
if hasattr(tool, '_initialized') and not tool._initialized:
|
||||
await tool.run()
|
||||
|
||||
# Process tools and create tool schemas
|
||||
self.tool_schemas = self._process_tools()
|
||||
|
||||
# Find computer tool and create interface adapter
|
||||
computer_handler = None
|
||||
for schema in self.tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
computer_handler = OpenAIComputerHandler(schema["computer"].interface)
|
||||
break
|
||||
self.computer_handler = computer_handler
|
||||
|
||||
def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
|
||||
"""Process input messages and create schemas for the agent loop"""
|
||||
if isinstance(input, str):
|
||||
return [{"role": "user", "content": input}]
|
||||
return [get_json(msg) for msg in input]
|
||||
|
||||
def _process_tools(self) -> List[Dict[str, Any]]:
|
||||
"""Process tools and create schemas for the agent loop"""
|
||||
schemas = []
|
||||
|
||||
for tool in self.tools:
|
||||
# Check if it's a computer object (has interface attribute)
|
||||
if hasattr(tool, 'interface'):
|
||||
# This is a computer tool - will be handled by agent loop
|
||||
schemas.append({
|
||||
"type": "computer",
|
||||
"computer": tool
|
||||
})
|
||||
elif callable(tool):
|
||||
# Use litellm.utils.function_to_dict to extract schema from docstring
|
||||
try:
|
||||
function_schema = litellm.utils.function_to_dict(tool)
|
||||
schemas.append({
|
||||
"type": "function",
|
||||
"function": function_schema
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process tool {tool}: {e}")
|
||||
else:
|
||||
print(f"Warning: Unknown tool type: {tool}")
|
||||
|
||||
return schemas
|
||||
|
||||
def _get_tool(self, name: str) -> Optional[Callable]:
|
||||
"""Get a tool by name"""
|
||||
for tool in self.tools:
|
||||
if hasattr(tool, '__name__') and tool.__name__ == name:
|
||||
return tool
|
||||
elif hasattr(tool, 'func') and tool.func.__name__ == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
# ============================================================================
|
||||
# AGENT RUN LOOP LIFECYCLE HOOKS
|
||||
# ============================================================================
|
||||
|
||||
async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Initialize run tracking by calling callbacks."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_run_start'):
|
||||
await callback.on_run_start(kwargs, old_items)
|
||||
|
||||
async def _on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
"""Finalize run tracking by calling callbacks."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_run_end'):
|
||||
await callback.on_run_end(kwargs, old_items, new_items)
|
||||
|
||||
async def _on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
|
||||
"""Check if run should continue by calling callbacks."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_run_continue'):
|
||||
should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
|
||||
if not should_continue:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Prepare messages for the LLM call by applying callbacks."""
|
||||
result = messages
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_llm_start'):
|
||||
result = await callback.on_llm_start(result)
|
||||
return result
|
||||
|
||||
async def _on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Postprocess messages after the LLM call by applying callbacks."""
|
||||
result = messages
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_llm_end'):
|
||||
result = await callback.on_llm_end(result)
|
||||
return result
|
||||
|
||||
async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
||||
"""Called when responses are received."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_responses'):
|
||||
await callback.on_responses(get_json(kwargs), get_json(responses))
|
||||
|
||||
async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a computer call is about to start."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_computer_call_start'):
|
||||
await callback.on_computer_call_start(get_json(item))
|
||||
|
||||
async def _on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
"""Called when a computer call has completed."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_computer_call_end'):
|
||||
await callback.on_computer_call_end(get_json(item), get_json(result))
|
||||
|
||||
async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a function call is about to start."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_function_call_start'):
|
||||
await callback.on_function_call_start(get_json(item))
|
||||
|
||||
async def _on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
"""Called when a function call has completed."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_function_call_end'):
|
||||
await callback.on_function_call_end(get_json(item), get_json(result))
|
||||
|
||||
async def _on_text(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a text message is encountered."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_text'):
|
||||
await callback.on_text(get_json(item))
|
||||
|
||||
async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""Called when an LLM API call is about to start."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_api_start'):
|
||||
await callback.on_api_start(get_json(kwargs))
|
||||
|
||||
async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when an LLM API call has completed."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_api_end'):
|
||||
await callback.on_api_end(get_json(kwargs), get_json(result))
|
||||
|
||||
async def _on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Called when usage information is received."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_usage'):
|
||||
await callback.on_usage(get_json(usage))
|
||||
|
||||
async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""Called when a screenshot is taken."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_screenshot'):
|
||||
await callback.on_screenshot(screenshot, name)
|
||||
|
||||
# ============================================================================
|
||||
# AGENT OUTPUT PROCESSING
|
||||
# ============================================================================
|
||||
|
||||
async def _handle_item(self, item: Any, computer: Optional[Computer] = None) -> List[Dict[str, Any]]:
|
||||
"""Handle each item; may cause a computer action + screenshot."""
|
||||
|
||||
item_type = item.get("type", None)
|
||||
|
||||
if item_type == "message":
|
||||
await self._on_text(item)
|
||||
# # Print messages
|
||||
# if item.get("content"):
|
||||
# for content_item in item.get("content"):
|
||||
# if content_item.get("text"):
|
||||
# print(content_item.get("text"))
|
||||
return []
|
||||
|
||||
if item_type == "computer_call":
|
||||
await self._on_computer_call_start(item)
|
||||
if not computer:
|
||||
raise ValueError("Computer handler is required for computer calls")
|
||||
|
||||
# Perform computer actions
|
||||
action = item.get("action")
|
||||
action_type = action.get("type")
|
||||
|
||||
# Extract action arguments (all fields except 'type')
|
||||
action_args = {k: v for k, v in action.items() if k != "type"}
|
||||
|
||||
# print(f"{action_type}({action_args})")
|
||||
|
||||
# Execute the computer action
|
||||
computer_method = getattr(computer, action_type, None)
|
||||
if computer_method:
|
||||
await computer_method(**action_args)
|
||||
else:
|
||||
print(f"Unknown computer action: {action_type}")
|
||||
return []
|
||||
|
||||
# Take screenshot after action
|
||||
if self.screenshot_delay and self.screenshot_delay > 0:
|
||||
await asyncio.sleep(self.screenshot_delay)
|
||||
screenshot_base64 = await computer.screenshot()
|
||||
await self._on_screenshot(screenshot_base64, "screenshot_after")
|
||||
|
||||
# Handle safety checks
|
||||
pending_checks = item.get("pending_safety_checks", [])
|
||||
acknowledged_checks = []
|
||||
for check in pending_checks:
|
||||
check_message = check.get("message", str(check))
|
||||
if acknowledge_safety_check_callback(check_message):
|
||||
acknowledged_checks.append(check)
|
||||
else:
|
||||
raise ValueError(f"Safety check failed: {check_message}")
|
||||
|
||||
# Create call output
|
||||
call_output = {
|
||||
"type": "computer_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"acknowledged_safety_checks": acknowledged_checks,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
||||
},
|
||||
}
|
||||
|
||||
# Additional URL safety checks for browser environments
|
||||
if await computer.get_environment() == "browser":
|
||||
current_url = await computer.get_current_url()
|
||||
call_output["output"]["current_url"] = current_url
|
||||
check_blocklisted_url(current_url)
|
||||
|
||||
result = [call_output]
|
||||
await self._on_computer_call_end(item, result)
|
||||
return result
|
||||
|
||||
if item_type == "function_call":
|
||||
await self._on_function_call_start(item)
|
||||
# Perform function call
|
||||
function = self._get_tool(item.get("name"))
|
||||
if not function:
|
||||
raise ValueError(f"Function {item.get("name")} not found")
|
||||
|
||||
args = json.loads(item.get("arguments"))
|
||||
|
||||
# Execute function - use asyncio.to_thread for non-async functions
|
||||
if inspect.iscoroutinefunction(function):
|
||||
result = await function(**args)
|
||||
else:
|
||||
result = await asyncio.to_thread(function, **args)
|
||||
|
||||
# Create function call output
|
||||
call_output = {
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"output": str(result),
|
||||
}
|
||||
|
||||
result = [call_output]
|
||||
await self._on_function_call_end(item, result)
|
||||
return result
|
||||
|
||||
return []
|
||||
|
||||
# ============================================================================
|
||||
# MAIN AGENT LOOP
|
||||
# ============================================================================
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: Messages,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Run the agent with the given messages using Computer protocol handler pattern.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
stream: Whether to stream the response
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
AsyncGenerator that yields response chunks
|
||||
"""
|
||||
|
||||
await self._initialize_computers()
|
||||
|
||||
# Merge kwargs
|
||||
merged_kwargs = {**self.kwargs, **kwargs}
|
||||
|
||||
old_items = self._process_input(messages)
|
||||
new_items = []
|
||||
|
||||
# Initialize run tracking
|
||||
run_kwargs = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"model": self.model,
|
||||
"agent_loop": self.agent_loop.__name__,
|
||||
**merged_kwargs
|
||||
}
|
||||
await self._on_run_start(run_kwargs, old_items)
|
||||
|
||||
while new_items[-1].get("role") != "assistant" if new_items else True:
|
||||
# Lifecycle hook: Check if we should continue based on callbacks (e.g., budget manager)
|
||||
should_continue = await self._on_run_continue(run_kwargs, old_items, new_items)
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
# Lifecycle hook: Prepare messages for the LLM call
|
||||
# Use cases:
|
||||
# - PII anonymization
|
||||
# - Image retention policy
|
||||
combined_messages = old_items + new_items
|
||||
preprocessed_messages = await self._on_llm_start(combined_messages)
|
||||
|
||||
loop_kwargs = {
|
||||
"messages": preprocessed_messages,
|
||||
"model": self.model,
|
||||
"tools": self.tool_schemas,
|
||||
"stream": False,
|
||||
"computer_handler": self.computer_handler,
|
||||
"max_retries": self.max_retries,
|
||||
"use_prompt_caching": self.use_prompt_caching,
|
||||
**merged_kwargs
|
||||
}
|
||||
|
||||
# Run agent loop iteration
|
||||
result = await self.agent_loop(
|
||||
**loop_kwargs,
|
||||
_on_api_start=self._on_api_start,
|
||||
_on_api_end=self._on_api_end,
|
||||
_on_usage=self._on_usage,
|
||||
_on_screenshot=self._on_screenshot,
|
||||
)
|
||||
result = get_json(result)
|
||||
|
||||
# Lifecycle hook: Postprocess messages after the LLM call
|
||||
# Use cases:
|
||||
# - PII deanonymization (if you want tool calls to see PII)
|
||||
result["output"] = await self._on_llm_end(result.get("output", []))
|
||||
await self._on_responses(loop_kwargs, result)
|
||||
|
||||
# Yield agent response
|
||||
yield result
|
||||
|
||||
# Add agent response to new_items
|
||||
new_items += result.get("output")
|
||||
|
||||
# Handle computer actions
|
||||
for item in result.get("output"):
|
||||
partial_items = await self._handle_item(item, self.computer_handler)
|
||||
new_items += partial_items
|
||||
|
||||
# Yield partial response
|
||||
yield {
|
||||
"output": partial_items,
|
||||
"usage": Usage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
}
|
||||
|
||||
await self._on_run_end(loop_kwargs, old_items, new_items)
|
||||
@@ -1,17 +0,0 @@
|
||||
"""
|
||||
Callback system for ComputerAgent preprocessing and postprocessing hooks.
|
||||
"""
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
from .image_retention import ImageRetentionCallback
|
||||
from .logging import LoggingCallback
|
||||
from .trajectory_saver import TrajectorySaverCallback
|
||||
from .budget_manager import BudgetManagerCallback
|
||||
|
||||
__all__ = [
|
||||
"AsyncCallbackHandler",
|
||||
"ImageRetentionCallback",
|
||||
"LoggingCallback",
|
||||
"TrajectorySaverCallback",
|
||||
"BudgetManagerCallback",
|
||||
]
|
||||
@@ -1,153 +0,0 @@
|
||||
"""
|
||||
Base callback handler interface for ComputerAgent preprocessing and postprocessing hooks.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
|
||||
class AsyncCallbackHandler(ABC):
|
||||
"""
|
||||
Base class for async callback handlers that can preprocess messages before
|
||||
the agent loop and postprocess output after the agent loop.
|
||||
"""
|
||||
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called at the start of an agent run loop."""
|
||||
pass
|
||||
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called at the end of an agent run loop."""
|
||||
pass
|
||||
|
||||
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
|
||||
"""Called during agent run loop to determine if execution should continue.
|
||||
|
||||
Args:
|
||||
kwargs: Run arguments
|
||||
old_items: Original messages
|
||||
new_items: New messages generated during run
|
||||
|
||||
Returns:
|
||||
True to continue execution, False to stop
|
||||
"""
|
||||
return True
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Called before messages are sent to the agent loop.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries to preprocess
|
||||
|
||||
Returns:
|
||||
List of preprocessed message dictionaries
|
||||
"""
|
||||
return messages
|
||||
|
||||
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Called after the agent loop returns output.
|
||||
|
||||
Args:
|
||||
output: List of output message dictionaries to postprocess
|
||||
|
||||
Returns:
|
||||
List of postprocessed output dictionaries
|
||||
"""
|
||||
return output
|
||||
|
||||
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when a computer call is about to start.
|
||||
|
||||
Args:
|
||||
item: The computer call item dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Called when a computer call has completed.
|
||||
|
||||
Args:
|
||||
item: The computer call item dictionary
|
||||
result: The result of the computer call
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when a function call is about to start.
|
||||
|
||||
Args:
|
||||
item: The function call item dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Called when a function call has completed.
|
||||
|
||||
Args:
|
||||
item: The function call item dictionary
|
||||
result: The result of the function call
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_text(self, item: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when a text message is encountered.
|
||||
|
||||
Args:
|
||||
item: The message item dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when an API call is about to start.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs being passed to the API call
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""
|
||||
Called when an API call has completed.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs that were passed to the API call
|
||||
result: The result of the API call
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when usage information is received.
|
||||
|
||||
Args:
|
||||
usage: The usage information
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""
|
||||
Called when a screenshot is taken.
|
||||
|
||||
Args:
|
||||
screenshot: The screenshot image
|
||||
name: The name of the screenshot
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when responses are received.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs being passed to the agent loop
|
||||
responses: The responses received
|
||||
"""
|
||||
pass
|
||||
@@ -1,44 +0,0 @@
|
||||
from typing import Dict, List, Any
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
class BudgetExceededError(Exception):
|
||||
"""Exception raised when budget is exceeded."""
|
||||
pass
|
||||
|
||||
class BudgetManagerCallback(AsyncCallbackHandler):
|
||||
"""Budget manager callback that tracks usage costs and can stop execution when budget is exceeded."""
|
||||
|
||||
def __init__(self, max_budget: float, reset_after_each_run: bool = True, raise_error: bool = False):
|
||||
"""
|
||||
Initialize BudgetManagerCallback.
|
||||
|
||||
Args:
|
||||
max_budget: Maximum budget allowed
|
||||
reset_after_each_run: Whether to reset budget after each run
|
||||
raise_error: Whether to raise an error when budget is exceeded
|
||||
"""
|
||||
self.max_budget = max_budget
|
||||
self.reset_after_each_run = reset_after_each_run
|
||||
self.raise_error = raise_error
|
||||
self.total_cost = 0.0
|
||||
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Reset budget if configured to do so."""
|
||||
if self.reset_after_each_run:
|
||||
self.total_cost = 0.0
|
||||
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Track usage costs."""
|
||||
if "response_cost" in usage:
|
||||
self.total_cost += usage["response_cost"]
|
||||
|
||||
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
|
||||
"""Check if budget allows continuation."""
|
||||
if self.total_cost >= self.max_budget:
|
||||
if self.raise_error:
|
||||
raise BudgetExceededError(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
|
||||
else:
|
||||
print(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
"""
|
||||
Image retention callback handler that limits the number of recent images in message history.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
class ImageRetentionCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that applies image retention policy to limit the number
|
||||
of recent images in message history to prevent context window overflow.
|
||||
"""
|
||||
|
||||
def __init__(self, only_n_most_recent_images: Optional[int] = None):
|
||||
"""
|
||||
Initialize the image retention callback.
|
||||
|
||||
Args:
|
||||
only_n_most_recent_images: If set, only keep the N most recent images in message history
|
||||
"""
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Apply image retention policy to messages before sending to agent loop.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
List of messages with image retention policy applied
|
||||
"""
|
||||
if self.only_n_most_recent_images is None:
|
||||
return messages
|
||||
|
||||
return self._apply_image_retention(messages)
|
||||
|
||||
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Apply image retention policy to keep only the N most recent images.
|
||||
|
||||
Removes computer_call_output items with image_url and their corresponding computer_call items,
|
||||
keeping only the most recent N image pairs based on only_n_most_recent_images setting.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
Filtered list of messages with image retention applied
|
||||
"""
|
||||
if self.only_n_most_recent_images is None:
|
||||
return messages
|
||||
|
||||
# First pass: Assign call_id to reasoning items based on the next computer_call
|
||||
messages_with_call_ids = []
|
||||
for i, msg in enumerate(messages):
|
||||
msg_copy = msg.copy() if isinstance(msg, dict) else msg
|
||||
|
||||
# If this is a reasoning item without a call_id, find the next computer_call
|
||||
if (msg_copy.get("type") == "reasoning" and
|
||||
not msg_copy.get("call_id")):
|
||||
# Look ahead for the next computer_call
|
||||
for j in range(i + 1, len(messages)):
|
||||
next_msg = messages[j]
|
||||
if (next_msg.get("type") == "computer_call" and
|
||||
next_msg.get("call_id")):
|
||||
msg_copy["call_id"] = next_msg.get("call_id")
|
||||
break
|
||||
|
||||
messages_with_call_ids.append(msg_copy)
|
||||
|
||||
# Find all computer_call_output items with images and their call_ids
|
||||
image_call_ids = []
|
||||
for msg in reversed(messages_with_call_ids): # Process in reverse to get most recent first
|
||||
if (msg.get("type") == "computer_call_output" and
|
||||
isinstance(msg.get("output"), dict) and
|
||||
"image_url" in msg.get("output", {})):
|
||||
call_id = msg.get("call_id")
|
||||
if call_id and call_id not in image_call_ids:
|
||||
image_call_ids.append(call_id)
|
||||
if len(image_call_ids) >= self.only_n_most_recent_images:
|
||||
break
|
||||
|
||||
# Keep the most recent N image call_ids (reverse to get chronological order)
|
||||
keep_call_ids = set(image_call_ids[:self.only_n_most_recent_images])
|
||||
|
||||
# Filter messages: remove computer_call, computer_call_output, and reasoning for old images
|
||||
filtered_messages = []
|
||||
for msg in messages_with_call_ids:
|
||||
msg_type = msg.get("type")
|
||||
call_id = msg.get("call_id")
|
||||
|
||||
# Remove old computer_call items
|
||||
if msg_type == "computer_call" and call_id not in keep_call_ids:
|
||||
# Check if this call_id corresponds to an image call
|
||||
has_image_output = any(
|
||||
m.get("type") == "computer_call_output" and
|
||||
m.get("call_id") == call_id and
|
||||
isinstance(m.get("output"), dict) and
|
||||
"image_url" in m.get("output", {})
|
||||
for m in messages_with_call_ids
|
||||
)
|
||||
if has_image_output:
|
||||
continue # Skip this computer_call
|
||||
|
||||
# Remove old computer_call_output items with images
|
||||
if (msg_type == "computer_call_output" and
|
||||
call_id not in keep_call_ids and
|
||||
isinstance(msg.get("output"), dict) and
|
||||
"image_url" in msg.get("output", {})):
|
||||
continue # Skip this computer_call_output
|
||||
|
||||
# Remove old reasoning items that are paired with removed computer calls
|
||||
if (msg_type == "reasoning" and
|
||||
call_id and call_id not in keep_call_ids):
|
||||
# Check if this call_id corresponds to an image call that's being removed
|
||||
has_image_output = any(
|
||||
m.get("type") == "computer_call_output" and
|
||||
m.get("call_id") == call_id and
|
||||
isinstance(m.get("output"), dict) and
|
||||
"image_url" in m.get("output", {})
|
||||
for m in messages_with_call_ids
|
||||
)
|
||||
if has_image_output:
|
||||
continue # Skip this reasoning item
|
||||
|
||||
filtered_messages.append(msg)
|
||||
|
||||
# Clean up: Remove call_id from reasoning items before returning
|
||||
final_messages = []
|
||||
for msg in filtered_messages:
|
||||
if msg.get("type") == "reasoning" and "call_id" in msg:
|
||||
# Create a copy without call_id for reasoning items
|
||||
cleaned_msg = {k: v for k, v in msg.items() if k != "call_id"}
|
||||
final_messages.append(cleaned_msg)
|
||||
else:
|
||||
final_messages.append(msg)
|
||||
|
||||
return final_messages
|
||||
@@ -1,247 +0,0 @@
|
||||
"""
|
||||
Logging callback for ComputerAgent that provides configurable logging of agent lifecycle events.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
def sanitize_image_urls(data: Any) -> Any:
|
||||
"""
|
||||
Recursively search for 'image_url' keys and set their values to '[omitted]'.
|
||||
|
||||
Args:
|
||||
data: Any data structure (dict, list, or primitive type)
|
||||
|
||||
Returns:
|
||||
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
# Create a copy of the dictionary
|
||||
sanitized = {}
|
||||
for key, value in data.items():
|
||||
if key == "image_url":
|
||||
sanitized[key] = "[omitted]"
|
||||
else:
|
||||
# Recursively sanitize the value
|
||||
sanitized[key] = sanitize_image_urls(value)
|
||||
return sanitized
|
||||
|
||||
elif isinstance(data, list):
|
||||
# Recursively sanitize each item in the list
|
||||
return [sanitize_image_urls(item) for item in data]
|
||||
|
||||
else:
|
||||
# For primitive types (str, int, bool, None, etc.), return as-is
|
||||
return data
|
||||
|
||||
|
||||
class LoggingCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that logs agent lifecycle events with configurable verbosity.
|
||||
|
||||
Logging levels:
|
||||
- DEBUG: All events including API calls, message preprocessing, and detailed outputs
|
||||
- INFO: Major lifecycle events (start/end, messages, outputs)
|
||||
- WARNING: Only warnings and errors
|
||||
- ERROR: Only errors
|
||||
"""
|
||||
|
||||
def __init__(self, logger: Optional[logging.Logger] = None, level: int = logging.INFO):
|
||||
"""
|
||||
Initialize the logging callback.
|
||||
|
||||
Args:
|
||||
logger: Logger instance to use. If None, creates a logger named 'agent.ComputerAgent'
|
||||
level: Logging level (logging.DEBUG, logging.INFO, etc.)
|
||||
"""
|
||||
self.logger = logger or logging.getLogger('agent.ComputerAgent')
|
||||
self.level = level
|
||||
|
||||
# Set up logger if it doesn't have handlers
|
||||
if not self.logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
self.logger.addHandler(handler)
|
||||
self.logger.setLevel(level)
|
||||
|
||||
def _update_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Update total usage statistics."""
|
||||
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||
for key, value in source.items():
|
||||
if isinstance(value, dict):
|
||||
if key not in target:
|
||||
target[key] = {}
|
||||
add_dicts(target[key], value)
|
||||
else:
|
||||
if key not in target:
|
||||
target[key] = 0
|
||||
target[key] += value
|
||||
add_dicts(self.total_usage, usage)
|
||||
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called before the run starts."""
|
||||
self.total_usage = {}
|
||||
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Called when usage information is received."""
|
||||
self._update_usage(usage)
|
||||
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called after the run ends."""
|
||||
def format_dict(d, indent=0):
|
||||
lines = []
|
||||
prefix = f" - {' ' * indent}"
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
lines.append(f"{prefix}{key}:")
|
||||
lines.extend(format_dict(value, indent + 1))
|
||||
elif isinstance(value, float):
|
||||
lines.append(f"{prefix}{key}: ${value:.4f}")
|
||||
else:
|
||||
lines.append(f"{prefix}{key}: {value}")
|
||||
return lines
|
||||
|
||||
formatted_output = "\n".join(format_dict(self.total_usage))
|
||||
self.logger.info(f"Total usage:\n{formatted_output}")
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Called before LLM processing starts."""
|
||||
if self.logger.isEnabledFor(logging.INFO):
|
||||
self.logger.info(f"LLM processing started with {len(messages)} messages")
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
|
||||
self.logger.debug(f"LLM input messages: {json.dumps(sanitized_messages, indent=2)}")
|
||||
return messages
|
||||
|
||||
async def on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Called after LLM processing ends."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
|
||||
self.logger.debug(f"LLM output: {json.dumps(sanitized_messages, indent=2)}")
|
||||
return messages
|
||||
|
||||
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a computer call starts."""
|
||||
action = item.get("action", {})
|
||||
action_type = action.get("type", "unknown")
|
||||
action_args = {k: v for k, v in action.items() if k != "type"}
|
||||
|
||||
# INFO level logging for the action
|
||||
self.logger.info(f"Computer: {action_type}({action_args})")
|
||||
|
||||
# DEBUG level logging for full details
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"Computer call started: {json.dumps(action, indent=2)}")
|
||||
|
||||
async def on_computer_call_end(self, item: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when a computer call ends."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
action = item.get("action", "unknown")
|
||||
self.logger.debug(f"Computer call completed: {json.dumps(action, indent=2)}")
|
||||
if result:
|
||||
sanitized_result = sanitize_image_urls(result)
|
||||
self.logger.debug(f"Computer call result: {json.dumps(sanitized_result, indent=2)}")
|
||||
|
||||
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a function call starts."""
|
||||
name = item.get("name", "unknown")
|
||||
arguments = item.get("arguments", "{}")
|
||||
|
||||
# INFO level logging for the function call
|
||||
self.logger.info(f"Function: {name}({arguments})")
|
||||
|
||||
# DEBUG level logging for full details
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"Function call started: {name}")
|
||||
|
||||
async def on_function_call_end(self, item: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when a function call ends."""
|
||||
# INFO level logging for function output (similar to function_call_output)
|
||||
if result:
|
||||
# Handle both list and direct result formats
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
output = result[0].get("output", str(result)) if isinstance(result[0], dict) else str(result[0])
|
||||
else:
|
||||
output = str(result)
|
||||
|
||||
# Truncate long outputs
|
||||
if len(output) > 100:
|
||||
output = output[:100] + "..."
|
||||
|
||||
self.logger.info(f"Output: {output}")
|
||||
|
||||
# DEBUG level logging for full details
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
name = item.get("name", "unknown")
|
||||
self.logger.debug(f"Function call completed: {name}")
|
||||
if result:
|
||||
self.logger.debug(f"Function call result: {json.dumps(result, indent=2)}")
|
||||
|
||||
async def on_text(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a text message is encountered."""
|
||||
# Get the role to determine if it's Agent or User
|
||||
role = item.get("role", "unknown")
|
||||
content_items = item.get("content", [])
|
||||
|
||||
# Process content items to build display text
|
||||
text_parts = []
|
||||
for content_item in content_items:
|
||||
content_type = content_item.get("type", "output_text")
|
||||
if content_type == "output_text":
|
||||
text_content = content_item.get("text", "")
|
||||
if not text_content.strip():
|
||||
text_parts.append("[empty]")
|
||||
else:
|
||||
# Truncate long text and add ellipsis
|
||||
if len(text_content) > 2048:
|
||||
text_parts.append(text_content[:2048] + "...")
|
||||
else:
|
||||
text_parts.append(text_content)
|
||||
else:
|
||||
# Non-text content, show as [type]
|
||||
text_parts.append(f"[{content_type}]")
|
||||
|
||||
# Join all text parts
|
||||
display_text = ''.join(text_parts) if text_parts else "[empty]"
|
||||
|
||||
# Log with appropriate level and format
|
||||
if role == "assistant":
|
||||
self.logger.info(f"Agent: {display_text}")
|
||||
elif role == "user":
|
||||
self.logger.info(f"User: {display_text}")
|
||||
else:
|
||||
# Fallback for unknown roles, use debug level
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"Text message ({role}): {display_text}")
|
||||
|
||||
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""Called when an API call is about to start."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
model = kwargs.get("model", "unknown")
|
||||
self.logger.debug(f"API call starting for model: {model}")
|
||||
# Log sanitized messages if present
|
||||
if "messages" in kwargs:
|
||||
sanitized_messages = sanitize_image_urls(kwargs["messages"])
|
||||
self.logger.debug(f"API call messages: {json.dumps(sanitized_messages, indent=2)}")
|
||||
elif "input" in kwargs:
|
||||
sanitized_input = sanitize_image_urls(kwargs["input"])
|
||||
self.logger.debug(f"API call input: {json.dumps(sanitized_input, indent=2)}")
|
||||
|
||||
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when an API call has completed."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
model = kwargs.get("model", "unknown")
|
||||
self.logger.debug(f"API call completed for model: {model}")
|
||||
self.logger.debug(f"API call result: {json.dumps(sanitize_image_urls(result), indent=2)}")
|
||||
|
||||
async def on_screenshot(self, item: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""Called when a screenshot is taken."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
image_size = len(item) / 1024
|
||||
self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB")
|
||||
@@ -1,259 +0,0 @@
|
||||
"""
|
||||
PII anonymization callback handler using Microsoft Presidio for text and image redaction.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from .base import AsyncCallbackHandler
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
|
||||
try:
|
||||
from presidio_analyzer import AnalyzerEngine
|
||||
from presidio_anonymizer import AnonymizerEngine, DeanonymizeEngine
|
||||
from presidio_anonymizer.entities import RecognizerResult, OperatorConfig
|
||||
from presidio_image_redactor import ImageRedactorEngine
|
||||
from PIL import Image
|
||||
PRESIDIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
PRESIDIO_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that anonymizes PII in text and images using Microsoft Presidio.
|
||||
|
||||
This handler:
|
||||
1. Anonymizes PII in messages before sending to the agent loop
|
||||
2. Deanonymizes PII in tool calls and message outputs after the agent loop
|
||||
3. Redacts PII from images in computer_call_output messages
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
anonymize_text: bool = True,
|
||||
anonymize_images: bool = True,
|
||||
entities_to_anonymize: Optional[List[str]] = None,
|
||||
anonymization_operator: str = "replace",
|
||||
image_redaction_color: Tuple[int, int, int] = (255, 192, 203) # Pink
|
||||
):
|
||||
"""
|
||||
Initialize the PII anonymization callback.
|
||||
|
||||
Args:
|
||||
anonymize_text: Whether to anonymize text content
|
||||
anonymize_images: Whether to redact images
|
||||
entities_to_anonymize: List of entity types to anonymize (None for all)
|
||||
anonymization_operator: Presidio operator to use ("replace", "mask", "redact", etc.)
|
||||
image_redaction_color: RGB color for image redaction
|
||||
"""
|
||||
if not PRESIDIO_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Presidio is not available. Install with: "
|
||||
"pip install presidio-analyzer presidio-anonymizer presidio-image-redactor"
|
||||
)
|
||||
|
||||
self.anonymize_text = anonymize_text
|
||||
self.anonymize_images = anonymize_images
|
||||
self.entities_to_anonymize = entities_to_anonymize
|
||||
self.anonymization_operator = anonymization_operator
|
||||
self.image_redaction_color = image_redaction_color
|
||||
|
||||
# Initialize Presidio engines
|
||||
self.analyzer = AnalyzerEngine()
|
||||
self.anonymizer = AnonymizerEngine()
|
||||
self.deanonymizer = DeanonymizeEngine()
|
||||
self.image_redactor = ImageRedactorEngine()
|
||||
|
||||
# Store anonymization mappings for deanonymization
|
||||
self.anonymization_mappings: Dict[str, Any] = {}
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Anonymize PII in messages before sending to agent loop.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
List of messages with PII anonymized
|
||||
"""
|
||||
if not self.anonymize_text and not self.anonymize_images:
|
||||
return messages
|
||||
|
||||
anonymized_messages = []
|
||||
for msg in messages:
|
||||
anonymized_msg = await self._anonymize_message(msg)
|
||||
anonymized_messages.append(anonymized_msg)
|
||||
|
||||
return anonymized_messages
|
||||
|
||||
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Deanonymize PII in tool calls and message outputs after agent loop.
|
||||
|
||||
Args:
|
||||
output: List of output dictionaries
|
||||
|
||||
Returns:
|
||||
List of output with PII deanonymized for tool calls
|
||||
"""
|
||||
if not self.anonymize_text:
|
||||
return output
|
||||
|
||||
deanonymized_output = []
|
||||
for item in output:
|
||||
# Only deanonymize tool calls and computer_call messages
|
||||
if item.get("type") in ["computer_call", "computer_call_output"]:
|
||||
deanonymized_item = await self._deanonymize_item(item)
|
||||
deanonymized_output.append(deanonymized_item)
|
||||
else:
|
||||
deanonymized_output.append(item)
|
||||
|
||||
return deanonymized_output
|
||||
|
||||
async def _anonymize_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Anonymize PII in a single message."""
|
||||
msg_copy = message.copy()
|
||||
|
||||
# Anonymize text content
|
||||
if self.anonymize_text:
|
||||
msg_copy = await self._anonymize_text_content(msg_copy)
|
||||
|
||||
# Redact images in computer_call_output
|
||||
if self.anonymize_images and msg_copy.get("type") == "computer_call_output":
|
||||
msg_copy = await self._redact_image_content(msg_copy)
|
||||
|
||||
return msg_copy
|
||||
|
||||
async def _anonymize_text_content(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Anonymize text content in a message."""
|
||||
msg_copy = message.copy()
|
||||
|
||||
# Handle content array
|
||||
content = msg_copy.get("content", [])
|
||||
if isinstance(content, str):
|
||||
anonymized_text, _ = await self._anonymize_text(content)
|
||||
msg_copy["content"] = anonymized_text
|
||||
elif isinstance(content, list):
|
||||
anonymized_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
anonymized_text, _ = await self._anonymize_text(text)
|
||||
item_copy = item.copy()
|
||||
item_copy["text"] = anonymized_text
|
||||
anonymized_content.append(item_copy)
|
||||
else:
|
||||
anonymized_content.append(item)
|
||||
msg_copy["content"] = anonymized_content
|
||||
|
||||
return msg_copy
|
||||
|
||||
async def _redact_image_content(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Redact PII from images in computer_call_output messages."""
|
||||
msg_copy = message.copy()
|
||||
output = msg_copy.get("output", {})
|
||||
|
||||
if isinstance(output, dict) and "image_url" in output:
|
||||
try:
|
||||
# Extract base64 image data
|
||||
image_url = output["image_url"]
|
||||
if image_url.startswith("data:image/"):
|
||||
# Parse data URL
|
||||
header, data = image_url.split(",", 1)
|
||||
image_data = base64.b64decode(data)
|
||||
|
||||
# Load image with PIL
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# Redact PII from image
|
||||
redacted_image = self.image_redactor.redact(image, self.image_redaction_color)
|
||||
|
||||
# Convert back to base64
|
||||
buffer = io.BytesIO()
|
||||
redacted_image.save(buffer, format="PNG")
|
||||
redacted_data = base64.b64encode(buffer.getvalue()).decode()
|
||||
|
||||
# Update image URL
|
||||
output_copy = output.copy()
|
||||
output_copy["image_url"] = f"data:image/png;base64,{redacted_data}"
|
||||
msg_copy["output"] = output_copy
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to redact image: {e}")
|
||||
|
||||
return msg_copy
|
||||
|
||||
async def _deanonymize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Deanonymize PII in tool calls and computer outputs."""
|
||||
item_copy = item.copy()
|
||||
|
||||
# Handle computer_call arguments
|
||||
if item.get("type") == "computer_call":
|
||||
args = item_copy.get("args", {})
|
||||
if isinstance(args, dict):
|
||||
deanonymized_args = {}
|
||||
for key, value in args.items():
|
||||
if isinstance(value, str):
|
||||
deanonymized_value, _ = await self._deanonymize_text(value)
|
||||
deanonymized_args[key] = deanonymized_value
|
||||
else:
|
||||
deanonymized_args[key] = value
|
||||
item_copy["args"] = deanonymized_args
|
||||
|
||||
return item_copy
|
||||
|
||||
async def _anonymize_text(self, text: str) -> Tuple[str, List[RecognizerResult]]:
|
||||
"""Anonymize PII in text and return the anonymized text and results."""
|
||||
if not text.strip():
|
||||
return text, []
|
||||
|
||||
try:
|
||||
# Analyze text for PII
|
||||
analyzer_results = self.analyzer.analyze(
|
||||
text=text,
|
||||
entities=self.entities_to_anonymize,
|
||||
language="en"
|
||||
)
|
||||
|
||||
if not analyzer_results:
|
||||
return text, []
|
||||
|
||||
# Anonymize the text
|
||||
anonymized_result = self.anonymizer.anonymize(
|
||||
text=text,
|
||||
analyzer_results=analyzer_results,
|
||||
operators={entity_type: OperatorConfig(self.anonymization_operator)
|
||||
for entity_type in set(result.entity_type for result in analyzer_results)}
|
||||
)
|
||||
|
||||
# Store mapping for deanonymization
|
||||
mapping_key = str(hash(text))
|
||||
self.anonymization_mappings[mapping_key] = {
|
||||
"original": text,
|
||||
"anonymized": anonymized_result.text,
|
||||
"results": analyzer_results
|
||||
}
|
||||
|
||||
return anonymized_result.text, analyzer_results
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to anonymize text: {e}")
|
||||
return text, []
|
||||
|
||||
async def _deanonymize_text(self, text: str) -> Tuple[str, bool]:
|
||||
"""Attempt to deanonymize text using stored mappings."""
|
||||
try:
|
||||
# Look for matching anonymized text in mappings
|
||||
for mapping_key, mapping in self.anonymization_mappings.items():
|
||||
if mapping["anonymized"] == text:
|
||||
return mapping["original"], True
|
||||
|
||||
# If no mapping found, return original text
|
||||
return text, False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to deanonymize text: {e}")
|
||||
return text, False
|
||||
@@ -1,305 +0,0 @@
|
||||
"""
|
||||
Trajectory saving callback handler for ComputerAgent.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Union, override
|
||||
from PIL import Image, ImageDraw
|
||||
import io
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
def sanitize_image_urls(data: Any) -> Any:
|
||||
"""
|
||||
Recursively search for 'image_url' keys and set their values to '[omitted]'.
|
||||
|
||||
Args:
|
||||
data: Any data structure (dict, list, or primitive type)
|
||||
|
||||
Returns:
|
||||
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
# Create a copy of the dictionary
|
||||
sanitized = {}
|
||||
for key, value in data.items():
|
||||
if key == "image_url":
|
||||
sanitized[key] = "[omitted]"
|
||||
else:
|
||||
# Recursively sanitize the value
|
||||
sanitized[key] = sanitize_image_urls(value)
|
||||
return sanitized
|
||||
|
||||
elif isinstance(data, list):
|
||||
# Recursively sanitize each item in the list
|
||||
return [sanitize_image_urls(item) for item in data]
|
||||
|
||||
else:
|
||||
# For primitive types (str, int, bool, None, etc.), return as-is
|
||||
return data
|
||||
|
||||
|
||||
class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that saves agent trajectories to disk.
|
||||
|
||||
Saves each run as a separate trajectory with unique ID, and each turn
|
||||
within the trajectory gets its own folder with screenshots and responses.
|
||||
"""
|
||||
|
||||
def __init__(self, trajectory_dir: str):
|
||||
"""
|
||||
Initialize trajectory saver.
|
||||
|
||||
Args:
|
||||
trajectory_dir: Base directory to save trajectories
|
||||
"""
|
||||
self.trajectory_dir = Path(trajectory_dir)
|
||||
self.trajectory_id: Optional[str] = None
|
||||
self.current_turn: int = 0
|
||||
self.current_artifact: int = 0
|
||||
self.model: Optional[str] = None
|
||||
self.total_usage: Dict[str, Any] = {}
|
||||
|
||||
# Ensure trajectory directory exists
|
||||
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_turn_dir(self) -> Path:
|
||||
"""Get the directory for the current turn."""
|
||||
if not self.trajectory_id:
|
||||
raise ValueError("Trajectory not initialized - call _on_run_start first")
|
||||
|
||||
# format: trajectory_id/turn_000
|
||||
turn_dir = self.trajectory_dir / self.trajectory_id / f"turn_{self.current_turn:03d}"
|
||||
turn_dir.mkdir(parents=True, exist_ok=True)
|
||||
return turn_dir
|
||||
|
||||
def _save_artifact(self, name: str, artifact: Union[str, bytes, Dict[str, Any]]) -> None:
|
||||
"""Save an artifact to the current turn directory."""
|
||||
turn_dir = self._get_turn_dir()
|
||||
if isinstance(artifact, bytes):
|
||||
# format: turn_000/0000_name.png
|
||||
artifact_filename = f"{self.current_artifact:04d}_{name}"
|
||||
artifact_path = turn_dir / f"{artifact_filename}.png"
|
||||
with open(artifact_path, "wb") as f:
|
||||
f.write(artifact)
|
||||
else:
|
||||
# format: turn_000/0000_name.json
|
||||
artifact_filename = f"{self.current_artifact:04d}_{name}"
|
||||
artifact_path = turn_dir / f"{artifact_filename}.json"
|
||||
with open(artifact_path, "w") as f:
|
||||
json.dump(sanitize_image_urls(artifact), f, indent=2)
|
||||
self.current_artifact += 1
|
||||
|
||||
def _update_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Update total usage statistics."""
|
||||
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||
for key, value in source.items():
|
||||
if isinstance(value, dict):
|
||||
if key not in target:
|
||||
target[key] = {}
|
||||
add_dicts(target[key], value)
|
||||
else:
|
||||
if key not in target:
|
||||
target[key] = 0
|
||||
target[key] += value
|
||||
add_dicts(self.total_usage, usage)
|
||||
|
||||
@override
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Initialize trajectory tracking for a new run."""
|
||||
model = kwargs.get("model", "unknown")
|
||||
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
|
||||
if "+" in model:
|
||||
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
|
||||
|
||||
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
|
||||
now = datetime.now()
|
||||
self.trajectory_id = f"{now.strftime('%Y-%m-%d')}_{model_name_short}_{now.strftime('%H%M%S')}_{str(uuid.uuid4())[:4]}"
|
||||
self.current_turn = 0
|
||||
self.current_artifact = 0
|
||||
self.model = model
|
||||
self.total_usage = {}
|
||||
|
||||
# Create trajectory directory
|
||||
trajectory_path = self.trajectory_dir / self.trajectory_id
|
||||
trajectory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save trajectory metadata
|
||||
metadata = {
|
||||
"trajectory_id": self.trajectory_id,
|
||||
"created_at": str(uuid.uuid1().time),
|
||||
"status": "running",
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
with open(trajectory_path / "metadata.json", "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
@override
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
"""Finalize run tracking by updating metadata with completion status, usage, and new items."""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
# Update metadata with completion status, total usage, and new items
|
||||
trajectory_path = self.trajectory_dir / self.trajectory_id
|
||||
metadata_path = trajectory_path / "metadata.json"
|
||||
|
||||
# Read existing metadata
|
||||
if metadata_path.exists():
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Update metadata with completion info
|
||||
metadata.update({
|
||||
"status": "completed",
|
||||
"completed_at": str(uuid.uuid1().time),
|
||||
"total_usage": self.total_usage,
|
||||
"new_items": sanitize_image_urls(new_items),
|
||||
"total_turns": self.current_turn
|
||||
})
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
@override
|
||||
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
self._save_artifact("api_start", { "kwargs": kwargs })
|
||||
|
||||
@override
|
||||
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""Save API call result."""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
self._save_artifact("api_result", { "kwargs": kwargs, "result": result })
|
||||
|
||||
@override
|
||||
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""Save a screenshot."""
|
||||
if isinstance(screenshot, str):
|
||||
screenshot = base64.b64decode(screenshot)
|
||||
self._save_artifact(name, screenshot)
|
||||
|
||||
@override
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Called when usage information is received."""
|
||||
self._update_usage(usage)
|
||||
|
||||
@override
|
||||
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
||||
"""Save responses to the current turn directory and update usage statistics."""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
# Save responses
|
||||
turn_dir = self._get_turn_dir()
|
||||
response_data = {
|
||||
"timestamp": str(uuid.uuid1().time),
|
||||
"model": self.model,
|
||||
"kwargs": kwargs,
|
||||
"response": responses
|
||||
}
|
||||
|
||||
self._save_artifact("agent_response", response_data)
|
||||
|
||||
# Increment turn counter
|
||||
self.current_turn += 1
|
||||
|
||||
def _draw_crosshair_on_image(self, image_bytes: bytes, x: int, y: int) -> bytes:
|
||||
"""
|
||||
Draw a red dot and crosshair at the specified coordinates on the image.
|
||||
|
||||
Args:
|
||||
image_bytes: The original image as bytes
|
||||
x: X coordinate for the crosshair
|
||||
y: Y coordinate for the crosshair
|
||||
|
||||
Returns:
|
||||
Modified image as bytes with red dot and crosshair
|
||||
"""
|
||||
# Open the image
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
# Draw crosshair lines (red, 2px thick)
|
||||
crosshair_size = 20
|
||||
line_width = 2
|
||||
color = "red"
|
||||
|
||||
# Horizontal line
|
||||
draw.line([(x - crosshair_size, y), (x + crosshair_size, y)], fill=color, width=line_width)
|
||||
# Vertical line
|
||||
draw.line([(x, y - crosshair_size), (x, y + crosshair_size)], fill=color, width=line_width)
|
||||
|
||||
# Draw center dot (filled circle)
|
||||
dot_radius = 3
|
||||
draw.ellipse([(x - dot_radius, y - dot_radius), (x + dot_radius, y + dot_radius)], fill=color)
|
||||
|
||||
# Convert back to bytes
|
||||
output = io.BytesIO()
|
||||
image.save(output, format='PNG')
|
||||
return output.getvalue()
|
||||
|
||||
@override
|
||||
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Called when a computer call has completed.
|
||||
Saves screenshots and computer call output.
|
||||
"""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
self._save_artifact("computer_call_result", { "item": item, "result": result })
|
||||
|
||||
# Check if action has x/y coordinates and there's a screenshot in the result
|
||||
action = item.get("action", {})
|
||||
if "x" in action and "y" in action:
|
||||
# Look for screenshot in the result
|
||||
for result_item in result:
|
||||
if (result_item.get("type") == "computer_call_output" and
|
||||
result_item.get("output", {}).get("type") == "input_image"):
|
||||
|
||||
image_url = result_item["output"]["image_url"]
|
||||
|
||||
# Extract base64 image data
|
||||
if image_url.startswith("data:image/"):
|
||||
# Format: data:image/png;base64,<base64_data>
|
||||
base64_data = image_url.split(",", 1)[1]
|
||||
else:
|
||||
# Assume it's just base64 data
|
||||
base64_data = image_url
|
||||
|
||||
try:
|
||||
# Decode the image
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
|
||||
# Draw crosshair at the action coordinates
|
||||
annotated_image = self._draw_crosshair_on_image(
|
||||
image_bytes,
|
||||
int(action["x"]),
|
||||
int(action["y"])
|
||||
)
|
||||
|
||||
# Save as screenshot_action
|
||||
self._save_artifact("screenshot_action", annotated_image)
|
||||
|
||||
except Exception as e:
|
||||
# If annotation fails, just log and continue
|
||||
print(f"Failed to annotate screenshot: {e}")
|
||||
|
||||
break # Only process the first screenshot found
|
||||
|
||||
# Increment turn counter
|
||||
self.current_turn += 1
|
||||
@@ -1,297 +0,0 @@
|
||||
"""
|
||||
CLI chat interface for agent - Computer Use Agent
|
||||
|
||||
Usage:
|
||||
python -m agent.cli <model_string>
|
||||
|
||||
Examples:
|
||||
python -m agent.cli openai/computer-use-preview
|
||||
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
|
||||
"""
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
import dotenv
|
||||
from yaspin import yaspin
|
||||
except ImportError:
|
||||
if __name__ == "__main__":
|
||||
raise ImportError(
|
||||
"CLI dependencies not found. "
|
||||
"Please install with: pip install \"cua-agent[cli]\""
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Color codes for terminal output
|
||||
class Colors:
|
||||
RESET = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
DIM = '\033[2m'
|
||||
|
||||
# Text colors
|
||||
RED = '\033[31m'
|
||||
GREEN = '\033[32m'
|
||||
YELLOW = '\033[33m'
|
||||
BLUE = '\033[34m'
|
||||
MAGENTA = '\033[35m'
|
||||
CYAN = '\033[36m'
|
||||
WHITE = '\033[37m'
|
||||
GRAY = '\033[90m'
|
||||
|
||||
# Background colors
|
||||
BG_RED = '\033[41m'
|
||||
BG_GREEN = '\033[42m'
|
||||
BG_YELLOW = '\033[43m'
|
||||
BG_BLUE = '\033[44m'
|
||||
|
||||
|
||||
def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = False, end: str = "\n"):
|
||||
"""Print colored text to terminal."""
|
||||
prefix = ""
|
||||
if bold:
|
||||
prefix += Colors.BOLD
|
||||
if dim:
|
||||
prefix += Colors.DIM
|
||||
if color:
|
||||
prefix += color
|
||||
|
||||
print(f"{prefix}{text}{Colors.RESET}", end=end)
|
||||
|
||||
|
||||
def print_action(action_type: str, details: Dict[str, Any]):
|
||||
"""Print computer action with nice formatting."""
|
||||
# Format action details
|
||||
args_str = ""
|
||||
if action_type == "click" and "x" in details and "y" in details:
|
||||
args_str = f"({details['x']}, {details['y']})"
|
||||
elif action_type == "type" and "text" in details:
|
||||
text = details["text"]
|
||||
if len(text) > 50:
|
||||
text = text[:47] + "..."
|
||||
args_str = f'"{text}"'
|
||||
elif action_type == "key" and "key" in details:
|
||||
args_str = f"'{details['key']}'"
|
||||
elif action_type == "scroll" and "x" in details and "y" in details:
|
||||
args_str = f"({details['x']}, {details['y']})"
|
||||
|
||||
print_colored(f"🛠️ {action_type}{args_str}", dim=True)
|
||||
|
||||
|
||||
def print_welcome(model: str, agent_loop: str, container_name: str):
|
||||
"""Print welcome message."""
|
||||
print_colored(f"Connected to {container_name} ({model}, {agent_loop})")
|
||||
print_colored("Type 'exit' to quit.", dim=True)
|
||||
|
||||
async def ainput(prompt: str = ""):
|
||||
return await asyncio.to_thread(input, prompt)
|
||||
|
||||
async def chat_loop(agent, model: str, container_name: str):
|
||||
"""Main chat loop with the agent."""
|
||||
print_welcome(model, agent.agent_loop.__name__, container_name)
|
||||
|
||||
history = []
|
||||
|
||||
while True:
|
||||
# Get user input with prompt
|
||||
print_colored("> ", end="")
|
||||
user_input = await ainput()
|
||||
|
||||
if user_input.lower() in ['exit', 'quit', 'q']:
|
||||
print_colored("\n👋 Goodbye!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Add user message to history
|
||||
history.append({"role": "user", "content": user_input})
|
||||
|
||||
# Stream responses from the agent with spinner
|
||||
with yaspin(text="Thinking...", spinner="line", attrs=["dark"]) as spinner:
|
||||
spinner.hide()
|
||||
|
||||
async for result in agent.run(history):
|
||||
# Add agent responses to history
|
||||
history.extend(result.get("output", []))
|
||||
|
||||
# Process and display the output
|
||||
for item in result.get("output", []):
|
||||
if item.get("type") == "message":
|
||||
# Display agent text response
|
||||
content = item.get("content", [])
|
||||
for content_part in content:
|
||||
if content_part.get("text"):
|
||||
text = content_part.get("text", "").strip()
|
||||
if text:
|
||||
spinner.hide()
|
||||
print_colored(text)
|
||||
|
||||
elif item.get("type") == "computer_call":
|
||||
# Display computer action
|
||||
action = item.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
if action_type:
|
||||
spinner.hide()
|
||||
print_action(action_type, action)
|
||||
spinner.text = f"Performing {action_type}..."
|
||||
spinner.show()
|
||||
|
||||
elif item.get("type") == "function_call":
|
||||
# Display function call
|
||||
function_name = item.get("name", "")
|
||||
spinner.hide()
|
||||
print_colored(f"🔧 Calling function: {function_name}", dim=True)
|
||||
spinner.text = f"Calling {function_name}..."
|
||||
spinner.show()
|
||||
|
||||
elif item.get("type") == "function_call_output":
|
||||
# Display function output (dimmed)
|
||||
output = item.get("output", "")
|
||||
if output and len(output.strip()) > 0:
|
||||
spinner.hide()
|
||||
print_colored(f"📤 {output}", dim=True)
|
||||
|
||||
spinner.hide()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main CLI function."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="CUA Agent CLI - Interactive computer use assistant",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python -m agent.cli openai/computer-use-preview
|
||||
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"model",
|
||||
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--images",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of recent images to keep in context (default: 3)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--trajectory",
|
||||
action="store_true",
|
||||
help="Save trajectory for debugging"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--budget",
|
||||
type=float,
|
||||
help="Maximum budget for the session (in dollars)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Enable verbose logging"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check for required environment variables
|
||||
container_name = os.getenv("CUA_CONTAINER_NAME")
|
||||
cua_api_key = os.getenv("CUA_API_KEY")
|
||||
|
||||
# Prompt for missing environment variables
|
||||
if not container_name:
|
||||
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
|
||||
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
|
||||
container_name = input("Enter your CUA container name: ").strip()
|
||||
if not container_name:
|
||||
print_colored("❌ Container name is required.")
|
||||
sys.exit(1)
|
||||
|
||||
if not cua_api_key:
|
||||
print_colored("CUA_API_KEY not set.", dim=True)
|
||||
cua_api_key = input("Enter your CUA API key: ").strip()
|
||||
if not cua_api_key:
|
||||
print_colored("❌ API key is required.")
|
||||
sys.exit(1)
|
||||
|
||||
# Check for provider-specific API keys based on model
|
||||
provider_api_keys = {
|
||||
"openai/": "OPENAI_API_KEY",
|
||||
"anthropic/": "ANTHROPIC_API_KEY",
|
||||
"omniparser+": "OPENAI_API_KEY",
|
||||
"omniparser+": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
|
||||
# Find matching provider and check for API key
|
||||
for prefix, env_var in provider_api_keys.items():
|
||||
if args.model.startswith(prefix):
|
||||
if not os.getenv(env_var):
|
||||
print_colored(f"{env_var} not set.", dim=True)
|
||||
api_key = input(f"Enter your {env_var.replace('_', ' ').title()}: ").strip()
|
||||
if not api_key:
|
||||
print_colored(f"❌ {env_var.replace('_', ' ').title()} is required.")
|
||||
sys.exit(1)
|
||||
# Set the environment variable for the session
|
||||
os.environ[env_var] = api_key
|
||||
break
|
||||
|
||||
# Import here to avoid import errors if dependencies are missing
|
||||
try:
|
||||
from agent import ComputerAgent
|
||||
from computer import Computer
|
||||
except ImportError as e:
|
||||
print_colored(f"❌ Import error: {e}", Colors.RED, bold=True)
|
||||
print_colored("Make sure agent and computer libraries are installed.", Colors.YELLOW)
|
||||
sys.exit(1)
|
||||
|
||||
# Create computer instance
|
||||
async with Computer(
|
||||
os_type="linux",
|
||||
provider_type="cloud",
|
||||
name=container_name,
|
||||
api_key=cua_api_key
|
||||
) as computer:
|
||||
|
||||
# Create agent
|
||||
agent_kwargs = {
|
||||
"model": args.model,
|
||||
"tools": [computer],
|
||||
"only_n_most_recent_images": args.images,
|
||||
"verbosity": 20 if args.verbose else 30, # DEBUG vs WARNING
|
||||
}
|
||||
|
||||
if args.trajectory:
|
||||
agent_kwargs["trajectory_dir"] = "trajectories"
|
||||
|
||||
if args.budget:
|
||||
agent_kwargs["max_trajectory_budget"] = {
|
||||
"max_budget": args.budget,
|
||||
"raise_error": True,
|
||||
"reset_after_each_run": False
|
||||
}
|
||||
|
||||
agent = ComputerAgent(**agent_kwargs)
|
||||
|
||||
# Start chat loop
|
||||
await chat_loop(agent, args.model, container_name)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except (KeyboardInterrupt, EOFError) as _:
|
||||
print_colored("\n\n👋 Goodbye!")
|
||||
@@ -1,107 +0,0 @@
|
||||
"""
|
||||
Computer handler implementation for OpenAI computer-use-preview protocol.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Dict, List, Any, Literal
|
||||
from .types import Computer
|
||||
|
||||
|
||||
class OpenAIComputerHandler:
|
||||
"""Computer handler that implements the Computer protocol using the computer interface."""
|
||||
|
||||
def __init__(self, computer_interface):
|
||||
"""Initialize with a computer interface (from tool schema)."""
|
||||
self.interface = computer_interface
|
||||
|
||||
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
|
||||
"""Get the current environment type."""
|
||||
# For now, return a default - this could be enhanced to detect actual environment
|
||||
return "windows"
|
||||
|
||||
async def get_dimensions(self) -> tuple[int, int]:
|
||||
"""Get screen dimensions as (width, height)."""
|
||||
screen_size = await self.interface.get_screen_size()
|
||||
return screen_size["width"], screen_size["height"]
|
||||
|
||||
async def screenshot(self) -> str:
|
||||
"""Take a screenshot and return as base64 string."""
|
||||
screenshot_bytes = await self.interface.screenshot()
|
||||
return base64.b64encode(screenshot_bytes).decode('utf-8')
|
||||
|
||||
async def click(self, x: int, y: int, button: str = "left") -> None:
|
||||
"""Click at coordinates with specified button."""
|
||||
if button == "left":
|
||||
await self.interface.left_click(x, y)
|
||||
elif button == "right":
|
||||
await self.interface.right_click(x, y)
|
||||
else:
|
||||
# Default to left click for unknown buttons
|
||||
await self.interface.left_click(x, y)
|
||||
|
||||
async def double_click(self, x: int, y: int) -> None:
|
||||
"""Double click at coordinates."""
|
||||
await self.interface.double_click(x, y)
|
||||
|
||||
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
|
||||
"""Scroll at coordinates with specified scroll amounts."""
|
||||
await self.interface.move_cursor(x, y)
|
||||
await self.interface.scroll(scroll_x, scroll_y)
|
||||
|
||||
async def type(self, text: str) -> None:
|
||||
"""Type text."""
|
||||
await self.interface.type_text(text)
|
||||
|
||||
async def wait(self, ms: int = 1000) -> None:
|
||||
"""Wait for specified milliseconds."""
|
||||
import asyncio
|
||||
await asyncio.sleep(ms / 1000.0)
|
||||
|
||||
async def move(self, x: int, y: int) -> None:
|
||||
"""Move cursor to coordinates."""
|
||||
await self.interface.move_cursor(x, y)
|
||||
|
||||
async def keypress(self, keys: List[str]) -> None:
|
||||
"""Press key combination."""
|
||||
if len(keys) == 1:
|
||||
await self.interface.press_key(keys[0])
|
||||
else:
|
||||
# Handle key combinations
|
||||
await self.interface.hotkey(*keys)
|
||||
|
||||
async def drag(self, path: List[Dict[str, int]]) -> None:
|
||||
"""Drag along specified path."""
|
||||
if not path:
|
||||
return
|
||||
|
||||
# Start drag from first point
|
||||
start = path[0]
|
||||
await self.interface.mouse_down(start["x"], start["y"])
|
||||
|
||||
# Move through path
|
||||
for point in path[1:]:
|
||||
await self.interface.move_cursor(point["x"], point["y"])
|
||||
|
||||
# End drag at last point
|
||||
end = path[-1]
|
||||
await self.interface.mouse_up(end["x"], end["y"])
|
||||
|
||||
async def get_current_url(self) -> str:
|
||||
"""Get current URL (for browser environments)."""
|
||||
# This would need to be implemented based on the specific browser interface
|
||||
# For now, return empty string
|
||||
return ""
|
||||
|
||||
|
||||
def acknowledge_safety_check_callback(message: str) -> bool:
|
||||
"""Safety check callback for user acknowledgment."""
|
||||
response = input(
|
||||
f"Safety Check Warning: {message}\nDo you want to acknowledge and proceed? (y/n): "
|
||||
).lower()
|
||||
return response.strip() == "y"
|
||||
|
||||
|
||||
def check_blocklisted_url(url: str) -> None:
|
||||
"""Check if URL is blocklisted (placeholder implementation)."""
|
||||
# This would contain actual URL checking logic
|
||||
pass
|
||||
@@ -1,90 +0,0 @@
|
||||
"""
|
||||
Decorators for agent - agent_loop decorator
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Dict, List, Any, Callable, Optional
|
||||
from functools import wraps
|
||||
|
||||
from .types import AgentLoopInfo
|
||||
|
||||
# Global registry
|
||||
_agent_loops: List[AgentLoopInfo] = []
|
||||
|
||||
def agent_loop(models: str, priority: int = 0):
|
||||
"""
|
||||
Decorator to register an agent loop function.
|
||||
|
||||
Args:
|
||||
models: Regex pattern to match supported models
|
||||
priority: Priority for loop selection (higher = more priority)
|
||||
"""
|
||||
def decorator(func: Callable):
|
||||
# Validate function signature
|
||||
sig = inspect.signature(func)
|
||||
required_params = {'messages', 'model'}
|
||||
func_params = set(sig.parameters.keys())
|
||||
|
||||
if not required_params.issubset(func_params):
|
||||
missing = required_params - func_params
|
||||
raise ValueError(f"Agent loop function must have parameters: {missing}")
|
||||
|
||||
# Register the loop
|
||||
loop_info = AgentLoopInfo(
|
||||
func=func,
|
||||
models_regex=models,
|
||||
priority=priority
|
||||
)
|
||||
_agent_loops.append(loop_info)
|
||||
|
||||
# Sort by priority (highest first)
|
||||
_agent_loops.sort(key=lambda x: x.priority, reverse=True)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Wrap the function in an asyncio.Queue for cancellation support
|
||||
queue = asyncio.Queue()
|
||||
task = None
|
||||
|
||||
try:
|
||||
# Create a task that can be cancelled
|
||||
async def run_loop():
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
await queue.put(('result', result))
|
||||
except Exception as e:
|
||||
await queue.put(('error', e))
|
||||
|
||||
task = asyncio.create_task(run_loop())
|
||||
|
||||
# Wait for result or cancellation
|
||||
event_type, data = await queue.get()
|
||||
|
||||
if event_type == 'error':
|
||||
raise data
|
||||
return data
|
||||
|
||||
except asyncio.CancelledError:
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def get_agent_loops() -> List[AgentLoopInfo]:
|
||||
"""Get all registered agent loops"""
|
||||
return _agent_loops.copy()
|
||||
|
||||
def find_agent_loop(model: str) -> Optional[AgentLoopInfo]:
|
||||
"""Find the best matching agent loop for a model"""
|
||||
for loop_info in _agent_loops:
|
||||
if loop_info.matches_model(model):
|
||||
return loop_info
|
||||
return None
|
||||
@@ -1,11 +0,0 @@
|
||||
"""
|
||||
Agent loops for agent
|
||||
"""
|
||||
|
||||
# Import the loops to register them
|
||||
from . import anthropic
|
||||
from . import openai
|
||||
from . import uitars
|
||||
from . import omniparser
|
||||
|
||||
__all__ = ["anthropic", "openai", "uitars", "omniparser"]
|
||||
@@ -1,728 +0,0 @@
|
||||
"""
|
||||
Anthropic hosted tools agent loop implementation using liteLLM
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional
|
||||
import litellm
|
||||
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
|
||||
|
||||
from ..decorators import agent_loop
|
||||
from ..types import Messages, AgentResponse, Tools
|
||||
from ..responses import (
|
||||
make_reasoning_item,
|
||||
make_output_text_item,
|
||||
make_click_item,
|
||||
make_double_click_item,
|
||||
make_drag_item,
|
||||
make_keypress_item,
|
||||
make_move_item,
|
||||
make_scroll_item,
|
||||
make_type_item,
|
||||
make_wait_item,
|
||||
make_input_image_item,
|
||||
make_screenshot_item
|
||||
)
|
||||
|
||||
# Model version mapping to tool version and beta flag
|
||||
MODEL_TOOL_MAPPING = [
|
||||
# Claude 4 models
|
||||
{
|
||||
"pattern": r"claude-4|claude-opus-4|claude-sonnet-4",
|
||||
"tool_version": "computer_20250124",
|
||||
"beta_flag": "computer-use-2025-01-24"
|
||||
},
|
||||
# Claude 3.7 models
|
||||
{
|
||||
"pattern": r"claude-3\.?7|claude-3-7",
|
||||
"tool_version": "computer_20250124",
|
||||
"beta_flag": "computer-use-2025-01-24"
|
||||
},
|
||||
# Claude 3.5 models (fallback)
|
||||
{
|
||||
"pattern": r"claude-3\.?5|claude-3-5",
|
||||
"tool_version": "computer_20241022",
|
||||
"beta_flag": "computer-use-2024-10-22"
|
||||
}
|
||||
]
|
||||
|
||||
def _get_tool_config_for_model(model: str) -> Dict[str, str]:
|
||||
"""Get tool version and beta flag for the given model."""
|
||||
import re
|
||||
|
||||
for mapping in MODEL_TOOL_MAPPING:
|
||||
if re.search(mapping["pattern"], model, re.IGNORECASE):
|
||||
return {
|
||||
"tool_version": mapping["tool_version"],
|
||||
"beta_flag": mapping["beta_flag"]
|
||||
}
|
||||
|
||||
# Default to Claude 3.5 configuration
|
||||
return {
|
||||
"tool_version": "computer_20241022",
|
||||
"beta_flag": "computer-use-2024-10-22"
|
||||
}
|
||||
|
||||
def _map_computer_tool_to_anthropic(computer_tool: Any, tool_version: str) -> Dict[str, Any]:
|
||||
"""Map a computer tool to Anthropic's hosted tool schema."""
|
||||
return {
|
||||
"type": tool_version,
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"parameters": {
|
||||
"display_height_px": getattr(computer_tool, 'display_height', 768),
|
||||
"display_width_px": getattr(computer_tool, 'display_width', 1024),
|
||||
"display_number": getattr(computer_tool, 'display_number', 1),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def _prepare_tools_for_anthropic(tool_schemas: List[Dict[str, Any]], model: str) -> Tools:
|
||||
"""Prepare tools for Anthropic API format."""
|
||||
tool_config = _get_tool_config_for_model(model)
|
||||
anthropic_tools = []
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
# Map computer tool to Anthropic format
|
||||
anthropic_tools.append(_map_computer_tool_to_anthropic(
|
||||
schema["computer"],
|
||||
tool_config["tool_version"]
|
||||
))
|
||||
elif schema["type"] == "function":
|
||||
# Function tools - convert to Anthropic format
|
||||
function_schema = schema["function"]
|
||||
anthropic_tools.append({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_schema["name"],
|
||||
"description": function_schema.get("description", ""),
|
||||
"parameters": function_schema.get("parameters", {})
|
||||
}
|
||||
})
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
def _convert_responses_items_to_completion_messages(messages: Messages) -> List[Dict[str, Any]]:
|
||||
"""Convert responses_items message format to liteLLM completion format."""
|
||||
completion_messages = []
|
||||
|
||||
for message in messages:
|
||||
msg_type = message.get("type")
|
||||
role = message.get("role")
|
||||
|
||||
# Handle user messages (both with and without explicit type)
|
||||
if role == "user" or msg_type == "user":
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, list):
|
||||
# Multi-modal content - convert input_image to image format
|
||||
converted_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "input_image":
|
||||
# Convert input_image to Anthropic image format
|
||||
image_url = item.get("image_url", "")
|
||||
if image_url and image_url != "[omitted]":
|
||||
# Extract base64 data from data URL
|
||||
if "," in image_url:
|
||||
base64_data = image_url.split(",")[-1]
|
||||
else:
|
||||
base64_data = image_url
|
||||
|
||||
converted_content.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": base64_data
|
||||
}
|
||||
})
|
||||
else:
|
||||
# Keep other content types as-is
|
||||
converted_content.append(item)
|
||||
|
||||
completion_messages.append({
|
||||
"role": "user",
|
||||
"content": converted_content if converted_content else content
|
||||
})
|
||||
else:
|
||||
# Text content
|
||||
completion_messages.append({
|
||||
"role": "user",
|
||||
"content": content
|
||||
})
|
||||
|
||||
# Handle assistant messages
|
||||
elif role == "assistant":
|
||||
content = message.get("content", [])
|
||||
if isinstance(content, str):
|
||||
content = [{ "type": "output_text", "text": content }]
|
||||
|
||||
content = "\n".join(item.get("text", "") for item in content)
|
||||
completion_messages.append({
|
||||
"role": "assistant",
|
||||
"content": content
|
||||
})
|
||||
|
||||
elif msg_type == "reasoning":
|
||||
# Reasoning becomes part of assistant message
|
||||
summary = message.get("summary", [])
|
||||
reasoning_text = ""
|
||||
|
||||
if isinstance(summary, list) and summary:
|
||||
# Extract text from summary items
|
||||
for item in summary:
|
||||
if isinstance(item, dict) and item.get("type") == "summary_text":
|
||||
reasoning_text = item.get("text", "")
|
||||
break
|
||||
else:
|
||||
# Fallback to direct reasoning field
|
||||
reasoning_text = message.get("reasoning", "")
|
||||
|
||||
if reasoning_text:
|
||||
completion_messages.append({
|
||||
"role": "assistant",
|
||||
"content": reasoning_text
|
||||
})
|
||||
|
||||
elif msg_type == "computer_call":
|
||||
# Computer call becomes tool use in assistant message
|
||||
action = message.get("action", {})
|
||||
action_type = action.get("type")
|
||||
call_id = message.get("call_id", "call_1")
|
||||
|
||||
tool_use_content = []
|
||||
|
||||
if action_type == "click":
|
||||
tool_use_content.append({
|
||||
"type": "tool_use",
|
||||
"id": call_id,
|
||||
"name": "computer",
|
||||
"input": {
|
||||
"action": "click",
|
||||
"coordinate": [action.get("x", 0), action.get("y", 0)]
|
||||
}
|
||||
})
|
||||
elif action_type == "type":
|
||||
tool_use_content.append({
|
||||
"type": "tool_use",
|
||||
"id": call_id,
|
||||
"name": "computer",
|
||||
"input": {
|
||||
"action": "type",
|
||||
"text": action.get("text", "")
|
||||
}
|
||||
})
|
||||
elif action_type == "key":
|
||||
tool_use_content.append({
|
||||
"type": "tool_use",
|
||||
"id": call_id,
|
||||
"name": "computer",
|
||||
"input": {
|
||||
"action": "key",
|
||||
"key": action.get("key", "")
|
||||
}
|
||||
})
|
||||
elif action_type == "wait":
|
||||
tool_use_content.append({
|
||||
"type": "tool_use",
|
||||
"id": call_id,
|
||||
"name": "computer",
|
||||
"input": {
|
||||
"action": "screenshot"
|
||||
}
|
||||
})
|
||||
elif action_type == "screenshot":
|
||||
tool_use_content.append({
|
||||
"type": "tool_use",
|
||||
"id": call_id,
|
||||
"name": "computer",
|
||||
"input": {
|
||||
"action": "screenshot"
|
||||
}
|
||||
})
|
||||
|
||||
# Convert tool_use_content to OpenAI tool_calls format
|
||||
openai_tool_calls = []
|
||||
for tool_use in tool_use_content:
|
||||
openai_tool_calls.append({
|
||||
"id": tool_use["id"],
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_use["name"],
|
||||
"arguments": json.dumps(tool_use["input"])
|
||||
}
|
||||
})
|
||||
|
||||
# If the last completion message is an assistant message, extend the tool_calls
|
||||
if completion_messages and completion_messages[-1].get("role") == "assistant":
|
||||
if "tool_calls" not in completion_messages[-1]:
|
||||
completion_messages[-1]["tool_calls"] = []
|
||||
completion_messages[-1]["tool_calls"].extend(openai_tool_calls)
|
||||
else:
|
||||
# Create new assistant message with tool calls
|
||||
completion_messages.append({
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": openai_tool_calls
|
||||
})
|
||||
|
||||
elif msg_type == "computer_call_output":
|
||||
# Computer call output becomes OpenAI function result
|
||||
output = message.get("output", {})
|
||||
call_id = message.get("call_id", "call_1")
|
||||
|
||||
if output.get("type") == "input_image":
|
||||
# Screenshot result - convert to OpenAI format with image_url content
|
||||
image_url = output.get("image_url", "")
|
||||
completion_messages.append({
|
||||
"role": "function",
|
||||
"name": "computer",
|
||||
"tool_call_id": call_id,
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}]
|
||||
})
|
||||
else:
|
||||
# Text result - convert to OpenAI format
|
||||
completion_messages.append({
|
||||
"role": "function",
|
||||
"name": "computer",
|
||||
"tool_call_id": call_id,
|
||||
"content": str(output)
|
||||
})
|
||||
|
||||
return completion_messages
|
||||
|
||||
def _convert_completion_to_responses_items(response: Any) -> List[Dict[str, Any]]:
|
||||
"""Convert liteLLM completion response to responses_items message format."""
|
||||
responses_items = []
|
||||
|
||||
if not response or not hasattr(response, 'choices') or not response.choices:
|
||||
return responses_items
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
# Handle text content
|
||||
if hasattr(message, 'content') and message.content:
|
||||
if isinstance(message.content, str):
|
||||
responses_items.append(make_output_text_item(message.content))
|
||||
elif isinstance(message.content, list):
|
||||
for content_item in message.content:
|
||||
if isinstance(content_item, dict):
|
||||
if content_item.get("type") == "text":
|
||||
responses_items.append(make_output_text_item(content_item.get("text", "")))
|
||||
elif content_item.get("type") == "tool_use":
|
||||
# Convert tool use to computer call
|
||||
tool_input = content_item.get("input", {})
|
||||
action_type = tool_input.get("action")
|
||||
call_id = content_item.get("id")
|
||||
|
||||
# Action reference:
|
||||
# https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/computer-use-tool#available-actions
|
||||
|
||||
# Basic actions (all versions)
|
||||
if action_type == "screenshot":
|
||||
responses_items.append(make_screenshot_item(call_id=call_id))
|
||||
elif action_type == "left_click":
|
||||
coordinate = tool_input.get("coordinate", [0, 0])
|
||||
responses_items.append(make_click_item(
|
||||
x=coordinate[0] if len(coordinate) > 0 else 0,
|
||||
y=coordinate[1] if len(coordinate) > 1 else 0,
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "type":
|
||||
responses_items.append(make_type_item(
|
||||
text=tool_input.get("text", ""),
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "key":
|
||||
responses_items.append(make_keypress_item(
|
||||
key=tool_input.get("key", ""),
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "mouse_move":
|
||||
# Mouse move - create a custom action item
|
||||
coordinate = tool_input.get("coordinate", [0, 0])
|
||||
responses_items.append({
|
||||
"type": "computer_call",
|
||||
"call_id": call_id,
|
||||
"action": {
|
||||
"type": "mouse_move",
|
||||
"x": coordinate[0] if len(coordinate) > 0 else 0,
|
||||
"y": coordinate[1] if len(coordinate) > 1 else 0
|
||||
}
|
||||
})
|
||||
|
||||
# Enhanced actions (computer_20250124) Available in Claude 4 and Claude Sonnet 3.7
|
||||
elif action_type == "scroll":
|
||||
coordinate = tool_input.get("coordinate", [0, 0])
|
||||
responses_items.append(make_scroll_item(
|
||||
x=coordinate[0] if len(coordinate) > 0 else 0,
|
||||
y=coordinate[1] if len(coordinate) > 1 else 0,
|
||||
direction=tool_input.get("scroll_direction", "down"),
|
||||
amount=tool_input.get("scroll_amount", 3),
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "left_click_drag":
|
||||
start_coord = tool_input.get("start_coordinate", [0, 0])
|
||||
end_coord = tool_input.get("end_coordinate", [0, 0])
|
||||
responses_items.append(make_drag_item(
|
||||
start_x=start_coord[0] if len(start_coord) > 0 else 0,
|
||||
start_y=start_coord[1] if len(start_coord) > 1 else 0,
|
||||
end_x=end_coord[0] if len(end_coord) > 0 else 0,
|
||||
end_y=end_coord[1] if len(end_coord) > 1 else 0,
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "right_click":
|
||||
coordinate = tool_input.get("coordinate", [0, 0])
|
||||
responses_items.append(make_click_item(
|
||||
x=coordinate[0] if len(coordinate) > 0 else 0,
|
||||
y=coordinate[1] if len(coordinate) > 1 else 0,
|
||||
button="right",
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "middle_click":
|
||||
coordinate = tool_input.get("coordinate", [0, 0])
|
||||
responses_items.append(make_click_item(
|
||||
x=coordinate[0] if len(coordinate) > 0 else 0,
|
||||
y=coordinate[1] if len(coordinate) > 1 else 0,
|
||||
button="wheel",
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "double_click":
|
||||
coordinate = tool_input.get("coordinate", [0, 0])
|
||||
responses_items.append(make_double_click_item(
|
||||
x=coordinate[0] if len(coordinate) > 0 else 0,
|
||||
y=coordinate[1] if len(coordinate) > 1 else 0,
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "triple_click":
|
||||
# coordinate = tool_input.get("coordinate", [0, 0])
|
||||
# responses_items.append({
|
||||
# "type": "computer_call",
|
||||
# "call_id": call_id,
|
||||
# "action": {
|
||||
# "type": "triple_click",
|
||||
# "x": coordinate[0] if len(coordinate) > 0 else 0,
|
||||
# "y": coordinate[1] if len(coordinate) > 1 else 0
|
||||
# }
|
||||
# })
|
||||
raise NotImplementedError("triple_click")
|
||||
elif action_type == "left_mouse_down":
|
||||
# coordinate = tool_input.get("coordinate", [0, 0])
|
||||
# responses_items.append({
|
||||
# "type": "computer_call",
|
||||
# "call_id": call_id,
|
||||
# "action": {
|
||||
# "type": "mouse_down",
|
||||
# "button": "left",
|
||||
# "x": coordinate[0] if len(coordinate) > 0 else 0,
|
||||
# "y": coordinate[1] if len(coordinate) > 1 else 0
|
||||
# }
|
||||
# })
|
||||
raise NotImplementedError("left_mouse_down")
|
||||
elif action_type == "left_mouse_up":
|
||||
# coordinate = tool_input.get("coordinate", [0, 0])
|
||||
# responses_items.append({
|
||||
# "type": "computer_call",
|
||||
# "call_id": call_id,
|
||||
# "action": {
|
||||
# "type": "mouse_up",
|
||||
# "button": "left",
|
||||
# "x": coordinate[0] if len(coordinate) > 0 else 0,
|
||||
# "y": coordinate[1] if len(coordinate) > 1 else 0
|
||||
# }
|
||||
# })
|
||||
raise NotImplementedError("left_mouse_up")
|
||||
elif action_type == "hold_key":
|
||||
# responses_items.append({
|
||||
# "type": "computer_call",
|
||||
# "call_id": call_id,
|
||||
# "action": {
|
||||
# "type": "key_hold",
|
||||
# "key": tool_input.get("key", "")
|
||||
# }
|
||||
# })
|
||||
raise NotImplementedError("hold_key")
|
||||
elif action_type == "wait":
|
||||
responses_items.append(make_wait_item(
|
||||
call_id=call_id
|
||||
))
|
||||
else:
|
||||
raise ValueError(f"Unknown action type: {action_type}")
|
||||
|
||||
# Handle tool calls (alternative format)
|
||||
if hasattr(message, 'tool_calls') and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
print(tool_call)
|
||||
if tool_call.function.name == "computer":
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
action_type = args.get("action")
|
||||
call_id = tool_call.id
|
||||
|
||||
# Basic actions (all versions)
|
||||
if action_type == "screenshot":
|
||||
responses_items.append(make_screenshot_item(
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type in ["click", "left_click"]:
|
||||
coordinate = args.get("coordinate", [0, 0])
|
||||
responses_items.append(make_click_item(
|
||||
x=coordinate[0] if len(coordinate) > 0 else 0,
|
||||
y=coordinate[1] if len(coordinate) > 1 else 0,
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "type":
|
||||
responses_items.append(make_type_item(
|
||||
text=args.get("text", ""),
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "key":
|
||||
responses_items.append(make_keypress_item(
|
||||
key=args.get("key", ""),
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "mouse_move":
|
||||
coordinate = args.get("coordinate", [0, 0])
|
||||
responses_items.append(make_move_item(
|
||||
x=coordinate[0] if len(coordinate) > 0 else 0,
|
||||
y=coordinate[1] if len(coordinate) > 1 else 0,
|
||||
call_id=call_id
|
||||
))
|
||||
|
||||
# Enhanced actions (computer_20250124) Available in Claude 4 and Claude Sonnet 3.7
|
||||
elif action_type == "scroll":
|
||||
coordinate = args.get("coordinate", [0, 0])
|
||||
direction = args.get("scroll_direction", "down")
|
||||
amount = args.get("scroll_amount", 3)
|
||||
scroll_x = amount if direction == "left" else \
|
||||
-amount if direction == "right" else 0
|
||||
scroll_y = amount if direction == "up" else \
|
||||
-amount if direction == "down" else 0
|
||||
responses_items.append(make_scroll_item(
|
||||
x=coordinate[0] if len(coordinate) > 0 else 0,
|
||||
y=coordinate[1] if len(coordinate) > 1 else 0,
|
||||
scroll_x=scroll_x,
|
||||
scroll_y=scroll_y,
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "left_click_drag":
|
||||
start_coord = args.get("start_coordinate", [0, 0])
|
||||
end_coord = args.get("end_coordinate", [0, 0])
|
||||
responses_items.append(make_drag_item(
|
||||
start_x=start_coord[0] if len(start_coord) > 0 else 0,
|
||||
start_y=start_coord[1] if len(start_coord) > 1 else 0,
|
||||
end_x=end_coord[0] if len(end_coord) > 0 else 0,
|
||||
end_y=end_coord[1] if len(end_coord) > 1 else 0,
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "right_click":
|
||||
coordinate = args.get("coordinate", [0, 0])
|
||||
responses_items.append(make_click_item(
|
||||
x=coordinate[0] if len(coordinate) > 0 else 0,
|
||||
y=coordinate[1] if len(coordinate) > 1 else 0,
|
||||
button="right",
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "middle_click":
|
||||
coordinate = args.get("coordinate", [0, 0])
|
||||
responses_items.append(make_click_item(
|
||||
x=coordinate[0] if len(coordinate) > 0 else 0,
|
||||
y=coordinate[1] if len(coordinate) > 1 else 0,
|
||||
button="scroll",
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "double_click":
|
||||
coordinate = args.get("coordinate", [0, 0])
|
||||
responses_items.append(make_double_click_item(
|
||||
x=coordinate[0] if len(coordinate) > 0 else 0,
|
||||
y=coordinate[1] if len(coordinate) > 1 else 0,
|
||||
call_id=call_id
|
||||
))
|
||||
elif action_type == "triple_click":
|
||||
raise NotImplementedError("triple_click")
|
||||
elif action_type == "left_mouse_down":
|
||||
raise NotImplementedError("left_mouse_down")
|
||||
elif action_type == "left_mouse_up":
|
||||
raise NotImplementedError("left_mouse_up")
|
||||
elif action_type == "hold_key":
|
||||
raise NotImplementedError("hold_key")
|
||||
elif action_type == "wait":
|
||||
responses_items.append(make_wait_item(
|
||||
call_id=call_id
|
||||
))
|
||||
except json.JSONDecodeError:
|
||||
print("Failed to decode tool call arguments")
|
||||
# Skip malformed tool calls
|
||||
continue
|
||||
|
||||
return responses_items
|
||||
|
||||
def _add_cache_control(completion_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Add cache control to completion messages"""
|
||||
num_writes = 0
|
||||
for message in completion_messages:
|
||||
message["cache_control"] = { "type": "ephemeral" }
|
||||
num_writes += 1
|
||||
# Cache control has a maximum of 4 blocks
|
||||
if num_writes >= 4:
|
||||
break
|
||||
|
||||
return completion_messages
|
||||
|
||||
def _combine_completion_messages(completion_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Combine completion messages with the same role"""
|
||||
if not completion_messages:
|
||||
return completion_messages
|
||||
|
||||
combined_messages = []
|
||||
|
||||
for message in completion_messages:
|
||||
# If this is the first message or role is different from last, add as new message
|
||||
if not combined_messages or combined_messages[-1]["role"] != message["role"]:
|
||||
# Ensure content is a list format and normalize text content
|
||||
new_message = message.copy()
|
||||
new_message["content"] = _normalize_content(message.get("content", ""))
|
||||
|
||||
# Copy tool_calls if present
|
||||
if "tool_calls" in message:
|
||||
new_message["tool_calls"] = message["tool_calls"].copy()
|
||||
|
||||
combined_messages.append(new_message)
|
||||
else:
|
||||
# Same role as previous message, combine them
|
||||
last_message = combined_messages[-1]
|
||||
|
||||
# Combine content
|
||||
current_content = _normalize_content(message.get("content", ""))
|
||||
last_message["content"].extend(current_content)
|
||||
|
||||
# Combine tool_calls if present
|
||||
if "tool_calls" in message:
|
||||
if "tool_calls" not in last_message:
|
||||
last_message["tool_calls"] = []
|
||||
last_message["tool_calls"].extend(message["tool_calls"])
|
||||
|
||||
# Post-process to merge consecutive text blocks
|
||||
for message in combined_messages:
|
||||
message["content"] = _merge_consecutive_text(message["content"])
|
||||
|
||||
return combined_messages
|
||||
|
||||
def _normalize_content(content) -> List[Dict[str, Any]]:
|
||||
"""Normalize content to list format"""
|
||||
if isinstance(content, str):
|
||||
if content.strip(): # Only add non-empty strings
|
||||
return [{"type": "text", "text": content}]
|
||||
else:
|
||||
return []
|
||||
elif isinstance(content, list):
|
||||
return content.copy()
|
||||
else:
|
||||
return []
|
||||
|
||||
def _merge_consecutive_text(content_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Merge consecutive text blocks with newlines"""
|
||||
if not content_list:
|
||||
return content_list
|
||||
|
||||
merged = []
|
||||
|
||||
for item in content_list:
|
||||
if (item.get("type") == "text" and
|
||||
merged and
|
||||
merged[-1].get("type") == "text"):
|
||||
# Merge with previous text block
|
||||
merged[-1]["text"] += "\n" + item["text"]
|
||||
else:
|
||||
merged.append(item.copy())
|
||||
|
||||
return merged
|
||||
|
||||
@agent_loop(models=r".*claude-.*", priority=5)
|
||||
async def anthropic_hosted_tools_loop(
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
|
||||
"""
|
||||
Anthropic hosted tools agent loop using liteLLM acompletion.
|
||||
|
||||
Supports Anthropic's computer use models with hosted tools.
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
# Get tool configuration for this model
|
||||
tool_config = _get_tool_config_for_model(model)
|
||||
|
||||
# Prepare tools for Anthropic API
|
||||
anthropic_tools = _prepare_tools_for_anthropic(tools, model)
|
||||
|
||||
# Convert responses_items messages to completion format
|
||||
completion_messages = _convert_responses_items_to_completion_messages(messages)
|
||||
if use_prompt_caching:
|
||||
# First combine messages to reduce number of blocks
|
||||
completion_messages = _combine_completion_messages(completion_messages)
|
||||
# Then add cache control, anthropic requires explicit "cache_control" dicts
|
||||
completion_messages = _add_cache_control(completion_messages)
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": completion_messages,
|
||||
"tools": anthropic_tools if anthropic_tools else None,
|
||||
"stream": stream,
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Add beta header for computer use
|
||||
if anthropic_tools:
|
||||
api_kwargs["headers"] = {
|
||||
"anthropic-beta": tool_config["beta_flag"]
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Convert response to responses_items format
|
||||
responses_items = _convert_completion_to_responses_items(response)
|
||||
|
||||
# Extract usage information
|
||||
responses_usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(responses_usage)
|
||||
|
||||
# Create agent response
|
||||
agent_response = {
|
||||
"output": responses_items,
|
||||
"usage": responses_usage
|
||||
}
|
||||
|
||||
return agent_response
|
||||
@@ -1,339 +0,0 @@
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop implementation using liteLLM
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
import litellm
|
||||
import inspect
|
||||
import base64
|
||||
|
||||
from ..decorators import agent_loop
|
||||
from ..types import Messages, AgentResponse, Tools
|
||||
|
||||
SOM_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment"
|
||||
],
|
||||
"description": "The action to perform"
|
||||
},
|
||||
"element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
|
||||
},
|
||||
"start_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to start dragging from (required for drag action)"
|
||||
},
|
||||
"end_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to drag to (required for drag action)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)"
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"action"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
OMNIPARSER_AVAILABLE = False
|
||||
try:
|
||||
from som import OmniParser
|
||||
OMNIPARSER_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
OMNIPARSER_SINGLETON = None
|
||||
|
||||
def get_parser():
|
||||
global OMNIPARSER_SINGLETON
|
||||
if OMNIPARSER_SINGLETON is None:
|
||||
OMNIPARSER_SINGLETON = OmniParser()
|
||||
return OMNIPARSER_SINGLETON
|
||||
|
||||
def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Get the last computer_call_output message from a messages list.
|
||||
|
||||
Args:
|
||||
messages: List of messages to search through
|
||||
|
||||
Returns:
|
||||
The last computer_call_output message dict, or None if not found
|
||||
"""
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, dict) and message.get("type") == "computer_call_output":
|
||||
return message
|
||||
return None
|
||||
|
||||
def _prepare_tools_for_omniparser(tool_schemas: List[Dict[str, Any]]) -> Tuple[Tools, dict]:
|
||||
"""Prepare tools for OpenAI API format"""
|
||||
omniparser_tools = []
|
||||
id2xy = dict()
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
omniparser_tools.append(SOM_TOOL_SCHEMA)
|
||||
if "id2xy" in schema:
|
||||
id2xy = schema["id2xy"]
|
||||
else:
|
||||
schema["id2xy"] = id2xy
|
||||
elif schema["type"] == "function":
|
||||
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
|
||||
# Schema should be: {type, name, description, parameters}
|
||||
omniparser_tools.append({ "type": "function", **schema["function"] })
|
||||
|
||||
return omniparser_tools, id2xy
|
||||
|
||||
async def replace_function_with_computer_call(item: Dict[str, Any], id2xy: Dict[int, Tuple[float, float]]):
|
||||
item_type = item.get("type")
|
||||
|
||||
def _get_xy(element_id: Optional[int]) -> Union[Tuple[float, float], Tuple[None, None]]:
|
||||
if element_id is None:
|
||||
return (None, None)
|
||||
return id2xy.get(element_id, (None, None))
|
||||
|
||||
if item_type == "function_call":
|
||||
fn_name = item.get("name")
|
||||
fn_args = json.loads(item.get("arguments", "{}"))
|
||||
|
||||
item_id = item.get("id")
|
||||
call_id = item.get("call_id")
|
||||
|
||||
if fn_name == "computer":
|
||||
action = fn_args.get("action")
|
||||
element_id = fn_args.get("element_id")
|
||||
start_element_id = fn_args.get("start_element_id")
|
||||
end_element_id = fn_args.get("end_element_id")
|
||||
text = fn_args.get("text")
|
||||
keys = fn_args.get("keys")
|
||||
button = fn_args.get("button")
|
||||
scroll_x = fn_args.get("scroll_x")
|
||||
scroll_y = fn_args.get("scroll_y")
|
||||
|
||||
x, y = _get_xy(element_id)
|
||||
start_x, start_y = _get_xy(start_element_id)
|
||||
end_x, end_y = _get_xy(end_element_id)
|
||||
|
||||
action_args = {
|
||||
"type": action,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"start_x": start_x,
|
||||
"start_y": start_y,
|
||||
"end_x": end_x,
|
||||
"end_y": end_y,
|
||||
"text": text,
|
||||
"keys": keys,
|
||||
"button": button,
|
||||
"scroll_x": scroll_x,
|
||||
"scroll_y": scroll_y
|
||||
}
|
||||
# Remove None values to keep the JSON clean
|
||||
action_args = {k: v for k, v in action_args.items() if v is not None}
|
||||
|
||||
return [{
|
||||
"type": "computer_call",
|
||||
"action": action_args,
|
||||
"id": item_id,
|
||||
"call_id": call_id,
|
||||
"status": "completed"
|
||||
}]
|
||||
|
||||
return [item]
|
||||
|
||||
async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[Tuple[float, float], int]):
|
||||
"""
|
||||
Convert computer_call back to function_call format.
|
||||
Also handles computer_call_output -> function_call_output conversion.
|
||||
|
||||
Args:
|
||||
item: The item to convert
|
||||
xy2id: Mapping from (x, y) coordinates to element IDs
|
||||
"""
|
||||
item_type = item.get("type")
|
||||
|
||||
def _get_element_id(x: Optional[float], y: Optional[float]) -> Optional[int]:
|
||||
"""Get element ID from coordinates, return None if coordinates are None"""
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return xy2id.get((x, y))
|
||||
|
||||
if item_type == "computer_call":
|
||||
action_data = item.get("action", {})
|
||||
|
||||
# Extract coordinates and convert back to element IDs
|
||||
element_id = _get_element_id(action_data.get("x"), action_data.get("y"))
|
||||
start_element_id = _get_element_id(action_data.get("start_x"), action_data.get("start_y"))
|
||||
end_element_id = _get_element_id(action_data.get("end_x"), action_data.get("end_y"))
|
||||
|
||||
# Build function arguments
|
||||
fn_args = {
|
||||
"action": action_data.get("type"),
|
||||
"element_id": element_id,
|
||||
"start_element_id": start_element_id,
|
||||
"end_element_id": end_element_id,
|
||||
"text": action_data.get("text"),
|
||||
"keys": action_data.get("keys"),
|
||||
"button": action_data.get("button"),
|
||||
"scroll_x": action_data.get("scroll_x"),
|
||||
"scroll_y": action_data.get("scroll_y")
|
||||
}
|
||||
|
||||
# Remove None values to keep the JSON clean
|
||||
fn_args = {k: v for k, v in fn_args.items() if v is not None}
|
||||
|
||||
return [{
|
||||
"type": "function_call",
|
||||
"name": "computer",
|
||||
"arguments": json.dumps(fn_args),
|
||||
"id": item.get("id"),
|
||||
"call_id": item.get("call_id"),
|
||||
"status": "completed",
|
||||
|
||||
# Fall back to string representation
|
||||
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})"
|
||||
}]
|
||||
|
||||
elif item_type == "computer_call_output":
|
||||
# Simple conversion: computer_call_output -> function_call_output
|
||||
return [{
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"content": [item.get("output")],
|
||||
"id": item.get("id"),
|
||||
"status": "completed"
|
||||
}]
|
||||
|
||||
return [item]
|
||||
|
||||
|
||||
@agent_loop(models=r"omniparser\+.*|omni\+.*", priority=10)
|
||||
async def omniparser_loop(
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop using liteLLM responses.
|
||||
|
||||
Supports OpenAI's computer use preview models.
|
||||
"""
|
||||
if not OMNIPARSER_AVAILABLE:
|
||||
raise ValueError("omniparser loop requires som to be installed. Install it with `pip install cua-som`.")
|
||||
|
||||
tools = tools or []
|
||||
|
||||
llm_model = model.split('+')[-1]
|
||||
|
||||
# Prepare tools for OpenAI API
|
||||
openai_tools, id2xy = _prepare_tools_for_omniparser(tools)
|
||||
|
||||
# Find last computer_call_output
|
||||
last_computer_call_output = get_last_computer_call_output(messages)
|
||||
if last_computer_call_output:
|
||||
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
|
||||
image_data = image_url.split(",")[-1]
|
||||
if image_data:
|
||||
parser = get_parser()
|
||||
result = parser.parse(image_data)
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(result.annotated_image_base64, "annotated_image")
|
||||
for element in result.elements:
|
||||
id2xy[element.id] = ((element.bbox.x1 + element.bbox.x2) / 2, (element.bbox.y1 + element.bbox.y2) / 2)
|
||||
|
||||
# handle computer calls -> function calls
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
message = message.__dict__
|
||||
new_messages += await replace_computer_call_with_function(message, id2xy)
|
||||
messages = new_messages
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": llm_model,
|
||||
"input": messages,
|
||||
"tools": openai_tools if openai_tools else None,
|
||||
"stream": stream,
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
print(str(api_kwargs)[:1000])
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract usage information
|
||||
response.usage = {
|
||||
**response.usage.model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response.usage)
|
||||
|
||||
# handle som function calls -> xy computer calls
|
||||
new_output = []
|
||||
for i in range(len(response.output)):
|
||||
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy)
|
||||
response.output = new_output
|
||||
|
||||
return response
|
||||
@@ -1,95 +0,0 @@
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop implementation using liteLLM
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional
|
||||
import litellm
|
||||
|
||||
from ..decorators import agent_loop
|
||||
from ..types import Messages, AgentResponse, Tools
|
||||
|
||||
def _map_computer_tool_to_openai(computer_tool: Any) -> Dict[str, Any]:
|
||||
"""Map a computer tool to OpenAI's computer-use-preview tool schema"""
|
||||
return {
|
||||
"type": "computer_use_preview",
|
||||
"display_width": getattr(computer_tool, 'display_width', 1024),
|
||||
"display_height": getattr(computer_tool, 'display_height', 768),
|
||||
"environment": getattr(computer_tool, 'environment', "linux") # mac, windows, linux, browser
|
||||
}
|
||||
|
||||
|
||||
def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools:
|
||||
"""Prepare tools for OpenAI API format"""
|
||||
openai_tools = []
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
# Map computer tool to OpenAI format
|
||||
openai_tools.append(_map_computer_tool_to_openai(schema["computer"]))
|
||||
elif schema["type"] == "function":
|
||||
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
|
||||
# Schema should be: {type, name, description, parameters}
|
||||
openai_tools.append({ "type": "function", **schema["function"] })
|
||||
|
||||
return openai_tools
|
||||
|
||||
|
||||
@agent_loop(models=r".*computer-use-preview.*", priority=10)
|
||||
async def openai_computer_use_loop(
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop using liteLLM responses.
|
||||
|
||||
Supports OpenAI's computer use preview models.
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
# Prepare tools for OpenAI API
|
||||
openai_tools = _prepare_tools_for_openai(tools)
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"input": messages,
|
||||
"tools": openai_tools if openai_tools else None,
|
||||
"stream": stream,
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract usage information
|
||||
response.usage = {
|
||||
**response.usage.model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response.usage)
|
||||
|
||||
return response
|
||||
@@ -1,688 +0,0 @@
|
||||
"""
|
||||
UITARS agent loop implementation using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from ctypes import cast
|
||||
import json
|
||||
import base64
|
||||
import math
|
||||
import re
|
||||
import ast
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import litellm
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
|
||||
from litellm.responses.utils import Usage
|
||||
from openai.types.responses.response_computer_tool_call_param import ActionType, ResponseComputerToolCallParam
|
||||
from openai.types.responses.response_input_param import ComputerCallOutput
|
||||
from openai.types.responses.response_output_message_param import ResponseOutputMessageParam
|
||||
from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam, Summary
|
||||
|
||||
from ..decorators import agent_loop
|
||||
from ..types import Messages, AgentResponse, Tools
|
||||
from ..responses import (
|
||||
make_reasoning_item,
|
||||
make_output_text_item,
|
||||
make_click_item,
|
||||
make_double_click_item,
|
||||
make_drag_item,
|
||||
make_keypress_item,
|
||||
make_scroll_item,
|
||||
make_type_item,
|
||||
make_wait_item,
|
||||
make_input_image_item
|
||||
)
|
||||
|
||||
# Constants from reference code
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
FINISH_WORD = "finished"
|
||||
WAIT_WORD = "wait"
|
||||
ENV_FAIL_WORD = "error_env"
|
||||
CALL_USER = "call_user"
|
||||
|
||||
# Action space prompt for UITARS
|
||||
UITARS_ACTION_SPACE = """
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
"""
|
||||
|
||||
UITARS_PROMPT_TEMPLATE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
{action_space}
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
|
||||
def round_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def escape_single_quotes(text):
|
||||
"""Escape single quotes in text for safe string formatting."""
|
||||
pattern = r"(?<!\\)'"
|
||||
return re.sub(pattern, r"\\'", text)
|
||||
|
||||
|
||||
def parse_action(action_str):
|
||||
"""Parse action string into structured format."""
|
||||
try:
|
||||
node = ast.parse(action_str, mode='eval')
|
||||
if not isinstance(node, ast.Expression):
|
||||
raise ValueError("Not an expression")
|
||||
|
||||
call = node.body
|
||||
if not isinstance(call, ast.Call):
|
||||
raise ValueError("Not a function call")
|
||||
|
||||
# Get function name
|
||||
if isinstance(call.func, ast.Name):
|
||||
func_name = call.func.id
|
||||
elif isinstance(call.func, ast.Attribute):
|
||||
func_name = call.func.attr
|
||||
else:
|
||||
func_name = None
|
||||
|
||||
# Get keyword arguments
|
||||
kwargs = {}
|
||||
for kw in call.keywords:
|
||||
key = kw.arg
|
||||
if isinstance(kw.value, ast.Constant):
|
||||
value = kw.value.value
|
||||
elif isinstance(kw.value, ast.Str): # Compatibility with older Python
|
||||
value = kw.value.s
|
||||
else:
|
||||
value = None
|
||||
kwargs[key] = value
|
||||
|
||||
return {
|
||||
'function': func_name,
|
||||
'args': kwargs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to parse action '{action_str}': {e}")
|
||||
return None
|
||||
|
||||
|
||||
def parse_uitars_response(text: str, image_width: int, image_height: int) -> List[Dict[str, Any]]:
|
||||
"""Parse UITARS model response into structured actions."""
|
||||
text = text.strip()
|
||||
|
||||
# Extract thought
|
||||
thought = None
|
||||
if text.startswith("Thought:"):
|
||||
thought_match = re.search(r"Thought: (.+?)(?=\s*Action:|$)", text, re.DOTALL)
|
||||
if thought_match:
|
||||
thought = thought_match.group(1).strip()
|
||||
|
||||
# Extract action
|
||||
if "Action:" not in text:
|
||||
raise ValueError("No Action found in response")
|
||||
|
||||
action_str = text.split("Action:")[-1].strip()
|
||||
|
||||
# Handle special case for type actions
|
||||
if "type(content" in action_str:
|
||||
def escape_quotes(match):
|
||||
return match.group(1)
|
||||
|
||||
pattern = r"type\(content='(.*?)'\)"
|
||||
content = re.sub(pattern, escape_quotes, action_str)
|
||||
action_str = escape_single_quotes(content)
|
||||
action_str = "type(content='" + action_str + "')"
|
||||
|
||||
|
||||
# Parse the action
|
||||
parsed_action = parse_action(action_str.replace("\n", "\\n").lstrip())
|
||||
if parsed_action is None:
|
||||
raise ValueError(f"Action can't parse: {action_str}")
|
||||
|
||||
action_type = parsed_action["function"]
|
||||
params = parsed_action["args"]
|
||||
|
||||
# Process parameters
|
||||
action_inputs = {}
|
||||
for param_name, param in params.items():
|
||||
if param == "":
|
||||
continue
|
||||
param = str(param).lstrip()
|
||||
action_inputs[param_name.strip()] = param
|
||||
|
||||
# Handle coordinate parameters
|
||||
if "start_box" in param_name or "end_box" in param_name:
|
||||
# Parse coordinates like '(x,y)' or '(x1,y1,x2,y2)'
|
||||
numbers = param.replace("(", "").replace(")", "").split(",")
|
||||
float_numbers = [float(num.strip()) / 1000 for num in numbers] # Normalize to 0-1 range
|
||||
|
||||
if len(float_numbers) == 2:
|
||||
# Single point, duplicate for box format
|
||||
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
|
||||
|
||||
action_inputs[param_name.strip()] = str(float_numbers)
|
||||
|
||||
return [{
|
||||
"thought": thought,
|
||||
"action_type": action_type,
|
||||
"action_inputs": action_inputs,
|
||||
"text": text
|
||||
}]
|
||||
|
||||
|
||||
def convert_to_computer_actions(parsed_responses: List[Dict[str, Any]], image_width: int, image_height: int) -> List[ResponseComputerToolCallParam | ResponseOutputMessageParam]:
|
||||
"""Convert parsed UITARS responses to computer actions."""
|
||||
computer_actions = []
|
||||
|
||||
for response in parsed_responses:
|
||||
action_type = response.get("action_type")
|
||||
action_inputs = response.get("action_inputs", {})
|
||||
|
||||
if action_type == "finished":
|
||||
finished_text = action_inputs.get("content", "Task completed successfully.")
|
||||
computer_actions.append(make_output_text_item(finished_text))
|
||||
break
|
||||
|
||||
elif action_type == "wait":
|
||||
computer_actions.append(make_wait_item())
|
||||
|
||||
elif action_type == "call_user":
|
||||
computer_actions.append(make_output_text_item("I need assistance from the user to proceed with this task."))
|
||||
|
||||
elif action_type in ["click", "left_single"]:
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
|
||||
computer_actions.append(make_click_item(x, y, "left"))
|
||||
|
||||
elif action_type == "double_click":
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
|
||||
computer_actions.append(make_double_click_item(x, y))
|
||||
|
||||
elif action_type == "right_click":
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
|
||||
computer_actions.append(make_click_item(x, y, "right"))
|
||||
|
||||
elif action_type == "type":
|
||||
content = action_inputs.get("content", "")
|
||||
computer_actions.append(make_type_item(content))
|
||||
|
||||
elif action_type == "hotkey":
|
||||
key = action_inputs.get("key", "")
|
||||
keys = key.split()
|
||||
computer_actions.append(make_keypress_item(keys))
|
||||
|
||||
elif action_type == "press":
|
||||
key = action_inputs.get("key", "")
|
||||
computer_actions.append(make_keypress_item([key]))
|
||||
|
||||
elif action_type == "scroll":
|
||||
start_box = action_inputs.get("start_box")
|
||||
direction = action_inputs.get("direction", "down")
|
||||
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
else:
|
||||
x, y = image_width // 2, image_height // 2
|
||||
|
||||
scroll_y = 5 if "up" in direction.lower() else -5
|
||||
computer_actions.append(make_scroll_item(x, y, 0, scroll_y))
|
||||
|
||||
elif action_type == "drag":
|
||||
start_box = action_inputs.get("start_box")
|
||||
end_box = action_inputs.get("end_box")
|
||||
|
||||
if start_box and end_box:
|
||||
start_coords = eval(start_box)
|
||||
end_coords = eval(end_box)
|
||||
|
||||
start_x = int((start_coords[0] + start_coords[2]) / 2 * image_width)
|
||||
start_y = int((start_coords[1] + start_coords[3]) / 2 * image_height)
|
||||
end_x = int((end_coords[0] + end_coords[2]) / 2 * image_width)
|
||||
end_y = int((end_coords[1] + end_coords[3]) / 2 * image_height)
|
||||
|
||||
path = [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}]
|
||||
computer_actions.append(make_drag_item(path))
|
||||
|
||||
return computer_actions
|
||||
|
||||
|
||||
def pil_to_base64(image: Image.Image) -> str:
|
||||
"""Convert PIL image to base64 string."""
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def process_image_for_uitars(image_data: str, max_pixels: int = MAX_PIXELS, min_pixels: int = MIN_PIXELS) -> tuple[Image.Image, int, int]:
|
||||
"""Process image for UITARS model input."""
|
||||
# Decode base64 image
|
||||
if image_data.startswith('data:image'):
|
||||
image_data = image_data.split(',')[1]
|
||||
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
|
||||
original_width, original_height = image.size
|
||||
|
||||
# Resize image according to UITARS requirements
|
||||
if image.width * image.height > max_pixels:
|
||||
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
|
||||
width = int(image.width * resize_factor)
|
||||
height = int(image.height * resize_factor)
|
||||
image = image.resize((width, height))
|
||||
|
||||
if image.width * image.height < min_pixels:
|
||||
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
|
||||
width = math.ceil(image.width * resize_factor)
|
||||
height = math.ceil(image.height * resize_factor)
|
||||
image = image.resize((width, height))
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image, original_width, original_height
|
||||
|
||||
|
||||
def sanitize_message(msg: Any) -> Any:
|
||||
"""Return a copy of the message with image_url ommited within content parts"""
|
||||
if isinstance(msg, dict):
|
||||
result = {}
|
||||
for key, value in msg.items():
|
||||
if key == "content" and isinstance(value, list):
|
||||
result[key] = [
|
||||
{k: v for k, v in item.items() if k != "image_url"} if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
elif isinstance(msg, list):
|
||||
return [sanitize_message(item) for item in msg]
|
||||
else:
|
||||
return msg
|
||||
|
||||
|
||||
def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert UITARS internal message format back to LiteLLM format.
|
||||
|
||||
This function processes reasoning, computer_call, and computer_call_output messages
|
||||
and converts them to the appropriate LiteLLM assistant message format.
|
||||
|
||||
Args:
|
||||
messages: List of UITARS internal messages
|
||||
|
||||
Returns:
|
||||
List of LiteLLM formatted messages
|
||||
"""
|
||||
litellm_messages = []
|
||||
current_assistant_content = []
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
message_type = message.get("type")
|
||||
|
||||
if message_type == "reasoning":
|
||||
# Extract reasoning text from summary
|
||||
summary = message.get("summary", [])
|
||||
if summary and isinstance(summary, list):
|
||||
for summary_item in summary:
|
||||
if isinstance(summary_item, dict) and summary_item.get("type") == "summary_text":
|
||||
reasoning_text = summary_item.get("text", "")
|
||||
if reasoning_text:
|
||||
current_assistant_content.append(f"Thought: {reasoning_text}")
|
||||
|
||||
elif message_type == "computer_call":
|
||||
# Convert computer action to UITARS action format
|
||||
action = message.get("action", {})
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "click":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
button = action.get("button", "left")
|
||||
if button == "left":
|
||||
action_text = f"Action: click(start_box='({x},{y})')"
|
||||
elif button == "right":
|
||||
action_text = f"Action: right_single(start_box='({x},{y})')"
|
||||
else:
|
||||
action_text = f"Action: click(start_box='({x},{y})')"
|
||||
|
||||
elif action_type == "double_click":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
action_text = f"Action: left_double(start_box='({x},{y})')"
|
||||
|
||||
elif action_type == "drag":
|
||||
start_x, start_y = action.get("start_x", 0), action.get("start_y", 0)
|
||||
end_x, end_y = action.get("end_x", 0), action.get("end_y", 0)
|
||||
action_text = f"Action: drag(start_box='({start_x},{start_y})', end_box='({end_x},{end_y})')"
|
||||
|
||||
elif action_type == "key":
|
||||
key = action.get("key", "")
|
||||
action_text = f"Action: hotkey(key='{key}')"
|
||||
|
||||
elif action_type == "type":
|
||||
text = action.get("text", "")
|
||||
# Escape single quotes in the text
|
||||
escaped_text = escape_single_quotes(text)
|
||||
action_text = f"Action: type(content='{escaped_text}')"
|
||||
|
||||
elif action_type == "scroll":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
direction = action.get("direction", "down")
|
||||
action_text = f"Action: scroll(start_box='({x},{y})', direction='{direction}')"
|
||||
|
||||
elif action_type == "wait":
|
||||
action_text = "Action: wait()"
|
||||
|
||||
else:
|
||||
# Fallback for unknown action types
|
||||
action_text = f"Action: {action_type}({action})"
|
||||
|
||||
current_assistant_content.append(action_text)
|
||||
|
||||
# When we hit a computer_call_output, finalize the current assistant message
|
||||
if current_assistant_content:
|
||||
litellm_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "\n".join(current_assistant_content)}]
|
||||
})
|
||||
current_assistant_content = []
|
||||
|
||||
elif message_type == "computer_call_output":
|
||||
# Add screenshot from computer call output
|
||||
output = message.get("output", {})
|
||||
if isinstance(output, dict) and output.get("type") == "input_image":
|
||||
image_url = output.get("image_url", "")
|
||||
if image_url:
|
||||
litellm_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "image_url", "image_url": {"url": image_url}}]
|
||||
})
|
||||
|
||||
elif message.get("role") == "user":
|
||||
# # Handle user messages
|
||||
# content = message.get("content", "")
|
||||
# if isinstance(content, str):
|
||||
# litellm_messages.append({
|
||||
# "role": "user",
|
||||
# "content": content
|
||||
# })
|
||||
# elif isinstance(content, list):
|
||||
# litellm_messages.append({
|
||||
# "role": "user",
|
||||
# "content": content
|
||||
# })
|
||||
pass
|
||||
|
||||
# Add any remaining assistant content
|
||||
if current_assistant_content:
|
||||
litellm_messages.append({
|
||||
"role": "assistant",
|
||||
"content": current_assistant_content
|
||||
})
|
||||
|
||||
return litellm_messages
|
||||
|
||||
@agent_loop(models=r"(?i).*ui-?tars.*", priority=10)
|
||||
async def uitars_loop(
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
|
||||
"""
|
||||
UITARS agent loop using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B model.
|
||||
|
||||
Supports UITARS vision-language models for computer control.
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
# Create response items
|
||||
response_items = []
|
||||
|
||||
# Find computer tool for screen dimensions
|
||||
computer_tool = None
|
||||
for tool_schema in tools:
|
||||
if tool_schema["type"] == "computer":
|
||||
computer_tool = tool_schema["computer"]
|
||||
break
|
||||
|
||||
# Get screen dimensions
|
||||
screen_width, screen_height = 1024, 768
|
||||
if computer_tool:
|
||||
try:
|
||||
screen_width, screen_height = await computer_tool.get_dimensions()
|
||||
except:
|
||||
pass
|
||||
|
||||
# Process messages to extract instruction and image
|
||||
instruction = ""
|
||||
image_data = None
|
||||
|
||||
# Convert messages to list if string
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
# Extract instruction and latest screenshot
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content", "")
|
||||
|
||||
# Handle different content formats
|
||||
if isinstance(content, str):
|
||||
if not instruction and message.get("role") == "user":
|
||||
instruction = content
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text" and not instruction:
|
||||
instruction = item.get("text", "")
|
||||
elif item.get("type") == "image_url" and not image_data:
|
||||
image_url = item.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
image_data = image_url.get("url", "")
|
||||
else:
|
||||
image_data = image_url
|
||||
|
||||
# Also check for computer_call_output with screenshots
|
||||
if message.get("type") == "computer_call_output" and not image_data:
|
||||
output = message.get("output", {})
|
||||
if isinstance(output, dict) and output.get("type") == "input_image":
|
||||
image_data = output.get("image_url", "")
|
||||
|
||||
if instruction and image_data:
|
||||
break
|
||||
|
||||
if not instruction:
|
||||
instruction = "Help me complete this task by analyzing the screen and taking appropriate actions."
|
||||
|
||||
# Create prompt
|
||||
user_prompt = UITARS_PROMPT_TEMPLATE.format(
|
||||
instruction=instruction,
|
||||
action_space=UITARS_ACTION_SPACE,
|
||||
language="English"
|
||||
)
|
||||
|
||||
# Convert conversation history to LiteLLM format
|
||||
history_messages = convert_uitars_messages_to_litellm(messages)
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}
|
||||
]
|
||||
|
||||
# Add current user instruction with screenshot
|
||||
current_user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_prompt},
|
||||
]
|
||||
}
|
||||
litellm_messages.append(current_user_message)
|
||||
|
||||
# Process image for UITARS
|
||||
if not image_data:
|
||||
# Take screenshot if none found in messages
|
||||
if computer_handler:
|
||||
image_data = await computer_handler.screenshot()
|
||||
await _on_screenshot(image_data, "screenshot_before")
|
||||
|
||||
# Add screenshot to output items so it can be retained in history
|
||||
response_items.append(make_input_image_item(image_data))
|
||||
else:
|
||||
raise ValueError("No screenshot found in messages and no computer_handler provided")
|
||||
processed_image, original_width, original_height = process_image_for_uitars(image_data)
|
||||
encoded_image = pil_to_base64(processed_image)
|
||||
|
||||
# Add conversation history
|
||||
if history_messages:
|
||||
litellm_messages.extend(history_messages)
|
||||
else:
|
||||
litellm_messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
|
||||
]
|
||||
})
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": litellm_messages,
|
||||
"max_tokens": kwargs.get("max_tokens", 500),
|
||||
"temperature": kwargs.get("temperature", 0.0),
|
||||
"do_sample": kwargs.get("temperature", 0.0) > 0.0,
|
||||
"num_retries": max_retries,
|
||||
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
# Call liteLLM with UITARS model
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract response content
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
# Parse UITARS response
|
||||
parsed_responses = parse_uitars_response(response_content, original_width, original_height)
|
||||
|
||||
# Convert to computer actions
|
||||
computer_actions = convert_to_computer_actions(parsed_responses, original_width, original_height)
|
||||
|
||||
# Add computer actions to response items
|
||||
thought = parsed_responses[0].get("thought", "")
|
||||
if thought:
|
||||
response_items.append(make_reasoning_item(thought))
|
||||
response_items.extend(computer_actions)
|
||||
|
||||
# Extract usage information
|
||||
response_usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response_usage)
|
||||
|
||||
# Create agent response
|
||||
agent_response = {
|
||||
"output": response_items,
|
||||
"usage": response_usage
|
||||
}
|
||||
|
||||
return agent_response
|
||||
@@ -1,207 +0,0 @@
|
||||
"""
|
||||
Functions for making various Responses API items from different types of responses.
|
||||
Based on the OpenAI spec for Responses API items.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Literal, Union, Optional
|
||||
|
||||
from openai.types.responses.response_computer_tool_call_param import (
|
||||
ResponseComputerToolCallParam,
|
||||
ActionClick,
|
||||
ActionDoubleClick,
|
||||
ActionDrag,
|
||||
ActionDragPath,
|
||||
ActionKeypress,
|
||||
ActionMove,
|
||||
ActionScreenshot,
|
||||
ActionScroll,
|
||||
ActionType as ActionTypeAction,
|
||||
ActionWait,
|
||||
PendingSafetyCheck
|
||||
)
|
||||
|
||||
from openai.types.responses.response_function_tool_call_param import ResponseFunctionToolCallParam
|
||||
from openai.types.responses.response_output_text_param import ResponseOutputTextParam
|
||||
from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam, Summary
|
||||
from openai.types.responses.response_output_message_param import ResponseOutputMessageParam
|
||||
from openai.types.responses.easy_input_message_param import EasyInputMessageParam
|
||||
from openai.types.responses.response_input_image_param import ResponseInputImageParam
|
||||
|
||||
def random_id():
|
||||
return str(uuid.uuid4())
|
||||
|
||||
# User message items
|
||||
def make_input_image_item(image_data: Union[str, bytes]) -> EasyInputMessageParam:
|
||||
return EasyInputMessageParam(
|
||||
content=[
|
||||
ResponseInputImageParam(
|
||||
type="input_image",
|
||||
image_url=f"data:image/png;base64,{base64.b64encode(image_data).decode('utf-8') if isinstance(image_data, bytes) else image_data}"
|
||||
)
|
||||
],
|
||||
role="user",
|
||||
type="message"
|
||||
)
|
||||
|
||||
# Text items
|
||||
def make_reasoning_item(reasoning: str) -> ResponseReasoningItemParam:
|
||||
return ResponseReasoningItemParam(
|
||||
id=random_id(),
|
||||
summary=[
|
||||
Summary(text=reasoning, type="summary_text")
|
||||
],
|
||||
type="reasoning"
|
||||
)
|
||||
|
||||
def make_output_text_item(content: str) -> ResponseOutputMessageParam:
|
||||
return ResponseOutputMessageParam(
|
||||
id=random_id(),
|
||||
content=[
|
||||
ResponseOutputTextParam(
|
||||
text=content,
|
||||
type="output_text",
|
||||
annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message"
|
||||
)
|
||||
|
||||
# Function call items
|
||||
def make_function_call_item(function_name: str, arguments: Dict[str, Any], call_id: Optional[str] = None) -> ResponseFunctionToolCallParam:
|
||||
return ResponseFunctionToolCallParam(
|
||||
id=random_id(),
|
||||
call_id=call_id if call_id else random_id(),
|
||||
name=function_name,
|
||||
arguments=json.dumps(arguments),
|
||||
status="completed",
|
||||
type="function_call"
|
||||
)
|
||||
|
||||
# Computer tool call items
|
||||
def make_click_item(x: int, y: int, button: Literal["left", "right", "wheel", "back", "forward"] = "left", call_id: Optional[str] = None) -> ResponseComputerToolCallParam:
|
||||
return ResponseComputerToolCallParam(
|
||||
id=random_id(),
|
||||
call_id=call_id if call_id else random_id(),
|
||||
action=ActionClick(
|
||||
button=button,
|
||||
type="click",
|
||||
x=x,
|
||||
y=y
|
||||
),
|
||||
pending_safety_checks=[],
|
||||
status="completed",
|
||||
type="computer_call"
|
||||
)
|
||||
|
||||
def make_double_click_item(x: int, y: int, call_id: Optional[str] = None) -> ResponseComputerToolCallParam:
|
||||
return ResponseComputerToolCallParam(
|
||||
id=random_id(),
|
||||
call_id=call_id if call_id else random_id(),
|
||||
action=ActionDoubleClick(
|
||||
type="double_click",
|
||||
x=x,
|
||||
y=y
|
||||
),
|
||||
pending_safety_checks=[],
|
||||
status="completed",
|
||||
type="computer_call"
|
||||
)
|
||||
|
||||
def make_drag_item(path: List[Dict[str, int]], call_id: Optional[str] = None) -> ResponseComputerToolCallParam:
|
||||
drag_path = [ActionDragPath(x=point["x"], y=point["y"]) for point in path]
|
||||
return ResponseComputerToolCallParam(
|
||||
id=random_id(),
|
||||
call_id=call_id if call_id else random_id(),
|
||||
action=ActionDrag(
|
||||
path=drag_path,
|
||||
type="drag"
|
||||
),
|
||||
pending_safety_checks=[],
|
||||
status="completed",
|
||||
type="computer_call"
|
||||
)
|
||||
|
||||
def make_keypress_item(keys: List[str], call_id: Optional[str] = None) -> ResponseComputerToolCallParam:
|
||||
return ResponseComputerToolCallParam(
|
||||
id=random_id(),
|
||||
call_id=call_id if call_id else random_id(),
|
||||
action=ActionKeypress(
|
||||
keys=keys,
|
||||
type="keypress"
|
||||
),
|
||||
pending_safety_checks=[],
|
||||
status="completed",
|
||||
type="computer_call"
|
||||
)
|
||||
|
||||
def make_move_item(x: int, y: int, call_id: Optional[str] = None) -> ResponseComputerToolCallParam:
|
||||
return ResponseComputerToolCallParam(
|
||||
id=random_id(),
|
||||
call_id=call_id if call_id else random_id(),
|
||||
action=ActionMove(
|
||||
type="move",
|
||||
x=x,
|
||||
y=y
|
||||
),
|
||||
pending_safety_checks=[],
|
||||
status="completed",
|
||||
type="computer_call"
|
||||
)
|
||||
|
||||
def make_screenshot_item(call_id: Optional[str] = None) -> ResponseComputerToolCallParam:
|
||||
return ResponseComputerToolCallParam(
|
||||
id=random_id(),
|
||||
call_id=call_id if call_id else random_id(),
|
||||
action=ActionScreenshot(
|
||||
type="screenshot"
|
||||
),
|
||||
pending_safety_checks=[],
|
||||
status="completed",
|
||||
type="computer_call"
|
||||
)
|
||||
|
||||
def make_scroll_item(x: int, y: int, scroll_x: int, scroll_y: int, call_id: Optional[str] = None) -> ResponseComputerToolCallParam:
|
||||
return ResponseComputerToolCallParam(
|
||||
id=random_id(),
|
||||
call_id=call_id if call_id else random_id(),
|
||||
action=ActionScroll(
|
||||
scroll_x=scroll_x,
|
||||
scroll_y=scroll_y,
|
||||
type="scroll",
|
||||
x=x,
|
||||
y=y
|
||||
),
|
||||
pending_safety_checks=[],
|
||||
status="completed",
|
||||
type="computer_call"
|
||||
)
|
||||
|
||||
def make_type_item(text: str, call_id: Optional[str] = None) -> ResponseComputerToolCallParam:
|
||||
return ResponseComputerToolCallParam(
|
||||
id=random_id(),
|
||||
call_id=call_id if call_id else random_id(),
|
||||
action=ActionTypeAction(
|
||||
text=text,
|
||||
type="type"
|
||||
),
|
||||
pending_safety_checks=[],
|
||||
status="completed",
|
||||
type="computer_call"
|
||||
)
|
||||
|
||||
def make_wait_item(call_id: Optional[str] = None) -> ResponseComputerToolCallParam:
|
||||
return ResponseComputerToolCallParam(
|
||||
id=random_id(),
|
||||
call_id=call_id if call_id else random_id(),
|
||||
action=ActionWait(
|
||||
type="wait"
|
||||
),
|
||||
pending_safety_checks=[],
|
||||
status="completed",
|
||||
type="computer_call"
|
||||
)
|
||||
@@ -1,79 +0,0 @@
|
||||
"""
|
||||
Type definitions for agent
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any, Optional, Callable, Protocol, Literal
|
||||
from pydantic import BaseModel
|
||||
import re
|
||||
from litellm import ResponseInputParam, ResponsesAPIResponse, ToolParam
|
||||
from collections.abc import Iterable
|
||||
|
||||
# Agent input types
|
||||
Messages = str | ResponseInputParam
|
||||
Tools = Optional[Iterable[ToolParam]]
|
||||
|
||||
# Agent output types
|
||||
AgentResponse = ResponsesAPIResponse
|
||||
|
||||
# Agent loop registration
|
||||
class AgentLoopInfo(BaseModel):
|
||||
"""Information about a registered agent loop"""
|
||||
func: Callable
|
||||
models_regex: str
|
||||
priority: int = 0
|
||||
|
||||
def matches_model(self, model: str) -> bool:
|
||||
"""Check if this loop matches the given model"""
|
||||
return bool(re.match(self.models_regex, model))
|
||||
|
||||
# Computer tool interface
|
||||
class Computer(Protocol):
|
||||
"""Protocol defining the interface for computer interactions."""
|
||||
|
||||
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
|
||||
"""Get the current environment type."""
|
||||
...
|
||||
|
||||
async def get_dimensions(self) -> tuple[int, int]:
|
||||
"""Get screen dimensions as (width, height)."""
|
||||
...
|
||||
|
||||
async def screenshot(self) -> str:
|
||||
"""Take a screenshot and return as base64 string."""
|
||||
...
|
||||
|
||||
async def click(self, x: int, y: int, button: str = "left") -> None:
|
||||
"""Click at coordinates with specified button."""
|
||||
...
|
||||
|
||||
async def double_click(self, x: int, y: int) -> None:
|
||||
"""Double click at coordinates."""
|
||||
...
|
||||
|
||||
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
|
||||
"""Scroll at coordinates with specified scroll amounts."""
|
||||
...
|
||||
|
||||
async def type(self, text: str) -> None:
|
||||
"""Type text."""
|
||||
...
|
||||
|
||||
async def wait(self, ms: int = 1000) -> None:
|
||||
"""Wait for specified milliseconds."""
|
||||
...
|
||||
|
||||
async def move(self, x: int, y: int) -> None:
|
||||
"""Move cursor to coordinates."""
|
||||
...
|
||||
|
||||
async def keypress(self, keys: List[str]) -> None:
|
||||
"""Press key combination."""
|
||||
...
|
||||
|
||||
async def drag(self, path: List[Dict[str, int]]) -> None:
|
||||
"""Drag along specified path."""
|
||||
...
|
||||
|
||||
async def get_current_url(self) -> str:
|
||||
"""Get current URL (for browser environments)."""
|
||||
...
|
||||
@@ -1,7 +0,0 @@
|
||||
"""
|
||||
UI components for agent
|
||||
"""
|
||||
|
||||
from .gradio import launch_ui, create_gradio_ui
|
||||
|
||||
__all__ = ["launch_ui", "create_gradio_ui"]
|
||||
@@ -1,4 +0,0 @@
|
||||
from .gradio import launch_ui
|
||||
|
||||
if __name__ == "__main__":
|
||||
launch_ui()
|
||||
@@ -1,8 +0,0 @@
|
||||
"""
|
||||
Gradio UI for agent
|
||||
"""
|
||||
|
||||
from .app import launch_ui
|
||||
from .ui_components import create_gradio_ui
|
||||
|
||||
__all__ = ["launch_ui", "create_gradio_ui"]
|
||||
@@ -1,248 +0,0 @@
|
||||
"""
|
||||
Advanced Gradio UI for Computer-Use Agent (cua-agent)
|
||||
|
||||
This is a Gradio interface for the Computer-Use Agent v0.4.x (cua-agent)
|
||||
with an advanced UI for model selection and configuration.
|
||||
|
||||
Supported Agent Models:
|
||||
- OpenAI: openai/computer-use-preview
|
||||
- Anthropic: anthropic/claude-3-5-sonnet-20241022, anthropic/claude-3-7-sonnet-20250219
|
||||
- UI-TARS: huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B
|
||||
- Omniparser: omniparser+anthropic/claude-3-5-sonnet-20241022, omniparser+ollama_chat/gemma3
|
||||
|
||||
Requirements:
|
||||
- Mac with Apple Silicon (M1/M2/M3/M4), Linux, or Windows
|
||||
- macOS 14 (Sonoma) or newer / Ubuntu 20.04+
|
||||
- Python 3.11+
|
||||
- Lume CLI installed (https://github.com/trycua/cua)
|
||||
- OpenAI or Anthropic API key
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, AsyncGenerator, Any, Tuple, Union
|
||||
import gradio as gr
|
||||
from gradio.components.chatbot import MetadataDict
|
||||
from typing import cast
|
||||
|
||||
# Import from agent package
|
||||
from agent import ComputerAgent
|
||||
from agent.types import Messages, AgentResponse
|
||||
from computer import Computer
|
||||
|
||||
# Global variables
|
||||
global_agent = None
|
||||
global_computer = None
|
||||
SETTINGS_FILE = Path(".gradio_settings.json")
|
||||
|
||||
|
||||
import dotenv
|
||||
if dotenv.load_dotenv():
|
||||
print(f"DEBUG - Loaded environment variables from {dotenv.find_dotenv()}")
|
||||
else:
|
||||
print("DEBUG - No .env file found")
|
||||
|
||||
# --- Settings Load/Save Functions ---
|
||||
def load_settings() -> Dict[str, Any]:
|
||||
"""Loads settings from the JSON file."""
|
||||
if SETTINGS_FILE.exists():
|
||||
try:
|
||||
with open(SETTINGS_FILE, "r") as f:
|
||||
settings = json.load(f)
|
||||
if isinstance(settings, dict):
|
||||
print(f"DEBUG - Loaded settings from {SETTINGS_FILE}")
|
||||
return settings
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
print(f"Warning: Could not load settings from {SETTINGS_FILE}: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def save_settings(settings: Dict[str, Any]):
|
||||
"""Saves settings to the JSON file."""
|
||||
settings.pop("provider_api_key", None)
|
||||
try:
|
||||
with open(SETTINGS_FILE, "w") as f:
|
||||
json.dump(settings, f, indent=4)
|
||||
print(f"DEBUG - Saved settings to {SETTINGS_FILE}")
|
||||
except IOError as e:
|
||||
print(f"Warning: Could not save settings to {SETTINGS_FILE}: {e}")
|
||||
|
||||
|
||||
# # Custom Screenshot Handler for Gradio chat
|
||||
# class GradioChatScreenshotHandler:
|
||||
# """Custom handler that adds screenshots to the Gradio chatbot."""
|
||||
|
||||
# def __init__(self, chatbot_history: List[gr.ChatMessage]):
|
||||
# self.chatbot_history = chatbot_history
|
||||
# print("GradioChatScreenshotHandler initialized")
|
||||
|
||||
# async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None:
|
||||
# """Add screenshot to chatbot when a screenshot is taken."""
|
||||
# image_markdown = f""
|
||||
|
||||
# if self.chatbot_history is not None:
|
||||
# self.chatbot_history.append(
|
||||
# gr.ChatMessage(
|
||||
# role="assistant",
|
||||
# content=image_markdown,
|
||||
# metadata={"title": f"🖥️ Screenshot - {action_type}", "status": "done"},
|
||||
# )
|
||||
# )
|
||||
|
||||
|
||||
# Detect platform capabilities
|
||||
is_mac = platform.system().lower() == "darwin"
|
||||
is_lume_available = is_mac or (os.environ.get("PYLUME_HOST", "localhost") != "localhost")
|
||||
|
||||
print("PYLUME_HOST: ", os.environ.get("PYLUME_HOST", "localhost"))
|
||||
print("is_mac: ", is_mac)
|
||||
print("Lume available: ", is_lume_available)
|
||||
|
||||
# Map model names to agent model strings
|
||||
MODEL_MAPPINGS = {
|
||||
"openai": {
|
||||
"default": "openai/computer-use-preview",
|
||||
"OpenAI: Computer-Use Preview": "openai/computer-use-preview",
|
||||
},
|
||||
"anthropic": {
|
||||
"default": "anthropic/claude-3-7-sonnet-20250219",
|
||||
"Anthropic: Claude 4 Opus (20250514)": "anthropic/claude-opus-4-20250514",
|
||||
"Anthropic: Claude 4 Sonnet (20250514)": "anthropic/claude-sonnet-4-20250514",
|
||||
"Anthropic: Claude 3.7 Sonnet (20250219)": "anthropic/claude-3-7-sonnet-20250219",
|
||||
"Anthropic: Claude 3.5 Sonnet (20240620)": "anthropic/claude-3-5-sonnet-20240620",
|
||||
},
|
||||
"omni": {
|
||||
"default": "omniparser+openai/gpt-4o",
|
||||
"OMNI: OpenAI GPT-4o": "omniparser+openai/gpt-4o",
|
||||
"OMNI: OpenAI GPT-4o mini": "omniparser+openai/gpt-4o-mini",
|
||||
"OMNI: Claude 3.7 Sonnet (20250219)": "omniparser+anthropic/claude-3-7-sonnet-20250219",
|
||||
"OMNI: Claude 3.5 Sonnet (20240620)": "omniparser+anthropic/claude-3-5-sonnet-20240620",
|
||||
},
|
||||
"uitars": {
|
||||
"default": "huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B" if is_mac else "ui-tars",
|
||||
"huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B": "huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_model_string(model_name: str, loop_provider: str) -> str:
|
||||
"""Determine the agent model string based on the input."""
|
||||
if model_name == "Custom model (OpenAI compatible API)":
|
||||
return "custom_oaicompat"
|
||||
elif model_name == "Custom model (ollama)":
|
||||
return "custom_ollama"
|
||||
elif loop_provider == "OMNI-OLLAMA" or model_name.startswith("OMNI: Ollama "):
|
||||
if model_name.startswith("OMNI: Ollama "):
|
||||
ollama_model = model_name.split("OMNI: Ollama ", 1)[1]
|
||||
return f"omniparser+ollama_chat/{ollama_model}"
|
||||
return "omniparser+ollama_chat/llama3"
|
||||
|
||||
# Map based on loop provider
|
||||
mapping = MODEL_MAPPINGS.get(loop_provider.lower(), MODEL_MAPPINGS["openai"])
|
||||
return mapping.get(model_name, mapping["default"])
|
||||
|
||||
|
||||
def get_ollama_models() -> List[str]:
|
||||
"""Get available models from Ollama if installed."""
|
||||
try:
|
||||
import subprocess
|
||||
result = subprocess.run(["ollama", "list"], capture_output=True, text=True)
|
||||
if result.returncode == 0:
|
||||
lines = result.stdout.strip().split("\n")
|
||||
if len(lines) < 2:
|
||||
return []
|
||||
models = []
|
||||
for line in lines[1:]:
|
||||
parts = line.split()
|
||||
if parts:
|
||||
model_name = parts[0]
|
||||
models.append(f"OMNI: Ollama {model_name}")
|
||||
return models
|
||||
return []
|
||||
except Exception as e:
|
||||
logging.error(f"Error getting Ollama models: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def create_computer_instance(
|
||||
verbosity: int = logging.INFO,
|
||||
os_type: str = "macos",
|
||||
provider_type: str = "lume",
|
||||
name: Optional[str] = None,
|
||||
api_key: Optional[str] = None
|
||||
) -> Computer:
|
||||
"""Create or get the global Computer instance."""
|
||||
global global_computer
|
||||
if global_computer is None:
|
||||
global_computer = Computer(
|
||||
verbosity=verbosity,
|
||||
os_type=os_type,
|
||||
provider_type=provider_type,
|
||||
name=name if name else "",
|
||||
api_key=api_key
|
||||
)
|
||||
return global_computer
|
||||
|
||||
|
||||
def create_agent(
|
||||
model_string: str,
|
||||
save_trajectory: bool = True,
|
||||
only_n_most_recent_images: int = 3,
|
||||
verbosity: int = logging.INFO,
|
||||
custom_model_name: Optional[str] = None,
|
||||
computer_os: str = "macos",
|
||||
computer_provider: str = "lume",
|
||||
computer_name: Optional[str] = None,
|
||||
computer_api_key: Optional[str] = None,
|
||||
max_trajectory_budget: Optional[float] = None,
|
||||
) -> ComputerAgent:
|
||||
"""Create or update the global agent with the specified parameters."""
|
||||
global global_agent
|
||||
|
||||
# Create the computer
|
||||
computer = create_computer_instance(
|
||||
verbosity=verbosity,
|
||||
os_type=computer_os,
|
||||
provider_type=computer_provider,
|
||||
name=computer_name,
|
||||
api_key=computer_api_key
|
||||
)
|
||||
|
||||
# Handle custom models
|
||||
if model_string == "custom_oaicompat" and custom_model_name:
|
||||
model_string = custom_model_name
|
||||
elif model_string == "custom_ollama" and custom_model_name:
|
||||
model_string = f"omniparser+ollama_chat/{custom_model_name}"
|
||||
|
||||
# Create agent kwargs
|
||||
agent_kwargs = {
|
||||
"model": model_string,
|
||||
"tools": [computer],
|
||||
"only_n_most_recent_images": only_n_most_recent_images,
|
||||
"verbosity": verbosity,
|
||||
}
|
||||
|
||||
if save_trajectory:
|
||||
agent_kwargs["trajectory_dir"] = "trajectories"
|
||||
|
||||
if max_trajectory_budget:
|
||||
agent_kwargs["max_trajectory_budget"] = {"max_budget": max_trajectory_budget, "raise_error": True}
|
||||
|
||||
global_agent = ComputerAgent(**agent_kwargs)
|
||||
return global_agent
|
||||
|
||||
|
||||
def launch_ui():
|
||||
"""Standalone function to launch the Gradio app."""
|
||||
from agent.ui.gradio.ui_components import create_gradio_ui
|
||||
print(f"Starting Gradio app for CUA Agent...")
|
||||
demo = create_gradio_ui()
|
||||
demo.launch(share=False, inbrowser=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
launch_ui()
|
||||
@@ -1,721 +0,0 @@
|
||||
"""
|
||||
UI Components for the Gradio interface
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Any, cast
|
||||
import gradio as gr
|
||||
from gradio.components.chatbot import MetadataDict
|
||||
|
||||
from .app import (
|
||||
load_settings, save_settings, create_agent, get_model_string,
|
||||
get_ollama_models, global_agent, global_computer
|
||||
)
|
||||
|
||||
# Global messages array to maintain conversation history
|
||||
global_messages = []
|
||||
|
||||
|
||||
def create_gradio_ui() -> gr.Blocks:
|
||||
"""Create a Gradio UI for the Computer-Use Agent."""
|
||||
|
||||
# Load settings
|
||||
saved_settings = load_settings()
|
||||
|
||||
# Check for API keys
|
||||
openai_api_key = os.environ.get("OPENAI_API_KEY", "")
|
||||
anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY", "")
|
||||
cua_api_key = os.environ.get("CUA_API_KEY", "")
|
||||
|
||||
# Model choices
|
||||
openai_models = ["OpenAI: Computer-Use Preview"]
|
||||
anthropic_models = [
|
||||
"Anthropic: Claude 4 Opus (20250514)",
|
||||
"Anthropic: Claude 4 Sonnet (20250514)",
|
||||
"Anthropic: Claude 3.7 Sonnet (20250219)",
|
||||
"Anthropic: Claude 3.5 Sonnet (20240620)",
|
||||
]
|
||||
omni_models = [
|
||||
"OMNI: OpenAI GPT-4o",
|
||||
"OMNI: OpenAI GPT-4o mini",
|
||||
"OMNI: Claude 3.7 Sonnet (20250219)",
|
||||
"OMNI: Claude 3.5 Sonnet (20240620)"
|
||||
]
|
||||
|
||||
# Check if API keys are available
|
||||
has_openai_key = bool(openai_api_key)
|
||||
has_anthropic_key = bool(anthropic_api_key)
|
||||
has_cua_key = bool(cua_api_key)
|
||||
|
||||
# Get Ollama models for OMNI
|
||||
ollama_models = get_ollama_models()
|
||||
if ollama_models:
|
||||
omni_models += ollama_models
|
||||
|
||||
# Detect platform
|
||||
is_mac = platform.system().lower() == "darwin"
|
||||
|
||||
# Format model choices
|
||||
provider_to_models = {
|
||||
"OPENAI": openai_models,
|
||||
"ANTHROPIC": anthropic_models,
|
||||
"OMNI": omni_models + ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
|
||||
"UITARS": ([
|
||||
"huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
] if is_mac else []) + ["Custom model (OpenAI compatible API)"],
|
||||
}
|
||||
|
||||
# Apply saved settings
|
||||
initial_loop = saved_settings.get("agent_loop", "OMNI")
|
||||
available_models_for_loop = provider_to_models.get(initial_loop, [])
|
||||
saved_model_choice = saved_settings.get("model_choice")
|
||||
if saved_model_choice and saved_model_choice in available_models_for_loop:
|
||||
initial_model = saved_model_choice
|
||||
else:
|
||||
if initial_loop == "OPENAI":
|
||||
initial_model = openai_models[0] if openai_models else "No models available"
|
||||
elif initial_loop == "ANTHROPIC":
|
||||
initial_model = anthropic_models[0] if anthropic_models else "No models available"
|
||||
else: # OMNI
|
||||
initial_model = omni_models[0] if omni_models else "Custom model (OpenAI compatible API)"
|
||||
|
||||
initial_custom_model = saved_settings.get("custom_model", "Qwen2.5-VL-7B-Instruct")
|
||||
initial_provider_base_url = saved_settings.get("provider_base_url", "http://localhost:1234/v1")
|
||||
initial_save_trajectory = saved_settings.get("save_trajectory", True)
|
||||
initial_recent_images = saved_settings.get("recent_images", 3)
|
||||
|
||||
# Example prompts
|
||||
example_messages = [
|
||||
"Create a Python virtual environment, install pandas and matplotlib, then plot stock data",
|
||||
"Open a PDF in Preview, add annotations, and save it as a compressed version",
|
||||
"Open Safari, search for 'macOS automation tools', and save the first three results as bookmarks",
|
||||
"Configure SSH keys and set up a connection to a remote server",
|
||||
]
|
||||
|
||||
def generate_python_code(agent_loop_choice, model_name, tasks, recent_images=3, save_trajectory=True, computer_os="linux", computer_provider="cloud", container_name="", cua_cloud_api_key="", max_budget=None):
|
||||
"""Generate Python code for the current configuration and tasks."""
|
||||
tasks_str = ""
|
||||
for task in tasks:
|
||||
if task and task.strip():
|
||||
tasks_str += f' "{task}",\n'
|
||||
|
||||
model_string = get_model_string(model_name, agent_loop_choice)
|
||||
|
||||
computer_args = []
|
||||
if computer_os != "macos":
|
||||
computer_args.append(f'os_type="{computer_os}"')
|
||||
if computer_provider != "lume":
|
||||
computer_args.append(f'provider_type="{computer_provider}"')
|
||||
if container_name:
|
||||
computer_args.append(f'name="{container_name}"')
|
||||
if cua_cloud_api_key:
|
||||
computer_args.append(f'api_key="{cua_cloud_api_key}"')
|
||||
|
||||
computer_args_str = ", ".join(computer_args)
|
||||
if computer_args_str:
|
||||
computer_args_str = f"({computer_args_str})"
|
||||
else:
|
||||
computer_args_str = "()"
|
||||
|
||||
code = f'''import asyncio
|
||||
from computer import Computer
|
||||
from agent import ComputerAgent
|
||||
|
||||
async def main():
|
||||
async with Computer{computer_args_str} as computer:
|
||||
agent = ComputerAgent(
|
||||
model="{model_string}",
|
||||
tools=[computer],
|
||||
only_n_most_recent_images={recent_images},'''
|
||||
|
||||
if save_trajectory:
|
||||
code += '''
|
||||
trajectory_dir="trajectories",'''
|
||||
|
||||
if max_budget:
|
||||
code += f'''
|
||||
max_trajectory_budget={{"max_budget": {max_budget}, "raise_error": True}},'''
|
||||
|
||||
code += '''
|
||||
)
|
||||
'''
|
||||
|
||||
if tasks_str:
|
||||
code += f'''
|
||||
# Prompts for the computer-use agent
|
||||
tasks = [
|
||||
{tasks_str.rstrip()}
|
||||
]
|
||||
|
||||
for task in tasks:
|
||||
print(f"Executing task: {{task}}")
|
||||
messages = [{{"role": "user", "content": task}}]
|
||||
async for result in agent.run(messages):
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"])'''
|
||||
else:
|
||||
code += f'''
|
||||
# Execute a single task
|
||||
task = "Search for information about CUA on GitHub"
|
||||
print(f"Executing task: {{task}}")
|
||||
messages = [{{"role": "user", "content": task}}]
|
||||
async for result in agent.run(messages):
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"])'''
|
||||
|
||||
code += '''
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())'''
|
||||
|
||||
return code
|
||||
|
||||
# Create the Gradio interface
|
||||
with gr.Blocks(title="Computer-Use Agent") as demo:
|
||||
with gr.Row():
|
||||
# Left column for settings
|
||||
with gr.Column(scale=1):
|
||||
# Logo
|
||||
gr.HTML(
|
||||
"""
|
||||
<div style="display: flex; justify-content: center; margin-bottom: 0.5em">
|
||||
<img alt="CUA Logo" style="width: 80px;"
|
||||
src="https://github.com/trycua/cua/blob/main/img/logo_black.png?raw=true" />
|
||||
</div>
|
||||
"""
|
||||
)
|
||||
|
||||
# Python code accordion
|
||||
with gr.Accordion("Python Code", open=False):
|
||||
code_display = gr.Code(
|
||||
language="python",
|
||||
value=generate_python_code(initial_loop, "gpt-4o", []),
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Accordion("Computer Configuration", open=True):
|
||||
computer_os = gr.Radio(
|
||||
choices=["macos", "linux", "windows"],
|
||||
label="Operating System",
|
||||
value="macos",
|
||||
info="Select the operating system for the computer",
|
||||
)
|
||||
|
||||
is_windows = platform.system().lower() == "windows"
|
||||
is_mac = platform.system().lower() == "darwin"
|
||||
|
||||
providers = ["cloud"]
|
||||
if is_mac:
|
||||
providers += ["lume"]
|
||||
if is_windows:
|
||||
providers += ["winsandbox"]
|
||||
|
||||
computer_provider = gr.Radio(
|
||||
choices=providers,
|
||||
label="Provider",
|
||||
value="lume" if is_mac else "cloud",
|
||||
info="Select the computer provider",
|
||||
)
|
||||
|
||||
container_name = gr.Textbox(
|
||||
label="Container Name",
|
||||
placeholder="Enter container name (optional)",
|
||||
value=os.environ.get("CUA_CONTAINER_NAME", ""),
|
||||
info="Optional name for the container",
|
||||
)
|
||||
|
||||
cua_cloud_api_key = gr.Textbox(
|
||||
label="CUA Cloud API Key",
|
||||
placeholder="Enter your CUA Cloud API key",
|
||||
value=os.environ.get("CUA_API_KEY", ""),
|
||||
type="password",
|
||||
info="Required for cloud provider",
|
||||
visible=(not has_cua_key)
|
||||
)
|
||||
|
||||
with gr.Accordion("Agent Configuration", open=True):
|
||||
agent_loop = gr.Dropdown(
|
||||
choices=["OPENAI", "ANTHROPIC", "OMNI", "UITARS"],
|
||||
label="Agent Loop",
|
||||
value=initial_loop,
|
||||
info="Select the agent loop provider",
|
||||
)
|
||||
|
||||
# Model selection dropdowns
|
||||
with gr.Group() as model_selection_group:
|
||||
openai_model_choice = gr.Dropdown(
|
||||
choices=openai_models,
|
||||
label="OpenAI Model",
|
||||
value=openai_models[0] if openai_models else "No models available",
|
||||
info="Select OpenAI model",
|
||||
interactive=True,
|
||||
visible=(initial_loop == "OPENAI")
|
||||
)
|
||||
|
||||
anthropic_model_choice = gr.Dropdown(
|
||||
choices=anthropic_models,
|
||||
label="Anthropic Model",
|
||||
value=anthropic_models[0] if anthropic_models else "No models available",
|
||||
info="Select Anthropic model",
|
||||
interactive=True,
|
||||
visible=(initial_loop == "ANTHROPIC")
|
||||
)
|
||||
|
||||
omni_model_choice = gr.Dropdown(
|
||||
choices=omni_models + ["Custom model (OpenAI compatible API)", "Custom model (ollama)"],
|
||||
label="OMNI Model",
|
||||
value=omni_models[0] if omni_models else "Custom model (OpenAI compatible API)",
|
||||
info="Select OMNI model or choose a custom model option",
|
||||
interactive=True,
|
||||
visible=(initial_loop == "OMNI")
|
||||
)
|
||||
|
||||
uitars_model_choice = gr.Dropdown(
|
||||
choices=provider_to_models.get("UITARS", ["No models available"]),
|
||||
label="UITARS Model",
|
||||
value=provider_to_models.get("UITARS", ["No models available"])[0] if provider_to_models.get("UITARS") else "No models available",
|
||||
info="Select UITARS model",
|
||||
interactive=True,
|
||||
visible=(initial_loop == "UITARS")
|
||||
)
|
||||
|
||||
model_choice = gr.Textbox(visible=False)
|
||||
|
||||
# API key inputs
|
||||
with gr.Group(visible=not has_openai_key and (initial_loop == "OPENAI" or initial_loop == "OMNI")) as openai_key_group:
|
||||
openai_api_key_input = gr.Textbox(
|
||||
label="OpenAI API Key",
|
||||
placeholder="Enter your OpenAI API key",
|
||||
value=os.environ.get("OPENAI_API_KEY", ""),
|
||||
interactive=True,
|
||||
type="password",
|
||||
info="Required for OpenAI models"
|
||||
)
|
||||
|
||||
with gr.Group(visible=not has_anthropic_key and (initial_loop == "ANTHROPIC" or initial_loop == "OMNI")) as anthropic_key_group:
|
||||
anthropic_api_key_input = gr.Textbox(
|
||||
label="Anthropic API Key",
|
||||
placeholder="Enter your Anthropic API key",
|
||||
value=os.environ.get("ANTHROPIC_API_KEY", ""),
|
||||
interactive=True,
|
||||
type="password",
|
||||
info="Required for Anthropic models"
|
||||
)
|
||||
|
||||
# API key handlers
|
||||
def set_openai_api_key(key):
|
||||
if key and key.strip():
|
||||
os.environ["OPENAI_API_KEY"] = key.strip()
|
||||
print(f"DEBUG - Set OpenAI API key environment variable")
|
||||
return key
|
||||
|
||||
def set_anthropic_api_key(key):
|
||||
if key and key.strip():
|
||||
os.environ["ANTHROPIC_API_KEY"] = key.strip()
|
||||
print(f"DEBUG - Set Anthropic API key environment variable")
|
||||
return key
|
||||
|
||||
openai_api_key_input.change(
|
||||
fn=set_openai_api_key,
|
||||
inputs=[openai_api_key_input],
|
||||
outputs=[openai_api_key_input],
|
||||
queue=False
|
||||
)
|
||||
|
||||
anthropic_api_key_input.change(
|
||||
fn=set_anthropic_api_key,
|
||||
inputs=[anthropic_api_key_input],
|
||||
outputs=[anthropic_api_key_input],
|
||||
queue=False
|
||||
)
|
||||
|
||||
# UI update function
|
||||
def update_ui(loop=None, openai_model=None, anthropic_model=None, omni_model=None, uitars_model=None):
|
||||
loop = loop or agent_loop.value
|
||||
|
||||
model_value = None
|
||||
if loop == "OPENAI" and openai_model:
|
||||
model_value = openai_model
|
||||
elif loop == "ANTHROPIC" and anthropic_model:
|
||||
model_value = anthropic_model
|
||||
elif loop == "OMNI" and omni_model:
|
||||
model_value = omni_model
|
||||
elif loop == "UITARS" and uitars_model:
|
||||
model_value = uitars_model
|
||||
|
||||
openai_visible = (loop == "OPENAI")
|
||||
anthropic_visible = (loop == "ANTHROPIC")
|
||||
omni_visible = (loop == "OMNI")
|
||||
uitars_visible = (loop == "UITARS")
|
||||
|
||||
show_openai_key = not has_openai_key and (loop == "OPENAI" or (loop == "OMNI" and model_value and "OpenAI" in model_value and "Custom" not in model_value))
|
||||
show_anthropic_key = not has_anthropic_key and (loop == "ANTHROPIC" or (loop == "OMNI" and model_value and "Claude" in model_value and "Custom" not in model_value))
|
||||
|
||||
is_custom_openai_api = model_value == "Custom model (OpenAI compatible API)"
|
||||
is_custom_ollama = model_value == "Custom model (ollama)"
|
||||
is_any_custom = is_custom_openai_api or is_custom_ollama
|
||||
|
||||
model_choice_value = model_value if model_value else ""
|
||||
|
||||
return [
|
||||
gr.update(visible=openai_visible),
|
||||
gr.update(visible=anthropic_visible),
|
||||
gr.update(visible=omni_visible),
|
||||
gr.update(visible=uitars_visible),
|
||||
gr.update(visible=show_openai_key),
|
||||
gr.update(visible=show_anthropic_key),
|
||||
gr.update(visible=is_any_custom),
|
||||
gr.update(visible=is_custom_openai_api),
|
||||
gr.update(visible=is_custom_openai_api),
|
||||
gr.update(value=model_choice_value)
|
||||
]
|
||||
|
||||
# Custom model inputs
|
||||
custom_model = gr.Textbox(
|
||||
label="Custom Model Name",
|
||||
placeholder="Enter custom model name (e.g., Qwen2.5-VL-7B-Instruct or llama3)",
|
||||
value=initial_custom_model,
|
||||
visible=(initial_model == "Custom model (OpenAI compatible API)" or initial_model == "Custom model (ollama)"),
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
provider_base_url = gr.Textbox(
|
||||
label="Provider Base URL",
|
||||
placeholder="Enter provider base URL (e.g., http://localhost:1234/v1)",
|
||||
value=initial_provider_base_url,
|
||||
visible=(initial_model == "Custom model (OpenAI compatible API)"),
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
provider_api_key = gr.Textbox(
|
||||
label="Provider API Key",
|
||||
placeholder="Enter provider API key (if required)",
|
||||
value="",
|
||||
visible=(initial_model == "Custom model (OpenAI compatible API)"),
|
||||
interactive=True,
|
||||
type="password",
|
||||
)
|
||||
|
||||
# Connect UI update events
|
||||
for dropdown in [agent_loop, omni_model_choice, uitars_model_choice, openai_model_choice, anthropic_model_choice]:
|
||||
dropdown.change(
|
||||
fn=update_ui,
|
||||
inputs=[agent_loop, openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice],
|
||||
outputs=[
|
||||
openai_model_choice, anthropic_model_choice, omni_model_choice, uitars_model_choice,
|
||||
openai_key_group, anthropic_key_group,
|
||||
custom_model, provider_base_url, provider_api_key,
|
||||
model_choice
|
||||
],
|
||||
queue=False
|
||||
)
|
||||
|
||||
save_trajectory = gr.Checkbox(
|
||||
label="Save Trajectory",
|
||||
value=initial_save_trajectory,
|
||||
info="Save the agent's trajectory for debugging",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
recent_images = gr.Slider(
|
||||
label="Recent Images",
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
value=initial_recent_images,
|
||||
step=1,
|
||||
info="Number of recent images to keep in context",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
max_budget = gr.Number(
|
||||
label="Max Budget ($)",
|
||||
value=lambda: None,
|
||||
minimum=-1,
|
||||
maximum=100.0,
|
||||
step=0.1,
|
||||
info="Optional budget limit for trajectory (0 = no limit)",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
# Right column for chat interface
|
||||
with gr.Column(scale=2):
|
||||
gr.Markdown(
|
||||
"Ask me to perform tasks in a virtual environment.<br>Built with <a href='https://github.com/trycua/cua' target='_blank'>github.com/trycua/cua</a>."
|
||||
)
|
||||
|
||||
chatbot_history = gr.Chatbot(type="messages")
|
||||
msg = gr.Textbox(
|
||||
placeholder="Ask me to perform tasks in a virtual environment"
|
||||
)
|
||||
clear = gr.Button("Clear")
|
||||
cancel_button = gr.Button("Cancel", variant="stop")
|
||||
|
||||
# Add examples
|
||||
example_group = gr.Examples(examples=example_messages, inputs=msg)
|
||||
|
||||
# Chat submission function
|
||||
def chat_submit(message, history):
|
||||
history.append(gr.ChatMessage(role="user", content=message))
|
||||
return "", history
|
||||
|
||||
# Cancel function
|
||||
async def cancel_agent_task(history):
|
||||
global global_agent
|
||||
if global_agent:
|
||||
print("DEBUG - Cancelling agent task")
|
||||
history.append(gr.ChatMessage(role="assistant", content="Task cancelled by user", metadata={"title": "❌ Cancelled"}))
|
||||
else:
|
||||
history.append(gr.ChatMessage(role="assistant", content="No active agent task to cancel", metadata={"title": "ℹ️ Info"}))
|
||||
return history
|
||||
|
||||
# Process response function
|
||||
async def process_response(
|
||||
history,
|
||||
openai_model_value,
|
||||
anthropic_model_value,
|
||||
omni_model_value,
|
||||
uitars_model_value,
|
||||
custom_model_value,
|
||||
agent_loop_choice,
|
||||
save_traj,
|
||||
recent_imgs,
|
||||
custom_url_value=None,
|
||||
custom_api_key=None,
|
||||
openai_key_input=None,
|
||||
anthropic_key_input=None,
|
||||
computer_os="linux",
|
||||
computer_provider="cloud",
|
||||
container_name="",
|
||||
cua_cloud_api_key="",
|
||||
max_budget_value=None,
|
||||
):
|
||||
if not history:
|
||||
yield history
|
||||
return
|
||||
|
||||
# Get the last user message
|
||||
last_user_message = history[-1]["content"]
|
||||
|
||||
# Get the appropriate model value based on the agent loop
|
||||
if agent_loop_choice == "OPENAI":
|
||||
model_choice_value = openai_model_value
|
||||
elif agent_loop_choice == "ANTHROPIC":
|
||||
model_choice_value = anthropic_model_value
|
||||
elif agent_loop_choice == "OMNI":
|
||||
model_choice_value = omni_model_value
|
||||
elif agent_loop_choice == "UITARS":
|
||||
model_choice_value = uitars_model_value
|
||||
else:
|
||||
model_choice_value = "No models available"
|
||||
|
||||
# Determine if this is a custom model selection
|
||||
is_custom_model_selected = model_choice_value in ["Custom model (OpenAI compatible API)", "Custom model (ollama)"]
|
||||
|
||||
# Determine the model name string to analyze
|
||||
if is_custom_model_selected:
|
||||
model_string_to_analyze = custom_model_value
|
||||
else:
|
||||
model_string_to_analyze = model_choice_value
|
||||
|
||||
try:
|
||||
# Get the model string
|
||||
model_string = get_model_string(model_string_to_analyze, agent_loop_choice)
|
||||
|
||||
# Set API keys if provided
|
||||
if openai_key_input:
|
||||
os.environ["OPENAI_API_KEY"] = openai_key_input
|
||||
if anthropic_key_input:
|
||||
os.environ["ANTHROPIC_API_KEY"] = anthropic_key_input
|
||||
if cua_cloud_api_key:
|
||||
os.environ["CUA_API_KEY"] = cua_cloud_api_key
|
||||
|
||||
# Save settings
|
||||
current_settings = {
|
||||
"agent_loop": agent_loop_choice,
|
||||
"model_choice": model_choice_value,
|
||||
"custom_model": custom_model_value,
|
||||
"provider_base_url": custom_url_value,
|
||||
"save_trajectory": save_traj,
|
||||
"recent_images": recent_imgs,
|
||||
"computer_os": computer_os,
|
||||
"computer_provider": computer_provider,
|
||||
"container_name": container_name,
|
||||
}
|
||||
save_settings(current_settings)
|
||||
|
||||
# Create agent
|
||||
global_agent = create_agent(
|
||||
model_string=model_string,
|
||||
save_trajectory=save_traj,
|
||||
only_n_most_recent_images=recent_imgs,
|
||||
custom_model_name=custom_model_value if is_custom_model_selected else None,
|
||||
computer_os=computer_os,
|
||||
computer_provider=computer_provider,
|
||||
computer_name=container_name,
|
||||
computer_api_key=cua_cloud_api_key,
|
||||
verbosity=logging.DEBUG,
|
||||
max_trajectory_budget=max_budget_value if max_budget_value and max_budget_value > 0 else None,
|
||||
)
|
||||
|
||||
if global_agent is None:
|
||||
history.append(
|
||||
gr.ChatMessage(
|
||||
role="assistant",
|
||||
content="Failed to create agent. Check API keys and configuration.",
|
||||
)
|
||||
)
|
||||
yield history
|
||||
return
|
||||
|
||||
# Add user message to global history
|
||||
global global_messages
|
||||
global_messages.append({"role": "user", "content": last_user_message})
|
||||
|
||||
# Stream responses from the agent
|
||||
async for result in global_agent.run(global_messages):
|
||||
global_messages += result.get("output", [])
|
||||
# print(f"DEBUG - Agent response ------- START")
|
||||
# from pprint import pprint
|
||||
# pprint(result)
|
||||
# print(f"DEBUG - Agent response ------- END")
|
||||
|
||||
# Process the result output
|
||||
for item in result.get("output", []):
|
||||
if item.get("type") == "message":
|
||||
content = item.get("content", [])
|
||||
for content_part in content:
|
||||
if content_part.get("text"):
|
||||
history.append(gr.ChatMessage(
|
||||
role=item.get("role", "assistant"),
|
||||
content=content_part.get("text", ""),
|
||||
metadata=content_part.get("metadata", {})
|
||||
))
|
||||
elif item.get("type") == "computer_call":
|
||||
action = item.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
if action_type:
|
||||
action_title = f"🛠️ Performing {action_type}"
|
||||
if action.get("x") and action.get("y"):
|
||||
action_title += f" at ({action['x']}, {action['y']})"
|
||||
history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"```json\n{json.dumps(action)}\n```",
|
||||
metadata={"title": action_title}
|
||||
))
|
||||
elif item.get("type") == "function_call":
|
||||
function_name = item.get("name", "")
|
||||
arguments = item.get("arguments", "{}")
|
||||
history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"🔧 Calling function: {function_name}\n```json\n{arguments}\n```",
|
||||
metadata={"title": f"Function Call: {function_name}"}
|
||||
))
|
||||
elif item.get("type") == "function_call_output":
|
||||
output = item.get("output", "")
|
||||
history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=f"📤 Function output:\n```\n{output}\n```",
|
||||
metadata={"title": "Function Output"}
|
||||
))
|
||||
elif item.get("type") == "computer_call_output":
|
||||
output = item.get("output", {}).get("image_url", "")
|
||||
image_markdown = f""
|
||||
history.append(gr.ChatMessage(
|
||||
role="assistant",
|
||||
content=image_markdown,
|
||||
metadata={"title": "🖥️ Computer Output"}
|
||||
))
|
||||
|
||||
yield history
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
history.append(gr.ChatMessage(role="assistant", content=f"Error: {str(e)}"))
|
||||
yield history
|
||||
|
||||
# Connect the submit button
|
||||
submit_event = msg.submit(
|
||||
fn=chat_submit,
|
||||
inputs=[msg, chatbot_history],
|
||||
outputs=[msg, chatbot_history],
|
||||
queue=False,
|
||||
).then(
|
||||
fn=process_response,
|
||||
inputs=[
|
||||
chatbot_history,
|
||||
openai_model_choice,
|
||||
anthropic_model_choice,
|
||||
omni_model_choice,
|
||||
uitars_model_choice,
|
||||
custom_model,
|
||||
agent_loop,
|
||||
save_trajectory,
|
||||
recent_images,
|
||||
provider_base_url,
|
||||
provider_api_key,
|
||||
openai_api_key_input,
|
||||
anthropic_api_key_input,
|
||||
computer_os,
|
||||
computer_provider,
|
||||
container_name,
|
||||
cua_cloud_api_key,
|
||||
max_budget,
|
||||
],
|
||||
outputs=[chatbot_history],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
# Clear button functionality
|
||||
def clear_chat():
|
||||
global global_messages
|
||||
global_messages.clear()
|
||||
return None
|
||||
|
||||
clear.click(clear_chat, None, chatbot_history, queue=False)
|
||||
|
||||
# Connect cancel button
|
||||
cancel_button.click(
|
||||
cancel_agent_task,
|
||||
[chatbot_history],
|
||||
[chatbot_history],
|
||||
queue=False
|
||||
)
|
||||
|
||||
# Code display update function
|
||||
def update_code_display(agent_loop, model_choice_val, custom_model_val, chat_history, recent_images_val, save_trajectory_val, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget_val):
|
||||
messages = []
|
||||
if chat_history:
|
||||
for msg in chat_history:
|
||||
if isinstance(msg, dict) and msg.get("role") == "user":
|
||||
messages.append(msg.get("content", ""))
|
||||
|
||||
return generate_python_code(
|
||||
agent_loop,
|
||||
model_choice_val or custom_model_val or "gpt-4o",
|
||||
messages,
|
||||
recent_images_val,
|
||||
save_trajectory_val,
|
||||
computer_os,
|
||||
computer_provider,
|
||||
container_name,
|
||||
cua_cloud_api_key,
|
||||
max_budget_val
|
||||
)
|
||||
|
||||
# Update code display when configuration changes
|
||||
for component in [agent_loop, model_choice, custom_model, chatbot_history, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget]:
|
||||
component.change(
|
||||
update_code_display,
|
||||
inputs=[agent_loop, model_choice, custom_model, chatbot_history, recent_images, save_trajectory, computer_os, computer_provider, container_name, cua_cloud_api_key, max_budget],
|
||||
outputs=[code_display]
|
||||
)
|
||||
|
||||
return demo
|
||||
@@ -1,148 +0,0 @@
|
||||
"""
|
||||
Example usage of the agent library with docstring-based tool definitions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from agent import agent_loop, ComputerAgent
|
||||
from agent.types import Messages
|
||||
from computer import Computer
|
||||
from computer.helpers import sandboxed
|
||||
|
||||
@sandboxed()
|
||||
def read_file(location: str) -> str:
|
||||
"""Read contents of a file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
location : str
|
||||
Path to the file to read
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Contents of the file or error message
|
||||
"""
|
||||
try:
|
||||
with open(location, 'r') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
|
||||
def save_note(content: str, filename: str = "note.txt") -> str:
|
||||
"""Save content to a note file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
content : str
|
||||
Content to save to the file
|
||||
filename : str, optional
|
||||
Name of the file to save to (default is "note.txt")
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Success or error message
|
||||
"""
|
||||
try:
|
||||
with open(filename, 'w') as f:
|
||||
f.write(content)
|
||||
return f"Saved note to {filename}"
|
||||
except Exception as e:
|
||||
return f"Error saving note: {str(e)}"
|
||||
|
||||
def calculate(a: int, b: int) -> int:
|
||||
"""Calculate the sum of two integers
|
||||
|
||||
Parameters
|
||||
----------
|
||||
a : int
|
||||
First integer
|
||||
b : int
|
||||
Second integer
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Sum of the two integers
|
||||
"""
|
||||
return a + b
|
||||
|
||||
async def main():
|
||||
"""Example usage of ComputerAgent with different models"""
|
||||
|
||||
# Example 1: Using Claude with computer and custom tools
|
||||
print("=== Example 1: Claude with Computer ===")
|
||||
|
||||
import os
|
||||
import dotenv
|
||||
import json
|
||||
dotenv.load_dotenv()
|
||||
|
||||
assert os.getenv("CUA_CONTAINER_NAME") is not None, "CUA_CONTAINER_NAME is not set"
|
||||
assert os.getenv("CUA_API_KEY") is not None, "CUA_API_KEY is not set"
|
||||
|
||||
async with Computer(
|
||||
os_type="linux",
|
||||
provider_type="cloud",
|
||||
name=os.getenv("CUA_CONTAINER_NAME") or "",
|
||||
api_key=os.getenv("CUA_API_KEY") or ""
|
||||
) as computer:
|
||||
agent = ComputerAgent(
|
||||
# Supported models:
|
||||
|
||||
# == OpenAI CUA (computer-use-preview) ==
|
||||
model="openai/computer-use-preview",
|
||||
|
||||
# == Anthropic CUA (Claude > 3.5) ==
|
||||
# model="anthropic/claude-opus-4-20250514",
|
||||
# model="anthropic/claude-sonnet-4-20250514",
|
||||
# model="anthropic/claude-3-7-sonnet-20250219",
|
||||
# model="anthropic/claude-3-5-sonnet-20240620",
|
||||
|
||||
# == UI-TARS ==
|
||||
# model="huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
# TODO: add local mlx provider
|
||||
# model="mlx-community/UI-TARS-1.5-7B-6bit",
|
||||
# model="ollama_chat/0000/ui-tars-1.5-7b",
|
||||
|
||||
# == Omniparser + Any LLM ==
|
||||
# model="omniparser+..."
|
||||
# model="omniparser+anthropic/claude-opus-4-20250514",
|
||||
|
||||
tools=[computer],
|
||||
only_n_most_recent_images=3,
|
||||
verbosity=logging.INFO,
|
||||
trajectory_dir="trajectories",
|
||||
use_prompt_caching=True,
|
||||
max_trajectory_budget={ "max_budget": 1.0, "raise_error": True, "reset_after_each_run": False },
|
||||
)
|
||||
|
||||
history = []
|
||||
while True:
|
||||
user_input = input("> ")
|
||||
history.append({"role": "user", "content": user_input})
|
||||
|
||||
# Non-streaming usage
|
||||
async for result in agent.run(history, stream=False):
|
||||
history += result["output"]
|
||||
|
||||
# # Print output
|
||||
# for item in result["output"]:
|
||||
# if item["type"] == "message":
|
||||
# print(item["content"][0]["text"])
|
||||
# elif item["type"] == "computer_call":
|
||||
# action = item["action"]
|
||||
# action_type = action["type"]
|
||||
# action_args = {k: v for k, v in action.items() if k != "type"}
|
||||
# print(f"{action_type}({action_args})")
|
||||
# elif item["type"] == "function_call":
|
||||
# action = item["name"]
|
||||
# action_args = item["arguments"]
|
||||
# print(f"{action}({action_args})")
|
||||
# elif item["type"] == "function_call_output":
|
||||
# print("===>", item["output"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,71 +0,0 @@
|
||||
[build-system]
|
||||
requires = ["pdm-backend"]
|
||||
build-backend = "pdm.backend"
|
||||
|
||||
[project]
|
||||
name = "cua-agent"
|
||||
version = "0.4.0b4"
|
||||
description = "CUA (Computer Use) Agent for AI-driven computer interaction"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "TryCua", email = "gh@trycua.com" }
|
||||
]
|
||||
dependencies = [
|
||||
"httpx>=0.27.0",
|
||||
"aiohttp>=3.9.3",
|
||||
"asyncio",
|
||||
"anyio>=4.4.1",
|
||||
"typing-extensions>=4.12.2",
|
||||
"pydantic>=2.6.4",
|
||||
"rich>=13.7.1",
|
||||
"python-dotenv>=1.0.1",
|
||||
"cua-computer>=0.3.0,<0.5.0",
|
||||
"cua-core>=0.1.0,<0.2.0",
|
||||
"certifi>=2024.2.2",
|
||||
"litellm>=1.74.8"
|
||||
]
|
||||
requires-python = ">=3.11"
|
||||
|
||||
[project.optional-dependencies]
|
||||
openai = []
|
||||
anthropic = []
|
||||
omni = [
|
||||
"ultralytics>=8.0.0",
|
||||
"cua-som>=0.1.0,<0.2.0",
|
||||
]
|
||||
uitars = []
|
||||
uitars-mlx = [
|
||||
"mlx-vlm>=0.1.27; sys_platform == 'darwin'"
|
||||
]
|
||||
uitars-hf = [
|
||||
"transformers>=4.54.0"
|
||||
]
|
||||
ui = [
|
||||
"gradio>=5.23.3",
|
||||
"python-dotenv>=1.0.1",
|
||||
]
|
||||
cli = [
|
||||
"yaspin>=3.1.0",
|
||||
]
|
||||
all = [
|
||||
# omni requirements
|
||||
"ultralytics>=8.0.0",
|
||||
"cua-som>=0.1.0,<0.2.0",
|
||||
# uitars requirements
|
||||
"mlx-vlm>=0.1.27; sys_platform == 'darwin'",
|
||||
"transformers>=4.54.0",
|
||||
# ui requirements
|
||||
"gradio>=5.23.3",
|
||||
"python-dotenv>=1.0.1",
|
||||
# cli requirements
|
||||
"yaspin>=3.1.0",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
constraint-dependencies = ["fastrtc>0.43.0", "mlx-audio>0.2.3"]
|
||||
|
||||
[tool.pdm]
|
||||
distribution = true
|
||||
|
||||
[tool.pdm.build]
|
||||
includes = ["agent/"]
|
||||
Reference in New Issue
Block a user