mirror of
https://github.com/trycua/computer.git
synced 2025-12-31 02:19:58 -06:00
Merge branch 'main' into feat/docs/init
This commit is contained in:
@@ -50,7 +50,7 @@
|
||||
# 🚀 Quick Start
|
||||
|
||||
Read our guide on getting started with a Computer-Use Agent:
|
||||
[Computer-Use Agent Quickstart](https://docs.trycua.com/home/guides/usage-guide)
|
||||
[Computer-Use Agent Quickstart](https://trycua.com/docs/guides/usage-guide)
|
||||
|
||||
Get started using Cua services on your machine:
|
||||
[Cua Usage Guide](https://docs.trycua.com/home/guides/cua-usage-guide)
|
||||
|
||||
@@ -8,7 +8,7 @@ import signal
|
||||
from computer import Computer, VMProviderType
|
||||
|
||||
# Import the unified agent class and types
|
||||
from agent import ComputerAgent, LLMProvider, LLM, AgentLoop
|
||||
from agent import ComputerAgent
|
||||
|
||||
# Import utility functions
|
||||
from utils import load_dotenv_files, handle_sigint
|
||||
@@ -19,8 +19,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def run_agent_example():
|
||||
"""Run example of using the ComputerAgent with OpenAI and Omni provider."""
|
||||
print("\n=== Example: ComputerAgent with OpenAI and Omni provider ===")
|
||||
"""Run example of using the ComputerAgent with different models."""
|
||||
print("\n=== Example: ComputerAgent with different models ===")
|
||||
|
||||
try:
|
||||
# Create a local macOS computer
|
||||
@@ -37,28 +37,37 @@ async def run_agent_example():
|
||||
# provider_type=VMProviderType.CLOUD,
|
||||
# )
|
||||
|
||||
# Create Computer instance with async context manager
|
||||
# Create ComputerAgent with new API
|
||||
agent = ComputerAgent(
|
||||
computer=computer,
|
||||
loop=AgentLoop.OPENAI,
|
||||
# loop=AgentLoop.ANTHROPIC,
|
||||
# loop=AgentLoop.UITARS,
|
||||
# loop=AgentLoop.OMNI,
|
||||
model=LLM(provider=LLMProvider.OPENAI), # No model name for Operator CUA
|
||||
# model=LLM(provider=LLMProvider.OPENAI, name="gpt-4o"),
|
||||
# model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"),
|
||||
# model=LLM(provider=LLMProvider.OLLAMA, name="gemma3:4b-it-q4_K_M"),
|
||||
# model=LLM(provider=LLMProvider.MLXVLM, name="mlx-community/UI-TARS-1.5-7B-4bit"),
|
||||
# model=LLM(
|
||||
# provider=LLMProvider.OAICOMPAT,
|
||||
# name="gemma-3-12b-it",
|
||||
# provider_base_url="http://localhost:1234/v1", # LM Studio local endpoint
|
||||
# ),
|
||||
save_trajectory=True,
|
||||
# Supported models:
|
||||
|
||||
# == OpenAI CUA (computer-use-preview) ==
|
||||
model="openai/computer-use-preview",
|
||||
|
||||
# == Anthropic CUA (Claude > 3.5) ==
|
||||
# model="anthropic/claude-opus-4-20250514",
|
||||
# model="anthropic/claude-sonnet-4-20250514",
|
||||
# model="anthropic/claude-3-7-sonnet-20250219",
|
||||
# model="anthropic/claude-3-5-sonnet-20240620",
|
||||
|
||||
# == UI-TARS ==
|
||||
# model="huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
# model="mlx/mlx-community/UI-TARS-1.5-7B-6bit",
|
||||
# model="ollama_chat/0000/ui-tars-1.5-7b",
|
||||
|
||||
# == Omniparser + Any LLM ==
|
||||
# model="omniparser+anthropic/claude-opus-4-20250514",
|
||||
# model="omniparser+ollama_chat/gemma3:12b-it-q4_K_M",
|
||||
|
||||
tools=[computer],
|
||||
only_n_most_recent_images=3,
|
||||
verbosity=logging.DEBUG,
|
||||
trajectory_dir="trajectories",
|
||||
use_prompt_caching=True,
|
||||
max_trajectory_budget=1.0,
|
||||
)
|
||||
|
||||
# Example tasks to demonstrate the agent
|
||||
tasks = [
|
||||
"Look for a repository named trycua/cua on GitHub.",
|
||||
"Check the open issues, open the most recent one and read it.",
|
||||
@@ -68,43 +77,35 @@ async def run_agent_example():
|
||||
"Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.",
|
||||
]
|
||||
|
||||
# Use message-based conversation history
|
||||
history = []
|
||||
|
||||
for i, task in enumerate(tasks):
|
||||
print(f"\nExecuting task {i}/{len(tasks)}: {task}")
|
||||
async for result in agent.run(task):
|
||||
print("Response ID: ", result.get("id"))
|
||||
|
||||
# Print detailed usage information
|
||||
usage = result.get("usage")
|
||||
if usage:
|
||||
print("\nUsage Details:")
|
||||
print(f" Input Tokens: {usage.get('input_tokens')}")
|
||||
if "input_tokens_details" in usage:
|
||||
print(f" Input Tokens Details: {usage.get('input_tokens_details')}")
|
||||
print(f" Output Tokens: {usage.get('output_tokens')}")
|
||||
if "output_tokens_details" in usage:
|
||||
print(f" Output Tokens Details: {usage.get('output_tokens_details')}")
|
||||
print(f" Total Tokens: {usage.get('total_tokens')}")
|
||||
|
||||
print("Response Text: ", result.get("text"))
|
||||
|
||||
# Print tools information
|
||||
tools = result.get("tools")
|
||||
if tools:
|
||||
print("\nTools:")
|
||||
print(tools)
|
||||
|
||||
# Print reasoning and tool call outputs
|
||||
outputs = result.get("output", [])
|
||||
for output in outputs:
|
||||
output_type = output.get("type")
|
||||
if output_type == "reasoning":
|
||||
print("\nReasoning Output:")
|
||||
print(output)
|
||||
elif output_type == "computer_call":
|
||||
print("\nTool Call Output:")
|
||||
print(output)
|
||||
|
||||
print(f"\n✅ Task {i+1}/{len(tasks)} completed: {task}")
|
||||
print(f"\nExecuting task {i+1}/{len(tasks)}: {task}")
|
||||
|
||||
# Add user message to history
|
||||
history.append({"role": "user", "content": task})
|
||||
|
||||
# Run agent with conversation history
|
||||
async for result in agent.run(history, stream=False):
|
||||
# Add agent outputs to history
|
||||
history += result.get("output", [])
|
||||
|
||||
# Print output for debugging
|
||||
for item in result.get("output", []):
|
||||
if item.get("type") == "message":
|
||||
content = item.get("content", [])
|
||||
for content_part in content:
|
||||
if content_part.get("text"):
|
||||
print(f"Agent: {content_part.get('text')}")
|
||||
elif item.get("type") == "computer_call":
|
||||
action = item.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
print(f"Computer Action: {action_type}({action})")
|
||||
elif item.get("type") == "computer_call_output":
|
||||
print("Computer Output: [Screenshot/Result]")
|
||||
|
||||
print(f"✅ Task {i+1}/{len(tasks)} completed: {task}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_agent_example: {e}")
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
<h1>
|
||||
<div class="image-wrapper" style="display: inline-block;">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="../../img/logo_white.png" style="display: block; margin: auto;">
|
||||
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="../../img/logo_black.png" style="display: block; margin: auto;">
|
||||
<source media="(prefers-color-scheme: dark)" alt="logo" height="150" srcset="../../../img/logo_white.png" style="display: block; margin: auto;">
|
||||
<source media="(prefers-color-scheme: light)" alt="logo" height="150" srcset="../../../img/logo_black.png" style="display: block; margin: auto;">
|
||||
<img alt="Shows my svg">
|
||||
</picture>
|
||||
</div>
|
||||
@@ -15,208 +15,367 @@
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
**cua-agent** is a general Computer-Use framework for running multi-app agentic workflows targeting macOS and Linux sandbox created with Cua, supporting local (Ollama) and cloud model providers (OpenAI, Anthropic, Groq, DeepSeek, Qwen).
|
||||
**cua-agent** is a general Computer-Use framework with liteLLM integration for running agentic workflows on macOS, Windows, and Linux sandboxes. It provides a unified interface for computer-use agents across multiple LLM providers with advanced callback system for extensibility.
|
||||
|
||||
### Get started with Agent
|
||||
## Features
|
||||
|
||||
<div align="center">
|
||||
<img src="../../img/agent.png"/>
|
||||
</div>
|
||||
- **Safe Computer-Use/Tool-Use**: Using Computer SDK for sandboxed desktops
|
||||
- **Multi-Agent Support**: Anthropic Claude, OpenAI computer-use-preview, UI-TARS, Omniparser + any LLM
|
||||
- **Multi-API Support**: Take advantage of liteLLM supporting 100+ LLMs / model APIs, including local models (`huggingface-local/`, `ollama_chat/`, `mlx/`)
|
||||
- **Cross-Platform**: Works on Windows, macOS, and Linux with cloud and local computer instances
|
||||
- **Extensible Callbacks**: Built-in support for image retention, cache control, PII anonymization, budget limits, and trajectory tracking
|
||||
|
||||
## Install
|
||||
|
||||
```bash
|
||||
pip install "cua-agent[all]"
|
||||
|
||||
# or install specific loop providers
|
||||
pip install "cua-agent[openai]" # OpenAI Cua Loop
|
||||
pip install "cua-agent[anthropic]" # Anthropic Cua Loop
|
||||
pip install "cua-agent[uitars]" # UI-Tars support
|
||||
pip install "cua-agent[omni]" # Cua Loop based on OmniParser (includes Ollama for local models)
|
||||
pip install "cua-agent[ui]" # Gradio UI for the agent
|
||||
pip install "cua-agent[uitars-mlx]" # MLX UI-Tars support
|
||||
# or install specific providers
|
||||
pip install "cua-agent[openai]" # OpenAI computer-use-preview support
|
||||
pip install "cua-agent[anthropic]" # Anthropic Claude support
|
||||
pip install "cua-agent[omni]" # Omniparser + any LLM support
|
||||
pip install "cua-agent[uitars]" # UI-TARS
|
||||
pip install "cua-agent[uitars-mlx]" # UI-TARS + MLX support
|
||||
pip install "cua-agent[uitars-hf]" # UI-TARS + Huggingface support
|
||||
pip install "cua-agent[ui]" # Gradio UI support
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
async with Computer() as macos_computer:
|
||||
# Create agent with loop and provider
|
||||
agent = ComputerAgent(
|
||||
computer=macos_computer,
|
||||
loop=AgentLoop.OPENAI,
|
||||
model=LLM(provider=LLMProvider.OPENAI)
|
||||
# or
|
||||
# loop=AgentLoop.ANTHROPIC,
|
||||
# model=LLM(provider=LLMProvider.ANTHROPIC)
|
||||
# or
|
||||
# loop=AgentLoop.OMNI,
|
||||
# model=LLM(provider=LLMProvider.OLLAMA, name="gemma3")
|
||||
# or
|
||||
# loop=AgentLoop.UITARS,
|
||||
# model=LLM(provider=LLMProvider.OAICOMPAT, name="ByteDance-Seed/UI-TARS-1.5-7B", provider_base_url="https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1")
|
||||
)
|
||||
|
||||
tasks = [
|
||||
"Look for a repository named trycua/cua on GitHub.",
|
||||
"Check the open issues, open the most recent one and read it.",
|
||||
"Clone the repository in users/lume/projects if it doesn't exist yet.",
|
||||
"Open the repository with an app named Cursor (on the dock, black background and white cube icon).",
|
||||
"From Cursor, open Composer if not already open.",
|
||||
"Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.",
|
||||
]
|
||||
|
||||
for i, task in enumerate(tasks):
|
||||
print(f"\nExecuting task {i}/{len(tasks)}: {task}")
|
||||
async for result in agent.run(task):
|
||||
print(result)
|
||||
|
||||
print(f"\n✅ Task {i+1}/{len(tasks)} completed: {task}")
|
||||
```
|
||||
|
||||
Refer to these notebooks for step-by-step guides on how to use the Computer-Use Agent (CUA):
|
||||
|
||||
- [Agent Notebook](../../notebooks/agent_nb.ipynb) - Complete examples and workflows
|
||||
|
||||
## Using the Gradio UI
|
||||
|
||||
The agent includes a Gradio-based user interface for easier interaction.
|
||||
|
||||
<div align="center">
|
||||
<img src="../../img/agent_gradio_ui.png"/>
|
||||
</div>
|
||||
|
||||
To use it:
|
||||
|
||||
```bash
|
||||
# Install with Gradio support
|
||||
pip install "cua-agent[ui]"
|
||||
```
|
||||
|
||||
### Create a simple launcher script
|
||||
## Quick Start
|
||||
|
||||
```python
|
||||
# launch_ui.py
|
||||
from agent.ui.gradio.app import create_gradio_ui
|
||||
import asyncio
|
||||
import os
|
||||
from agent import ComputerAgent
|
||||
from computer import Computer
|
||||
|
||||
app = create_gradio_ui()
|
||||
app.launch(share=False)
|
||||
async def main():
|
||||
# Set up computer instance
|
||||
async with Computer(
|
||||
os_type="linux",
|
||||
provider_type="cloud",
|
||||
name=os.getenv("CUA_CONTAINER_NAME"),
|
||||
api_key=os.getenv("CUA_API_KEY")
|
||||
) as computer:
|
||||
|
||||
# Create agent
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
tools=[computer],
|
||||
only_n_most_recent_images=3,
|
||||
trajectory_dir="trajectories",
|
||||
max_trajectory_budget=5.0 # $5 budget limit
|
||||
)
|
||||
|
||||
# Run agent
|
||||
messages = [{"role": "user", "content": "Take a screenshot and tell me what you see"}]
|
||||
|
||||
async for result in agent.run(messages):
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
### Setting up API Keys
|
||||
## Supported Models
|
||||
|
||||
For the Gradio UI to show available models, you need to set API keys as environment variables:
|
||||
|
||||
```bash
|
||||
# For OpenAI models
|
||||
export OPENAI_API_KEY=your_openai_key_here
|
||||
|
||||
# For Anthropic models
|
||||
export ANTHROPIC_API_KEY=your_anthropic_key_here
|
||||
|
||||
# Launch with both keys set
|
||||
OPENAI_API_KEY=your_key ANTHROPIC_API_KEY=your_key python launch_ui.py
|
||||
### Anthropic Claude (Computer Use API)
|
||||
```python
|
||||
model="anthropic/claude-3-5-sonnet-20241022"
|
||||
model="anthropic/claude-3-5-sonnet-20240620"
|
||||
model="anthropic/claude-opus-4-20250514"
|
||||
model="anthropic/claude-sonnet-4-20250514"
|
||||
```
|
||||
|
||||
Without these environment variables, the UI will show "No models available" for the corresponding providers, but you can still use local models with the OMNI loop provider.
|
||||
### OpenAI Computer Use Preview
|
||||
```python
|
||||
model="openai/computer-use-preview"
|
||||
```
|
||||
|
||||
### Using Local Models
|
||||
### UI-TARS (Local or Huggingface Inference)
|
||||
```python
|
||||
model="huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B"
|
||||
model="ollama_chat/0000/ui-tars-1.5-7b"
|
||||
```
|
||||
|
||||
You can use local models with the OMNI loop provider by selecting "Custom model..." from the dropdown. The default provider URL is set to `http://localhost:1234/v1` which works with LM Studio.
|
||||
### Omniparser + Any LLM
|
||||
```python
|
||||
model="omniparser+ollama_chat/mistral-small3.2"
|
||||
model="omniparser+vertex_ai/gemini-pro"
|
||||
model="omniparser+anthropic/claude-3-5-sonnet-20241022"
|
||||
model="omniparser+openai/gpt-4o"
|
||||
```
|
||||
|
||||
If you're using a different local model server:
|
||||
- vLLM: `http://localhost:8000/v1`
|
||||
- LocalAI: `http://localhost:8080/v1`
|
||||
- Ollama with OpenAI compat API: `http://localhost:11434/v1`
|
||||
## Custom Tools
|
||||
|
||||
The Gradio UI provides:
|
||||
- Selection of different agent loops (OpenAI, Anthropic, OMNI)
|
||||
- Model selection for each provider
|
||||
- Configuration of agent parameters
|
||||
- Chat interface for interacting with the agent
|
||||
|
||||
### Using UI-TARS
|
||||
|
||||
The UI-TARS models are available in two forms:
|
||||
|
||||
1. **MLX UI-TARS models** (Default): These models run locally using MLXVLM provider
|
||||
- `mlx-community/UI-TARS-1.5-7B-4bit` (default) - 4-bit quantized version
|
||||
- `mlx-community/UI-TARS-1.5-7B-6bit` - 6-bit quantized version for higher quality
|
||||
|
||||
```python
|
||||
agent = ComputerAgent(
|
||||
computer=macos_computer,
|
||||
loop=AgentLoop.UITARS,
|
||||
model=LLM(provider=LLMProvider.MLXVLM, name="mlx-community/UI-TARS-1.5-7B-4bit")
|
||||
)
|
||||
```
|
||||
|
||||
2. **OpenAI-compatible UI-TARS**: For using the original ByteDance model
|
||||
- If you want to use the original ByteDance UI-TARS model via an OpenAI-compatible API, follow the [deployment guide](https://github.com/bytedance/UI-TARS/blob/main/README_deploy.md)
|
||||
- This will give you a provider URL like `https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1` which you can use in the code or Gradio UI:
|
||||
|
||||
```python
|
||||
agent = ComputerAgent(
|
||||
computer=macos_computer,
|
||||
loop=AgentLoop.UITARS,
|
||||
model=LLM(provider=LLMProvider.OAICOMPAT, name="tgi",
|
||||
provider_base_url="https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1")
|
||||
)
|
||||
```
|
||||
|
||||
## Agent Loops
|
||||
|
||||
The `cua-agent` package provides three agent loops variations, based on different CUA models providers and techniques:
|
||||
|
||||
| Agent Loop | Supported Models | Description | Set-Of-Marks |
|
||||
|:-----------|:-----------------|:------------|:-------------|
|
||||
| `AgentLoop.OPENAI` | • `computer_use_preview` | Use OpenAI Operator CUA model | Not Required |
|
||||
| `AgentLoop.ANTHROPIC` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219` | Use Anthropic Computer-Use | Not Required |
|
||||
| `AgentLoop.UITARS` | • `mlx-community/UI-TARS-1.5-7B-4bit` (default)<br>• `mlx-community/UI-TARS-1.5-7B-6bit`<br>• `ByteDance-Seed/UI-TARS-1.5-7B` (via openAI-compatible endpoint) | Uses UI-TARS models with MLXVLM (default) or OAICOMPAT providers | Not Required |
|
||||
| `AgentLoop.OMNI` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219`<br>• `gpt-4.5-preview`<br>• `gpt-4o`<br>• `gpt-4`<br>• `phi4`<br>• `phi4-mini`<br>• `gemma3`<br>• `...`<br>• `Any Ollama or OpenAI-compatible model` | Use OmniParser for element pixel-detection (SoM) and any VLMs for UI Grounding and Reasoning | OmniParser |
|
||||
|
||||
## AgentResponse
|
||||
The `AgentResponse` class represents the structured output returned after each agent turn. It contains the agent's response, reasoning, tool usage, and other metadata. The response format aligns with the new [OpenAI Agent SDK specification](https://platform.openai.com/docs/api-reference/responses) for better consistency across different agent loops.
|
||||
Define custom tools using decorated functions:
|
||||
|
||||
```python
|
||||
async for result in agent.run(task):
|
||||
print("Response ID: ", result.get("id"))
|
||||
from computer.helpers import sandboxed
|
||||
|
||||
# Print detailed usage information
|
||||
usage = result.get("usage")
|
||||
if usage:
|
||||
print("\nUsage Details:")
|
||||
print(f" Input Tokens: {usage.get('input_tokens')}")
|
||||
if "input_tokens_details" in usage:
|
||||
print(f" Input Tokens Details: {usage.get('input_tokens_details')}")
|
||||
print(f" Output Tokens: {usage.get('output_tokens')}")
|
||||
if "output_tokens_details" in usage:
|
||||
print(f" Output Tokens Details: {usage.get('output_tokens_details')}")
|
||||
print(f" Total Tokens: {usage.get('total_tokens')}")
|
||||
@sandboxed()
|
||||
def read_file(location: str) -> str:
|
||||
"""Read contents of a file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
location : str
|
||||
Path to the file to read
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
Contents of the file or error message
|
||||
"""
|
||||
try:
|
||||
with open(location, 'r') as f:
|
||||
return f.read()
|
||||
except Exception as e:
|
||||
return f"Error reading file: {str(e)}"
|
||||
|
||||
print("Response Text: ", result.get("text"))
|
||||
def calculate(a: int, b: int) -> int:
|
||||
"""Calculate the sum of two integers"""
|
||||
return a + b
|
||||
|
||||
# Print tools information
|
||||
tools = result.get("tools")
|
||||
if tools:
|
||||
print("\nTools:")
|
||||
print(tools)
|
||||
|
||||
# Print reasoning and tool call outputs
|
||||
outputs = result.get("output", [])
|
||||
for output in outputs:
|
||||
output_type = output.get("type")
|
||||
if output_type == "reasoning":
|
||||
print("\nReasoning Output:")
|
||||
print(output)
|
||||
elif output_type == "computer_call":
|
||||
print("\nTool Call Output:")
|
||||
print(output)
|
||||
# Use with agent
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
tools=[computer, read_file, calculate]
|
||||
)
|
||||
```
|
||||
|
||||
**Note on Settings Persistence:**
|
||||
## Callbacks System
|
||||
|
||||
* The Gradio UI automatically saves your configuration (Agent Loop, Model Choice, Custom Base URL, Save Trajectory state, Recent Images count) to a file named `.gradio_settings.json` in the project's root directory when you successfully run a task.
|
||||
* This allows your preferences to persist between sessions.
|
||||
* API keys entered into the custom provider field are **not** saved in this file for security reasons. Manage API keys using environment variables (e.g., `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`) or a `.env` file.
|
||||
* It's recommended to add `.gradio_settings.json` to your `.gitignore` file.
|
||||
agent provides a comprehensive callback system for extending functionality:
|
||||
|
||||
### Built-in Callbacks
|
||||
|
||||
```python
|
||||
from agent.callbacks import (
|
||||
ImageRetentionCallback,
|
||||
TrajectorySaverCallback,
|
||||
BudgetManagerCallback,
|
||||
LoggingCallback
|
||||
)
|
||||
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
tools=[computer],
|
||||
callbacks=[
|
||||
ImageRetentionCallback(only_n_most_recent_images=3),
|
||||
TrajectorySaverCallback(trajectory_dir="trajectories"),
|
||||
BudgetManagerCallback(max_budget=10.0, raise_error=True),
|
||||
LoggingCallback(level=logging.INFO)
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Custom Callbacks
|
||||
|
||||
```python
|
||||
from agent.callbacks.base import AsyncCallbackHandler
|
||||
|
||||
class CustomCallback(AsyncCallbackHandler):
|
||||
async def on_llm_start(self, messages):
|
||||
"""Preprocess messages before LLM call"""
|
||||
# Add custom preprocessing logic
|
||||
return messages
|
||||
|
||||
async def on_llm_end(self, messages):
|
||||
"""Postprocess messages after LLM call"""
|
||||
# Add custom postprocessing logic
|
||||
return messages
|
||||
|
||||
async def on_usage(self, usage):
|
||||
"""Track usage information"""
|
||||
print(f"Tokens used: {usage.total_tokens}")
|
||||
```
|
||||
|
||||
## Budget Management
|
||||
|
||||
Control costs with built-in budget management:
|
||||
|
||||
```python
|
||||
# Simple budget limit
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
max_trajectory_budget=5.0 # $5 limit
|
||||
)
|
||||
|
||||
# Advanced budget configuration
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
max_trajectory_budget={
|
||||
"max_budget": 10.0,
|
||||
"raise_error": True, # Raise error when exceeded
|
||||
"reset_after_each_run": False # Persistent across runs
|
||||
}
|
||||
)
|
||||
```
|
||||
|
||||
## Trajectory Management
|
||||
|
||||
Save and replay agent conversations:
|
||||
|
||||
```python
|
||||
agent = ComputerAgent(
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
trajectory_dir="trajectories", # Auto-save trajectories
|
||||
tools=[computer]
|
||||
)
|
||||
|
||||
# Trajectories are saved with:
|
||||
# - Complete conversation history
|
||||
# - Usage statistics and costs
|
||||
# - Timestamps and metadata
|
||||
# - Screenshots and computer actions
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### ComputerAgent Parameters
|
||||
|
||||
- `model`: Model identifier (required)
|
||||
- `tools`: List of computer objects and decorated functions
|
||||
- `callbacks`: List of callback handlers for extensibility
|
||||
- `only_n_most_recent_images`: Limit recent images to prevent context overflow
|
||||
- `verbosity`: Logging level (logging.INFO, logging.DEBUG, etc.)
|
||||
- `trajectory_dir`: Directory to save conversation trajectories
|
||||
- `max_retries`: Maximum API call retries (default: 3)
|
||||
- `screenshot_delay`: Delay between actions and screenshots (default: 0.5s)
|
||||
- `use_prompt_caching`: Enable prompt caching for supported models
|
||||
- `max_trajectory_budget`: Budget limit configuration
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Computer instance (cloud)
|
||||
export CUA_CONTAINER_NAME="your-container-name"
|
||||
export CUA_API_KEY="your-cua-api-key"
|
||||
|
||||
# LLM API keys
|
||||
export ANTHROPIC_API_KEY="your-anthropic-key"
|
||||
export OPENAI_API_KEY="your-openai-key"
|
||||
```
|
||||
|
||||
## Advanced Usage
|
||||
|
||||
### Streaming Responses
|
||||
|
||||
```python
|
||||
async for result in agent.run(messages, stream=True):
|
||||
# Process streaming chunks
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"], end="", flush=True)
|
||||
elif item["type"] == "computer_call":
|
||||
action = item["action"]
|
||||
print(f"\n[Action: {action['type']}]")
|
||||
```
|
||||
|
||||
### Interactive Chat Loop
|
||||
|
||||
```python
|
||||
history = []
|
||||
while True:
|
||||
user_input = input("> ")
|
||||
if user_input.lower() in ['quit', 'exit']:
|
||||
break
|
||||
|
||||
history.append({"role": "user", "content": user_input})
|
||||
|
||||
async for result in agent.run(history):
|
||||
history += result["output"]
|
||||
|
||||
# Display assistant responses
|
||||
for item in result["output"]:
|
||||
if item["type"] == "message":
|
||||
print(item["content"][0]["text"])
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
try:
|
||||
async for result in agent.run(messages):
|
||||
# Process results
|
||||
pass
|
||||
except BudgetExceededException:
|
||||
print("Budget limit exceeded")
|
||||
except Exception as e:
|
||||
print(f"Agent error: {e}")
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
### ComputerAgent.run()
|
||||
|
||||
```python
|
||||
async def run(
|
||||
self,
|
||||
messages: Messages,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Run the agent with the given messages.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
stream: Whether to stream the response
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
AsyncGenerator that yields response chunks
|
||||
"""
|
||||
```
|
||||
|
||||
### Message Format
|
||||
|
||||
```python
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Take a screenshot and describe what you see"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I'll take a screenshot for you."
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
### Response Format
|
||||
|
||||
```python
|
||||
{
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "output_text", "text": "I can see..."}]
|
||||
},
|
||||
{
|
||||
"type": "computer_call",
|
||||
"action": {"type": "screenshot"},
|
||||
"call_id": "call_123"
|
||||
},
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": "call_123",
|
||||
"output": {"image_url": "data:image/png;base64,..."}
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 150,
|
||||
"completion_tokens": 75,
|
||||
"total_tokens": 225,
|
||||
"response_cost": 0.01,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
MIT License - see LICENSE file for details.
|
||||
@@ -1,12 +1,27 @@
|
||||
"""CUA (Computer Use) Agent for AI-driven computer interaction."""
|
||||
"""
|
||||
agent - Decorator-based Computer Use Agent with liteLLM integration
|
||||
"""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
import sys
|
||||
|
||||
__version__ = "0.1.0"
|
||||
from .decorators import agent_loop
|
||||
from .agent import ComputerAgent
|
||||
from .types import Messages, AgentResponse
|
||||
|
||||
# Initialize logging
|
||||
logger = logging.getLogger("agent")
|
||||
# Import loops to register them
|
||||
from . import loops
|
||||
|
||||
__all__ = [
|
||||
"agent_loop",
|
||||
"ComputerAgent",
|
||||
"Messages",
|
||||
"AgentResponse"
|
||||
]
|
||||
|
||||
__version__ = "0.4.0"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize telemetry when the package is imported
|
||||
try:
|
||||
@@ -18,7 +33,7 @@ try:
|
||||
)
|
||||
|
||||
# Import set_dimension from our own telemetry module
|
||||
from .core.telemetry import set_dimension
|
||||
from .telemetry import set_dimension
|
||||
|
||||
# Check if telemetry is enabled
|
||||
if is_telemetry_enabled():
|
||||
@@ -47,9 +62,3 @@ except ImportError as e:
|
||||
except Exception as e:
|
||||
# Other issues with telemetry
|
||||
logger.warning(f"Error initializing telemetry: {e}")
|
||||
|
||||
from .core.types import LLMProvider, LLM
|
||||
from .core.factory import AgentLoop
|
||||
from .core.agent import ComputerAgent
|
||||
|
||||
__all__ = ["AgentLoop", "LLMProvider", "LLM", "ComputerAgent"]
|
||||
|
||||
21
libs/python/agent/agent/__main__.py
Normal file
21
libs/python/agent/agent/__main__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Entry point for running agent CLI module.
|
||||
|
||||
Usage:
|
||||
python -m agent.cli <model_string>
|
||||
"""
|
||||
|
||||
import sys
|
||||
import asyncio
|
||||
from .cli import main
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Check if 'cli' is specified as the module
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "cli":
|
||||
# Remove 'cli' from arguments and run CLI
|
||||
sys.argv.pop(1)
|
||||
asyncio.run(main())
|
||||
else:
|
||||
print("Usage: python -m agent.cli <model_string>")
|
||||
print("Example: python -m agent.cli openai/computer-use-preview")
|
||||
sys.exit(1)
|
||||
9
libs/python/agent/agent/adapters/__init__.py
Normal file
9
libs/python/agent/agent/adapters/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Adapters package for agent - Custom LLM adapters for LiteLLM
|
||||
"""
|
||||
|
||||
from .huggingfacelocal_adapter import HuggingFaceLocalAdapter
|
||||
|
||||
__all__ = [
|
||||
"HuggingFaceLocalAdapter",
|
||||
]
|
||||
229
libs/python/agent/agent/adapters/huggingfacelocal_adapter.py
Normal file
229
libs/python/agent/agent/adapters/huggingfacelocal_adapter.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import asyncio
|
||||
import warnings
|
||||
from typing import Iterator, AsyncIterator, Dict, List, Any, Optional
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponse
|
||||
from litellm.llms.custom_llm import CustomLLM
|
||||
from litellm import completion, acompletion
|
||||
|
||||
# Try to import HuggingFace dependencies
|
||||
try:
|
||||
import torch
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
HF_AVAILABLE = True
|
||||
except ImportError:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
|
||||
class HuggingFaceLocalAdapter(CustomLLM):
|
||||
"""HuggingFace Local Adapter for running vision-language models locally."""
|
||||
|
||||
def __init__(self, device: str = "auto", **kwargs):
|
||||
"""Initialize the adapter.
|
||||
|
||||
Args:
|
||||
device: Device to load model on ("auto", "cuda", "cpu", etc.)
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.models = {} # Cache for loaded models
|
||||
self.processors = {} # Cache for loaded processors
|
||||
|
||||
def _load_model_and_processor(self, model_name: str):
|
||||
"""Load model and processor if not already cached.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to load
|
||||
|
||||
Returns:
|
||||
Tuple of (model, processor)
|
||||
"""
|
||||
if model_name not in self.models:
|
||||
# Load model
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=self.device,
|
||||
attn_implementation="sdpa"
|
||||
)
|
||||
|
||||
# Load processor
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
|
||||
# Cache them
|
||||
self.models[model_name] = model
|
||||
self.processors[model_name] = processor
|
||||
|
||||
return self.models[model_name], self.processors[model_name]
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert OpenAI format messages to HuggingFace format.
|
||||
|
||||
Args:
|
||||
messages: Messages in OpenAI format
|
||||
|
||||
Returns:
|
||||
Messages in HuggingFace format
|
||||
"""
|
||||
converted_messages = []
|
||||
|
||||
for message in messages:
|
||||
converted_message = {
|
||||
"role": message["role"],
|
||||
"content": []
|
||||
}
|
||||
|
||||
content = message.get("content", [])
|
||||
if isinstance(content, str):
|
||||
# Simple text content
|
||||
converted_message["content"].append({
|
||||
"type": "text",
|
||||
"text": content
|
||||
})
|
||||
elif isinstance(content, list):
|
||||
# Multi-modal content
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
converted_message["content"].append({
|
||||
"type": "text",
|
||||
"text": item.get("text", "")
|
||||
})
|
||||
elif item.get("type") == "image_url":
|
||||
# Convert image_url format to image format
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
converted_message["content"].append({
|
||||
"type": "image",
|
||||
"image": image_url
|
||||
})
|
||||
|
||||
converted_messages.append(converted_message)
|
||||
|
||||
return converted_messages
|
||||
|
||||
def _generate(self, **kwargs) -> str:
|
||||
"""Generate response using the local HuggingFace model.
|
||||
|
||||
Args:
|
||||
**kwargs: Keyword arguments containing messages and model info
|
||||
|
||||
Returns:
|
||||
Generated text response
|
||||
"""
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"HuggingFace transformers dependencies not found. "
|
||||
"Please install with: pip install \"cua-agent[uitars-hf]\""
|
||||
)
|
||||
|
||||
# Extract messages and model from kwargs
|
||||
messages = kwargs.get('messages', [])
|
||||
model_name = kwargs.get('model', 'ByteDance-Seed/UI-TARS-1.5-7B')
|
||||
max_new_tokens = kwargs.get('max_tokens', 128)
|
||||
|
||||
# Warn about ignored kwargs
|
||||
ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'}
|
||||
if ignored_kwargs:
|
||||
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
|
||||
|
||||
# Load model and processor
|
||||
model, processor = self._load_model_and_processor(model_name)
|
||||
|
||||
# Convert messages to HuggingFace format
|
||||
hf_messages = self._convert_messages(messages)
|
||||
|
||||
# Apply chat template and tokenize
|
||||
inputs = processor.apply_chat_template(
|
||||
hf_messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# Move inputs to the same device as model
|
||||
if torch.cuda.is_available() and self.device != "cpu":
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Generate response
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
|
||||
# Trim input tokens from output
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
|
||||
# Decode output
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False
|
||||
)
|
||||
|
||||
return output_text[0] if output_text else ""
|
||||
|
||||
def completion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Synchronous completion method.
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
return completion(
|
||||
model=f"huggingface-local/{kwargs['model']}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
|
||||
async def acompletion(self, *args, **kwargs) -> ModelResponse:
|
||||
"""Asynchronous completion method.
|
||||
|
||||
Returns:
|
||||
ModelResponse with generated text
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
generated_text = await asyncio.to_thread(self._generate, **kwargs)
|
||||
|
||||
return await acompletion(
|
||||
model=f"huggingface-local/{kwargs['model']}",
|
||||
mock_response=generated_text,
|
||||
)
|
||||
|
||||
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
|
||||
"""Synchronous streaming method.
|
||||
|
||||
Returns:
|
||||
Iterator of GenericStreamingChunk
|
||||
"""
|
||||
generated_text = self._generate(**kwargs)
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"is_finished": True,
|
||||
"text": generated_text,
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk
|
||||
|
||||
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
|
||||
"""Asynchronous streaming method.
|
||||
|
||||
Returns:
|
||||
AsyncIterator of GenericStreamingChunk
|
||||
"""
|
||||
# Run _generate in thread pool to avoid blocking
|
||||
generated_text = await asyncio.to_thread(self._generate, **kwargs)
|
||||
|
||||
generic_streaming_chunk: GenericStreamingChunk = {
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"is_finished": True,
|
||||
"text": generated_text,
|
||||
"tool_use": None,
|
||||
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
|
||||
yield generic_streaming_chunk
|
||||
594
libs/python/agent/agent/agent.py
Normal file
594
libs/python/agent/agent/agent.py
Normal file
@@ -0,0 +1,594 @@
|
||||
"""
|
||||
ComputerAgent - Main agent class that selects and runs agent loops
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set
|
||||
|
||||
from litellm.responses.utils import Usage
|
||||
|
||||
from .types import Messages, Computer
|
||||
from .decorators import find_agent_loop
|
||||
from .computer_handler import OpenAIComputerHandler, acknowledge_safety_check_callback, check_blocklisted_url
|
||||
import json
|
||||
import litellm
|
||||
import litellm.utils
|
||||
import inspect
|
||||
from .adapters import HuggingFaceLocalAdapter
|
||||
from .callbacks import (
|
||||
ImageRetentionCallback,
|
||||
LoggingCallback,
|
||||
TrajectorySaverCallback,
|
||||
BudgetManagerCallback,
|
||||
TelemetryCallback,
|
||||
)
|
||||
|
||||
def get_json(obj: Any, max_depth: int = 10) -> Any:
|
||||
def custom_serializer(o: Any, depth: int = 0, seen: Set[int] = None) -> Any:
|
||||
if seen is None:
|
||||
seen = set()
|
||||
|
||||
# Use model_dump() if available
|
||||
if hasattr(o, 'model_dump'):
|
||||
return o.model_dump()
|
||||
|
||||
# Check depth limit
|
||||
if depth > max_depth:
|
||||
return f"<max_depth_exceeded:{max_depth}>"
|
||||
|
||||
# Check for circular references using object id
|
||||
obj_id = id(o)
|
||||
if obj_id in seen:
|
||||
return f"<circular_reference:{type(o).__name__}>"
|
||||
|
||||
# Handle Computer objects
|
||||
if hasattr(o, '__class__') and 'computer' in getattr(o, '__class__').__name__.lower():
|
||||
return f"<computer:{o.__class__.__name__}>"
|
||||
|
||||
# Handle objects with __dict__
|
||||
if hasattr(o, '__dict__'):
|
||||
seen.add(obj_id)
|
||||
try:
|
||||
result = {}
|
||||
for k, v in o.__dict__.items():
|
||||
if v is not None:
|
||||
# Recursively serialize with updated depth and seen set
|
||||
serialized_value = custom_serializer(v, depth + 1, seen.copy())
|
||||
result[k] = serialized_value
|
||||
return result
|
||||
finally:
|
||||
seen.discard(obj_id)
|
||||
|
||||
# Handle common types that might contain nested objects
|
||||
elif isinstance(o, dict):
|
||||
seen.add(obj_id)
|
||||
try:
|
||||
return {
|
||||
k: custom_serializer(v, depth + 1, seen.copy())
|
||||
for k, v in o.items()
|
||||
if v is not None
|
||||
}
|
||||
finally:
|
||||
seen.discard(obj_id)
|
||||
|
||||
elif isinstance(o, (list, tuple, set)):
|
||||
seen.add(obj_id)
|
||||
try:
|
||||
return [
|
||||
custom_serializer(item, depth + 1, seen.copy())
|
||||
for item in o
|
||||
if item is not None
|
||||
]
|
||||
finally:
|
||||
seen.discard(obj_id)
|
||||
|
||||
# For basic types that json.dumps can handle
|
||||
elif isinstance(o, (str, int, float, bool)) or o is None:
|
||||
return o
|
||||
|
||||
# Fallback to string representation
|
||||
else:
|
||||
return str(o)
|
||||
|
||||
def remove_nones(obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
return {k: remove_nones(v) for k, v in obj.items() if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
return [remove_nones(item) for item in obj if item is not None]
|
||||
return obj
|
||||
|
||||
# Serialize with circular reference and depth protection
|
||||
serialized = custom_serializer(obj)
|
||||
|
||||
# Convert to JSON string and back to ensure JSON compatibility
|
||||
json_str = json.dumps(serialized)
|
||||
parsed = json.loads(json_str)
|
||||
|
||||
# Final cleanup of any remaining None values
|
||||
return remove_nones(parsed)
|
||||
|
||||
def sanitize_message(msg: Any) -> Any:
|
||||
"""Return a copy of the message with image_url omitted for computer_call_output messages."""
|
||||
if msg.get("type") == "computer_call_output":
|
||||
output = msg.get("output", {})
|
||||
if isinstance(output, dict):
|
||||
sanitized = msg.copy()
|
||||
sanitized["output"] = {**output, "image_url": "[omitted]"}
|
||||
return sanitized
|
||||
return msg
|
||||
|
||||
class ComputerAgent:
|
||||
"""
|
||||
Main agent class that automatically selects the appropriate agent loop
|
||||
based on the model and executes tool calls.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tools: Optional[List[Any]] = None,
|
||||
custom_loop: Optional[Callable] = None,
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
verbosity: Optional[int] = None,
|
||||
trajectory_dir: Optional[str] = None,
|
||||
max_retries: Optional[int] = 3,
|
||||
screenshot_delay: Optional[float | int] = 0.5,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
max_trajectory_budget: Optional[float | dict] = None,
|
||||
telemetry_enabled: Optional[bool] = True,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Initialize ComputerAgent.
|
||||
|
||||
Args:
|
||||
model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro")
|
||||
tools: List of tools (computer objects, decorated functions, etc.)
|
||||
custom_loop: Custom agent loop function to use instead of auto-selection
|
||||
only_n_most_recent_images: If set, only keep the N most recent images in message history. Adds ImageRetentionCallback automatically.
|
||||
callbacks: List of AsyncCallbackHandler instances for preprocessing/postprocessing
|
||||
verbosity: Logging level (logging.DEBUG, logging.INFO, etc.). If set, adds LoggingCallback automatically
|
||||
trajectory_dir: If set, saves trajectory data (screenshots, responses) to this directory. Adds TrajectorySaverCallback automatically.
|
||||
max_retries: Maximum number of retries for failed API calls
|
||||
screenshot_delay: Delay before screenshots in seconds
|
||||
use_prompt_caching: If set, use prompt caching to avoid reprocessing the same prompt. Intended for use with anthropic providers.
|
||||
max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded
|
||||
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
|
||||
**kwargs: Additional arguments passed to the agent loop
|
||||
"""
|
||||
self.model = model
|
||||
self.tools = tools or []
|
||||
self.custom_loop = custom_loop
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self.callbacks = callbacks or []
|
||||
self.verbosity = verbosity
|
||||
self.trajectory_dir = trajectory_dir
|
||||
self.max_retries = max_retries
|
||||
self.screenshot_delay = screenshot_delay
|
||||
self.use_prompt_caching = use_prompt_caching
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
self.kwargs = kwargs
|
||||
|
||||
# == Add built-in callbacks ==
|
||||
|
||||
# Add telemetry callback if telemetry_enabled is set
|
||||
if self.telemetry_enabled:
|
||||
if isinstance(self.telemetry_enabled, bool):
|
||||
self.callbacks.append(TelemetryCallback(self))
|
||||
else:
|
||||
self.callbacks.append(TelemetryCallback(self, **self.telemetry_enabled))
|
||||
|
||||
# Add logging callback if verbosity is set
|
||||
if self.verbosity is not None:
|
||||
self.callbacks.append(LoggingCallback(level=self.verbosity))
|
||||
|
||||
# Add image retention callback if only_n_most_recent_images is set
|
||||
if self.only_n_most_recent_images:
|
||||
self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images))
|
||||
|
||||
# Add trajectory saver callback if trajectory_dir is set
|
||||
if self.trajectory_dir:
|
||||
self.callbacks.append(TrajectorySaverCallback(self.trajectory_dir))
|
||||
|
||||
# Add budget manager if max_trajectory_budget is set
|
||||
if max_trajectory_budget:
|
||||
if isinstance(max_trajectory_budget, dict):
|
||||
self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget))
|
||||
else:
|
||||
self.callbacks.append(BudgetManagerCallback(max_trajectory_budget))
|
||||
|
||||
# == Enable local model providers w/ LiteLLM ==
|
||||
|
||||
# Register local model providers
|
||||
hf_adapter = HuggingFaceLocalAdapter(
|
||||
device="auto"
|
||||
)
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "huggingface-local", "custom_handler": hf_adapter}
|
||||
]
|
||||
|
||||
# == Initialize computer agent ==
|
||||
|
||||
# Find the appropriate agent loop
|
||||
if custom_loop:
|
||||
self.agent_loop = custom_loop
|
||||
self.agent_loop_info = None
|
||||
else:
|
||||
loop_info = find_agent_loop(model)
|
||||
if not loop_info:
|
||||
raise ValueError(f"No agent loop found for model: {model}")
|
||||
self.agent_loop = loop_info.func
|
||||
self.agent_loop_info = loop_info
|
||||
|
||||
self.tool_schemas = []
|
||||
self.computer_handler = None
|
||||
|
||||
async def _initialize_computers(self):
|
||||
"""Initialize computer objects"""
|
||||
if not self.tool_schemas:
|
||||
for tool in self.tools:
|
||||
if hasattr(tool, '_initialized') and not tool._initialized:
|
||||
await tool.run()
|
||||
|
||||
# Process tools and create tool schemas
|
||||
self.tool_schemas = self._process_tools()
|
||||
|
||||
# Find computer tool and create interface adapter
|
||||
computer_handler = None
|
||||
for schema in self.tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
computer_handler = OpenAIComputerHandler(schema["computer"].interface)
|
||||
break
|
||||
self.computer_handler = computer_handler
|
||||
|
||||
def _process_input(self, input: Messages) -> List[Dict[str, Any]]:
|
||||
"""Process input messages and create schemas for the agent loop"""
|
||||
if isinstance(input, str):
|
||||
return [{"role": "user", "content": input}]
|
||||
return [get_json(msg) for msg in input]
|
||||
|
||||
def _process_tools(self) -> List[Dict[str, Any]]:
|
||||
"""Process tools and create schemas for the agent loop"""
|
||||
schemas = []
|
||||
|
||||
for tool in self.tools:
|
||||
# Check if it's a computer object (has interface attribute)
|
||||
if hasattr(tool, 'interface'):
|
||||
# This is a computer tool - will be handled by agent loop
|
||||
schemas.append({
|
||||
"type": "computer",
|
||||
"computer": tool
|
||||
})
|
||||
elif callable(tool):
|
||||
# Use litellm.utils.function_to_dict to extract schema from docstring
|
||||
try:
|
||||
function_schema = litellm.utils.function_to_dict(tool)
|
||||
schemas.append({
|
||||
"type": "function",
|
||||
"function": function_schema
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"Warning: Could not process tool {tool}: {e}")
|
||||
else:
|
||||
print(f"Warning: Unknown tool type: {tool}")
|
||||
|
||||
return schemas
|
||||
|
||||
def _get_tool(self, name: str) -> Optional[Callable]:
|
||||
"""Get a tool by name"""
|
||||
for tool in self.tools:
|
||||
if hasattr(tool, '__name__') and tool.__name__ == name:
|
||||
return tool
|
||||
elif hasattr(tool, 'func') and tool.func.__name__ == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
# ============================================================================
|
||||
# AGENT RUN LOOP LIFECYCLE HOOKS
|
||||
# ============================================================================
|
||||
|
||||
async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Initialize run tracking by calling callbacks."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_run_start'):
|
||||
await callback.on_run_start(kwargs, old_items)
|
||||
|
||||
async def _on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
"""Finalize run tracking by calling callbacks."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_run_end'):
|
||||
await callback.on_run_end(kwargs, old_items, new_items)
|
||||
|
||||
async def _on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
|
||||
"""Check if run should continue by calling callbacks."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_run_continue'):
|
||||
should_continue = await callback.on_run_continue(kwargs, old_items, new_items)
|
||||
if not should_continue:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Prepare messages for the LLM call by applying callbacks."""
|
||||
result = messages
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_llm_start'):
|
||||
result = await callback.on_llm_start(result)
|
||||
return result
|
||||
|
||||
async def _on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Postprocess messages after the LLM call by applying callbacks."""
|
||||
result = messages
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_llm_end'):
|
||||
result = await callback.on_llm_end(result)
|
||||
return result
|
||||
|
||||
async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
||||
"""Called when responses are received."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_responses'):
|
||||
await callback.on_responses(get_json(kwargs), get_json(responses))
|
||||
|
||||
async def _on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a computer call is about to start."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_computer_call_start'):
|
||||
await callback.on_computer_call_start(get_json(item))
|
||||
|
||||
async def _on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
"""Called when a computer call has completed."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_computer_call_end'):
|
||||
await callback.on_computer_call_end(get_json(item), get_json(result))
|
||||
|
||||
async def _on_function_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a function call is about to start."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_function_call_start'):
|
||||
await callback.on_function_call_start(get_json(item))
|
||||
|
||||
async def _on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
"""Called when a function call has completed."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_function_call_end'):
|
||||
await callback.on_function_call_end(get_json(item), get_json(result))
|
||||
|
||||
async def _on_text(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a text message is encountered."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_text'):
|
||||
await callback.on_text(get_json(item))
|
||||
|
||||
async def _on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""Called when an LLM API call is about to start."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_api_start'):
|
||||
await callback.on_api_start(get_json(kwargs))
|
||||
|
||||
async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when an LLM API call has completed."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_api_end'):
|
||||
await callback.on_api_end(get_json(kwargs), get_json(result))
|
||||
|
||||
async def _on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Called when usage information is received."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_usage'):
|
||||
await callback.on_usage(get_json(usage))
|
||||
|
||||
async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""Called when a screenshot is taken."""
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, 'on_screenshot'):
|
||||
await callback.on_screenshot(screenshot, name)
|
||||
|
||||
# ============================================================================
|
||||
# AGENT OUTPUT PROCESSING
|
||||
# ============================================================================
|
||||
|
||||
async def _handle_item(self, item: Any, computer: Optional[Computer] = None) -> List[Dict[str, Any]]:
|
||||
"""Handle each item; may cause a computer action + screenshot."""
|
||||
|
||||
item_type = item.get("type", None)
|
||||
|
||||
if item_type == "message":
|
||||
await self._on_text(item)
|
||||
# # Print messages
|
||||
# if item.get("content"):
|
||||
# for content_item in item.get("content"):
|
||||
# if content_item.get("text"):
|
||||
# print(content_item.get("text"))
|
||||
return []
|
||||
|
||||
if item_type == "computer_call":
|
||||
await self._on_computer_call_start(item)
|
||||
if not computer:
|
||||
raise ValueError("Computer handler is required for computer calls")
|
||||
|
||||
# Perform computer actions
|
||||
action = item.get("action")
|
||||
action_type = action.get("type")
|
||||
|
||||
# Extract action arguments (all fields except 'type')
|
||||
action_args = {k: v for k, v in action.items() if k != "type"}
|
||||
|
||||
# print(f"{action_type}({action_args})")
|
||||
|
||||
# Execute the computer action
|
||||
computer_method = getattr(computer, action_type, None)
|
||||
if computer_method:
|
||||
await computer_method(**action_args)
|
||||
else:
|
||||
print(f"Unknown computer action: {action_type}")
|
||||
return []
|
||||
|
||||
# Take screenshot after action
|
||||
if self.screenshot_delay and self.screenshot_delay > 0:
|
||||
await asyncio.sleep(self.screenshot_delay)
|
||||
screenshot_base64 = await computer.screenshot()
|
||||
await self._on_screenshot(screenshot_base64, "screenshot_after")
|
||||
|
||||
# Handle safety checks
|
||||
pending_checks = item.get("pending_safety_checks", [])
|
||||
acknowledged_checks = []
|
||||
for check in pending_checks:
|
||||
check_message = check.get("message", str(check))
|
||||
if acknowledge_safety_check_callback(check_message):
|
||||
acknowledged_checks.append(check)
|
||||
else:
|
||||
raise ValueError(f"Safety check failed: {check_message}")
|
||||
|
||||
# Create call output
|
||||
call_output = {
|
||||
"type": "computer_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"acknowledged_safety_checks": acknowledged_checks,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
||||
},
|
||||
}
|
||||
|
||||
# Additional URL safety checks for browser environments
|
||||
if await computer.get_environment() == "browser":
|
||||
current_url = await computer.get_current_url()
|
||||
call_output["output"]["current_url"] = current_url
|
||||
check_blocklisted_url(current_url)
|
||||
|
||||
result = [call_output]
|
||||
await self._on_computer_call_end(item, result)
|
||||
return result
|
||||
|
||||
if item_type == "function_call":
|
||||
await self._on_function_call_start(item)
|
||||
# Perform function call
|
||||
function = self._get_tool(item.get("name"))
|
||||
if not function:
|
||||
raise ValueError(f"Function {item.get("name")} not found")
|
||||
|
||||
args = json.loads(item.get("arguments"))
|
||||
|
||||
# Execute function - use asyncio.to_thread for non-async functions
|
||||
if inspect.iscoroutinefunction(function):
|
||||
result = await function(**args)
|
||||
else:
|
||||
result = await asyncio.to_thread(function, **args)
|
||||
|
||||
# Create function call output
|
||||
call_output = {
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"output": str(result),
|
||||
}
|
||||
|
||||
result = [call_output]
|
||||
await self._on_function_call_end(item, result)
|
||||
return result
|
||||
|
||||
return []
|
||||
|
||||
# ============================================================================
|
||||
# MAIN AGENT LOOP
|
||||
# ============================================================================
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: Messages,
|
||||
stream: bool = False,
|
||||
**kwargs
|
||||
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""
|
||||
Run the agent with the given messages using Computer protocol handler pattern.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
stream: Whether to stream the response
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
AsyncGenerator that yields response chunks
|
||||
"""
|
||||
|
||||
await self._initialize_computers()
|
||||
|
||||
# Merge kwargs
|
||||
merged_kwargs = {**self.kwargs, **kwargs}
|
||||
|
||||
old_items = self._process_input(messages)
|
||||
new_items = []
|
||||
|
||||
# Initialize run tracking
|
||||
run_kwargs = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"model": self.model,
|
||||
"agent_loop": self.agent_loop.__name__,
|
||||
**merged_kwargs
|
||||
}
|
||||
await self._on_run_start(run_kwargs, old_items)
|
||||
|
||||
while new_items[-1].get("role") != "assistant" if new_items else True:
|
||||
# Lifecycle hook: Check if we should continue based on callbacks (e.g., budget manager)
|
||||
should_continue = await self._on_run_continue(run_kwargs, old_items, new_items)
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
# Lifecycle hook: Prepare messages for the LLM call
|
||||
# Use cases:
|
||||
# - PII anonymization
|
||||
# - Image retention policy
|
||||
combined_messages = old_items + new_items
|
||||
preprocessed_messages = await self._on_llm_start(combined_messages)
|
||||
|
||||
loop_kwargs = {
|
||||
"messages": preprocessed_messages,
|
||||
"model": self.model,
|
||||
"tools": self.tool_schemas,
|
||||
"stream": False,
|
||||
"computer_handler": self.computer_handler,
|
||||
"max_retries": self.max_retries,
|
||||
"use_prompt_caching": self.use_prompt_caching,
|
||||
**merged_kwargs
|
||||
}
|
||||
|
||||
# Run agent loop iteration
|
||||
result = await self.agent_loop(
|
||||
**loop_kwargs,
|
||||
_on_api_start=self._on_api_start,
|
||||
_on_api_end=self._on_api_end,
|
||||
_on_usage=self._on_usage,
|
||||
_on_screenshot=self._on_screenshot,
|
||||
)
|
||||
result = get_json(result)
|
||||
|
||||
# Lifecycle hook: Postprocess messages after the LLM call
|
||||
# Use cases:
|
||||
# - PII deanonymization (if you want tool calls to see PII)
|
||||
result["output"] = await self._on_llm_end(result.get("output", []))
|
||||
await self._on_responses(loop_kwargs, result)
|
||||
|
||||
# Yield agent response
|
||||
yield result
|
||||
|
||||
# Add agent response to new_items
|
||||
new_items += result.get("output")
|
||||
|
||||
# Handle computer actions
|
||||
for item in result.get("output"):
|
||||
partial_items = await self._handle_item(item, self.computer_handler)
|
||||
new_items += partial_items
|
||||
|
||||
# Yield partial response
|
||||
yield {
|
||||
"output": partial_items,
|
||||
"usage": Usage(
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
total_tokens=0,
|
||||
)
|
||||
}
|
||||
|
||||
await self._on_run_end(loop_kwargs, old_items, new_items)
|
||||
19
libs/python/agent/agent/callbacks/__init__.py
Normal file
19
libs/python/agent/agent/callbacks/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Callback system for ComputerAgent preprocessing and postprocessing hooks.
|
||||
"""
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
from .image_retention import ImageRetentionCallback
|
||||
from .logging import LoggingCallback
|
||||
from .trajectory_saver import TrajectorySaverCallback
|
||||
from .budget_manager import BudgetManagerCallback
|
||||
from .telemetry import TelemetryCallback
|
||||
|
||||
__all__ = [
|
||||
"AsyncCallbackHandler",
|
||||
"ImageRetentionCallback",
|
||||
"LoggingCallback",
|
||||
"TrajectorySaverCallback",
|
||||
"BudgetManagerCallback",
|
||||
"TelemetryCallback",
|
||||
]
|
||||
153
libs/python/agent/agent/callbacks/base.py
Normal file
153
libs/python/agent/agent/callbacks/base.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""
|
||||
Base callback handler interface for ComputerAgent preprocessing and postprocessing hooks.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
|
||||
class AsyncCallbackHandler(ABC):
|
||||
"""
|
||||
Base class for async callback handlers that can preprocess messages before
|
||||
the agent loop and postprocess output after the agent loop.
|
||||
"""
|
||||
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called at the start of an agent run loop."""
|
||||
pass
|
||||
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called at the end of an agent run loop."""
|
||||
pass
|
||||
|
||||
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
|
||||
"""Called during agent run loop to determine if execution should continue.
|
||||
|
||||
Args:
|
||||
kwargs: Run arguments
|
||||
old_items: Original messages
|
||||
new_items: New messages generated during run
|
||||
|
||||
Returns:
|
||||
True to continue execution, False to stop
|
||||
"""
|
||||
return True
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Called before messages are sent to the agent loop.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries to preprocess
|
||||
|
||||
Returns:
|
||||
List of preprocessed message dictionaries
|
||||
"""
|
||||
return messages
|
||||
|
||||
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Called after the agent loop returns output.
|
||||
|
||||
Args:
|
||||
output: List of output message dictionaries to postprocess
|
||||
|
||||
Returns:
|
||||
List of postprocessed output dictionaries
|
||||
"""
|
||||
return output
|
||||
|
||||
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when a computer call is about to start.
|
||||
|
||||
Args:
|
||||
item: The computer call item dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Called when a computer call has completed.
|
||||
|
||||
Args:
|
||||
item: The computer call item dictionary
|
||||
result: The result of the computer call
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when a function call is about to start.
|
||||
|
||||
Args:
|
||||
item: The function call item dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Called when a function call has completed.
|
||||
|
||||
Args:
|
||||
item: The function call item dictionary
|
||||
result: The result of the function call
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_text(self, item: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when a text message is encountered.
|
||||
|
||||
Args:
|
||||
item: The message item dictionary
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when an API call is about to start.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs being passed to the API call
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""
|
||||
Called when an API call has completed.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs that were passed to the API call
|
||||
result: The result of the API call
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when usage information is received.
|
||||
|
||||
Args:
|
||||
usage: The usage information
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""
|
||||
Called when a screenshot is taken.
|
||||
|
||||
Args:
|
||||
screenshot: The screenshot image
|
||||
name: The name of the screenshot
|
||||
"""
|
||||
pass
|
||||
|
||||
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Called when responses are received.
|
||||
|
||||
Args:
|
||||
kwargs: The kwargs being passed to the agent loop
|
||||
responses: The responses received
|
||||
"""
|
||||
pass
|
||||
44
libs/python/agent/agent/callbacks/budget_manager.py
Normal file
44
libs/python/agent/agent/callbacks/budget_manager.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from typing import Dict, List, Any
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
class BudgetExceededError(Exception):
|
||||
"""Exception raised when budget is exceeded."""
|
||||
pass
|
||||
|
||||
class BudgetManagerCallback(AsyncCallbackHandler):
|
||||
"""Budget manager callback that tracks usage costs and can stop execution when budget is exceeded."""
|
||||
|
||||
def __init__(self, max_budget: float, reset_after_each_run: bool = True, raise_error: bool = False):
|
||||
"""
|
||||
Initialize BudgetManagerCallback.
|
||||
|
||||
Args:
|
||||
max_budget: Maximum budget allowed
|
||||
reset_after_each_run: Whether to reset budget after each run
|
||||
raise_error: Whether to raise an error when budget is exceeded
|
||||
"""
|
||||
self.max_budget = max_budget
|
||||
self.reset_after_each_run = reset_after_each_run
|
||||
self.raise_error = raise_error
|
||||
self.total_cost = 0.0
|
||||
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Reset budget if configured to do so."""
|
||||
if self.reset_after_each_run:
|
||||
self.total_cost = 0.0
|
||||
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Track usage costs."""
|
||||
if "response_cost" in usage:
|
||||
self.total_cost += usage["response_cost"]
|
||||
|
||||
async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool:
|
||||
"""Check if budget allows continuation."""
|
||||
if self.total_cost >= self.max_budget:
|
||||
if self.raise_error:
|
||||
raise BudgetExceededError(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
|
||||
else:
|
||||
print(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}")
|
||||
return False
|
||||
return True
|
||||
|
||||
139
libs/python/agent/agent/callbacks/image_retention.py
Normal file
139
libs/python/agent/agent/callbacks/image_retention.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""
|
||||
Image retention callback handler that limits the number of recent images in message history.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
class ImageRetentionCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that applies image retention policy to limit the number
|
||||
of recent images in message history to prevent context window overflow.
|
||||
"""
|
||||
|
||||
def __init__(self, only_n_most_recent_images: Optional[int] = None):
|
||||
"""
|
||||
Initialize the image retention callback.
|
||||
|
||||
Args:
|
||||
only_n_most_recent_images: If set, only keep the N most recent images in message history
|
||||
"""
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Apply image retention policy to messages before sending to agent loop.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
List of messages with image retention policy applied
|
||||
"""
|
||||
if self.only_n_most_recent_images is None:
|
||||
return messages
|
||||
|
||||
return self._apply_image_retention(messages)
|
||||
|
||||
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Apply image retention policy to keep only the N most recent images.
|
||||
|
||||
Removes computer_call_output items with image_url and their corresponding computer_call items,
|
||||
keeping only the most recent N image pairs based on only_n_most_recent_images setting.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
Filtered list of messages with image retention applied
|
||||
"""
|
||||
if self.only_n_most_recent_images is None:
|
||||
return messages
|
||||
|
||||
# First pass: Assign call_id to reasoning items based on the next computer_call
|
||||
messages_with_call_ids = []
|
||||
for i, msg in enumerate(messages):
|
||||
msg_copy = msg.copy() if isinstance(msg, dict) else msg
|
||||
|
||||
# If this is a reasoning item without a call_id, find the next computer_call
|
||||
if (msg_copy.get("type") == "reasoning" and
|
||||
not msg_copy.get("call_id")):
|
||||
# Look ahead for the next computer_call
|
||||
for j in range(i + 1, len(messages)):
|
||||
next_msg = messages[j]
|
||||
if (next_msg.get("type") == "computer_call" and
|
||||
next_msg.get("call_id")):
|
||||
msg_copy["call_id"] = next_msg.get("call_id")
|
||||
break
|
||||
|
||||
messages_with_call_ids.append(msg_copy)
|
||||
|
||||
# Find all computer_call_output items with images and their call_ids
|
||||
image_call_ids = []
|
||||
for msg in reversed(messages_with_call_ids): # Process in reverse to get most recent first
|
||||
if (msg.get("type") == "computer_call_output" and
|
||||
isinstance(msg.get("output"), dict) and
|
||||
"image_url" in msg.get("output", {})):
|
||||
call_id = msg.get("call_id")
|
||||
if call_id and call_id not in image_call_ids:
|
||||
image_call_ids.append(call_id)
|
||||
if len(image_call_ids) >= self.only_n_most_recent_images:
|
||||
break
|
||||
|
||||
# Keep the most recent N image call_ids (reverse to get chronological order)
|
||||
keep_call_ids = set(image_call_ids[:self.only_n_most_recent_images])
|
||||
|
||||
# Filter messages: remove computer_call, computer_call_output, and reasoning for old images
|
||||
filtered_messages = []
|
||||
for msg in messages_with_call_ids:
|
||||
msg_type = msg.get("type")
|
||||
call_id = msg.get("call_id")
|
||||
|
||||
# Remove old computer_call items
|
||||
if msg_type == "computer_call" and call_id not in keep_call_ids:
|
||||
# Check if this call_id corresponds to an image call
|
||||
has_image_output = any(
|
||||
m.get("type") == "computer_call_output" and
|
||||
m.get("call_id") == call_id and
|
||||
isinstance(m.get("output"), dict) and
|
||||
"image_url" in m.get("output", {})
|
||||
for m in messages_with_call_ids
|
||||
)
|
||||
if has_image_output:
|
||||
continue # Skip this computer_call
|
||||
|
||||
# Remove old computer_call_output items with images
|
||||
if (msg_type == "computer_call_output" and
|
||||
call_id not in keep_call_ids and
|
||||
isinstance(msg.get("output"), dict) and
|
||||
"image_url" in msg.get("output", {})):
|
||||
continue # Skip this computer_call_output
|
||||
|
||||
# Remove old reasoning items that are paired with removed computer calls
|
||||
if (msg_type == "reasoning" and
|
||||
call_id and call_id not in keep_call_ids):
|
||||
# Check if this call_id corresponds to an image call that's being removed
|
||||
has_image_output = any(
|
||||
m.get("type") == "computer_call_output" and
|
||||
m.get("call_id") == call_id and
|
||||
isinstance(m.get("output"), dict) and
|
||||
"image_url" in m.get("output", {})
|
||||
for m in messages_with_call_ids
|
||||
)
|
||||
if has_image_output:
|
||||
continue # Skip this reasoning item
|
||||
|
||||
filtered_messages.append(msg)
|
||||
|
||||
# Clean up: Remove call_id from reasoning items before returning
|
||||
final_messages = []
|
||||
for msg in filtered_messages:
|
||||
if msg.get("type") == "reasoning" and "call_id" in msg:
|
||||
# Create a copy without call_id for reasoning items
|
||||
cleaned_msg = {k: v for k, v in msg.items() if k != "call_id"}
|
||||
final_messages.append(cleaned_msg)
|
||||
else:
|
||||
final_messages.append(msg)
|
||||
|
||||
return final_messages
|
||||
247
libs/python/agent/agent/callbacks/logging.py
Normal file
247
libs/python/agent/agent/callbacks/logging.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Logging callback for ComputerAgent that provides configurable logging of agent lifecycle events.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, List, Any, Optional, Union
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
|
||||
def sanitize_image_urls(data: Any) -> Any:
|
||||
"""
|
||||
Recursively search for 'image_url' keys and set their values to '[omitted]'.
|
||||
|
||||
Args:
|
||||
data: Any data structure (dict, list, or primitive type)
|
||||
|
||||
Returns:
|
||||
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
# Create a copy of the dictionary
|
||||
sanitized = {}
|
||||
for key, value in data.items():
|
||||
if key == "image_url":
|
||||
sanitized[key] = "[omitted]"
|
||||
else:
|
||||
# Recursively sanitize the value
|
||||
sanitized[key] = sanitize_image_urls(value)
|
||||
return sanitized
|
||||
|
||||
elif isinstance(data, list):
|
||||
# Recursively sanitize each item in the list
|
||||
return [sanitize_image_urls(item) for item in data]
|
||||
|
||||
else:
|
||||
# For primitive types (str, int, bool, None, etc.), return as-is
|
||||
return data
|
||||
|
||||
|
||||
class LoggingCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that logs agent lifecycle events with configurable verbosity.
|
||||
|
||||
Logging levels:
|
||||
- DEBUG: All events including API calls, message preprocessing, and detailed outputs
|
||||
- INFO: Major lifecycle events (start/end, messages, outputs)
|
||||
- WARNING: Only warnings and errors
|
||||
- ERROR: Only errors
|
||||
"""
|
||||
|
||||
def __init__(self, logger: Optional[logging.Logger] = None, level: int = logging.INFO):
|
||||
"""
|
||||
Initialize the logging callback.
|
||||
|
||||
Args:
|
||||
logger: Logger instance to use. If None, creates a logger named 'agent.ComputerAgent'
|
||||
level: Logging level (logging.DEBUG, logging.INFO, etc.)
|
||||
"""
|
||||
self.logger = logger or logging.getLogger('agent.ComputerAgent')
|
||||
self.level = level
|
||||
|
||||
# Set up logger if it doesn't have handlers
|
||||
if not self.logger.handlers:
|
||||
handler = logging.StreamHandler()
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
self.logger.addHandler(handler)
|
||||
self.logger.setLevel(level)
|
||||
|
||||
def _update_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Update total usage statistics."""
|
||||
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||
for key, value in source.items():
|
||||
if isinstance(value, dict):
|
||||
if key not in target:
|
||||
target[key] = {}
|
||||
add_dicts(target[key], value)
|
||||
else:
|
||||
if key not in target:
|
||||
target[key] = 0
|
||||
target[key] += value
|
||||
add_dicts(self.total_usage, usage)
|
||||
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called before the run starts."""
|
||||
self.total_usage = {}
|
||||
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Called when usage information is received."""
|
||||
self._update_usage(usage)
|
||||
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called after the run ends."""
|
||||
def format_dict(d, indent=0):
|
||||
lines = []
|
||||
prefix = f" - {' ' * indent}"
|
||||
for key, value in d.items():
|
||||
if isinstance(value, dict):
|
||||
lines.append(f"{prefix}{key}:")
|
||||
lines.extend(format_dict(value, indent + 1))
|
||||
elif isinstance(value, float):
|
||||
lines.append(f"{prefix}{key}: ${value:.4f}")
|
||||
else:
|
||||
lines.append(f"{prefix}{key}: {value}")
|
||||
return lines
|
||||
|
||||
formatted_output = "\n".join(format_dict(self.total_usage))
|
||||
self.logger.info(f"Total usage:\n{formatted_output}")
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Called before LLM processing starts."""
|
||||
if self.logger.isEnabledFor(logging.INFO):
|
||||
self.logger.info(f"LLM processing started with {len(messages)} messages")
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
|
||||
self.logger.debug(f"LLM input messages: {json.dumps(sanitized_messages, indent=2)}")
|
||||
return messages
|
||||
|
||||
async def on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Called after LLM processing ends."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
sanitized_messages = [sanitize_image_urls(msg) for msg in messages]
|
||||
self.logger.debug(f"LLM output: {json.dumps(sanitized_messages, indent=2)}")
|
||||
return messages
|
||||
|
||||
async def on_computer_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a computer call starts."""
|
||||
action = item.get("action", {})
|
||||
action_type = action.get("type", "unknown")
|
||||
action_args = {k: v for k, v in action.items() if k != "type"}
|
||||
|
||||
# INFO level logging for the action
|
||||
self.logger.info(f"Computer: {action_type}({action_args})")
|
||||
|
||||
# DEBUG level logging for full details
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"Computer call started: {json.dumps(action, indent=2)}")
|
||||
|
||||
async def on_computer_call_end(self, item: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when a computer call ends."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
action = item.get("action", "unknown")
|
||||
self.logger.debug(f"Computer call completed: {json.dumps(action, indent=2)}")
|
||||
if result:
|
||||
sanitized_result = sanitize_image_urls(result)
|
||||
self.logger.debug(f"Computer call result: {json.dumps(sanitized_result, indent=2)}")
|
||||
|
||||
async def on_function_call_start(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a function call starts."""
|
||||
name = item.get("name", "unknown")
|
||||
arguments = item.get("arguments", "{}")
|
||||
|
||||
# INFO level logging for the function call
|
||||
self.logger.info(f"Function: {name}({arguments})")
|
||||
|
||||
# DEBUG level logging for full details
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"Function call started: {name}")
|
||||
|
||||
async def on_function_call_end(self, item: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when a function call ends."""
|
||||
# INFO level logging for function output (similar to function_call_output)
|
||||
if result:
|
||||
# Handle both list and direct result formats
|
||||
if isinstance(result, list) and len(result) > 0:
|
||||
output = result[0].get("output", str(result)) if isinstance(result[0], dict) else str(result[0])
|
||||
else:
|
||||
output = str(result)
|
||||
|
||||
# Truncate long outputs
|
||||
if len(output) > 100:
|
||||
output = output[:100] + "..."
|
||||
|
||||
self.logger.info(f"Output: {output}")
|
||||
|
||||
# DEBUG level logging for full details
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
name = item.get("name", "unknown")
|
||||
self.logger.debug(f"Function call completed: {name}")
|
||||
if result:
|
||||
self.logger.debug(f"Function call result: {json.dumps(result, indent=2)}")
|
||||
|
||||
async def on_text(self, item: Dict[str, Any]) -> None:
|
||||
"""Called when a text message is encountered."""
|
||||
# Get the role to determine if it's Agent or User
|
||||
role = item.get("role", "unknown")
|
||||
content_items = item.get("content", [])
|
||||
|
||||
# Process content items to build display text
|
||||
text_parts = []
|
||||
for content_item in content_items:
|
||||
content_type = content_item.get("type", "output_text")
|
||||
if content_type == "output_text":
|
||||
text_content = content_item.get("text", "")
|
||||
if not text_content.strip():
|
||||
text_parts.append("[empty]")
|
||||
else:
|
||||
# Truncate long text and add ellipsis
|
||||
if len(text_content) > 2048:
|
||||
text_parts.append(text_content[:2048] + "...")
|
||||
else:
|
||||
text_parts.append(text_content)
|
||||
else:
|
||||
# Non-text content, show as [type]
|
||||
text_parts.append(f"[{content_type}]")
|
||||
|
||||
# Join all text parts
|
||||
display_text = ''.join(text_parts) if text_parts else "[empty]"
|
||||
|
||||
# Log with appropriate level and format
|
||||
if role == "assistant":
|
||||
self.logger.info(f"Agent: {display_text}")
|
||||
elif role == "user":
|
||||
self.logger.info(f"User: {display_text}")
|
||||
else:
|
||||
# Fallback for unknown roles, use debug level
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
self.logger.debug(f"Text message ({role}): {display_text}")
|
||||
|
||||
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
"""Called when an API call is about to start."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
model = kwargs.get("model", "unknown")
|
||||
self.logger.debug(f"API call starting for model: {model}")
|
||||
# Log sanitized messages if present
|
||||
if "messages" in kwargs:
|
||||
sanitized_messages = sanitize_image_urls(kwargs["messages"])
|
||||
self.logger.debug(f"API call messages: {json.dumps(sanitized_messages, indent=2)}")
|
||||
elif "input" in kwargs:
|
||||
sanitized_input = sanitize_image_urls(kwargs["input"])
|
||||
self.logger.debug(f"API call input: {json.dumps(sanitized_input, indent=2)}")
|
||||
|
||||
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""Called when an API call has completed."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
model = kwargs.get("model", "unknown")
|
||||
self.logger.debug(f"API call completed for model: {model}")
|
||||
self.logger.debug(f"API call result: {json.dumps(sanitize_image_urls(result), indent=2)}")
|
||||
|
||||
async def on_screenshot(self, item: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""Called when a screenshot is taken."""
|
||||
if self.logger.isEnabledFor(logging.DEBUG):
|
||||
image_size = len(item) / 1024
|
||||
self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB")
|
||||
259
libs/python/agent/agent/callbacks/pii_anonymization.py
Normal file
259
libs/python/agent/agent/callbacks/pii_anonymization.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
PII anonymization callback handler using Microsoft Presidio for text and image redaction.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional, Tuple
|
||||
from .base import AsyncCallbackHandler
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
|
||||
try:
|
||||
from presidio_analyzer import AnalyzerEngine
|
||||
from presidio_anonymizer import AnonymizerEngine, DeanonymizeEngine
|
||||
from presidio_anonymizer.entities import RecognizerResult, OperatorConfig
|
||||
from presidio_image_redactor import ImageRedactorEngine
|
||||
from PIL import Image
|
||||
PRESIDIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
PRESIDIO_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that anonymizes PII in text and images using Microsoft Presidio.
|
||||
|
||||
This handler:
|
||||
1. Anonymizes PII in messages before sending to the agent loop
|
||||
2. Deanonymizes PII in tool calls and message outputs after the agent loop
|
||||
3. Redacts PII from images in computer_call_output messages
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
anonymize_text: bool = True,
|
||||
anonymize_images: bool = True,
|
||||
entities_to_anonymize: Optional[List[str]] = None,
|
||||
anonymization_operator: str = "replace",
|
||||
image_redaction_color: Tuple[int, int, int] = (255, 192, 203) # Pink
|
||||
):
|
||||
"""
|
||||
Initialize the PII anonymization callback.
|
||||
|
||||
Args:
|
||||
anonymize_text: Whether to anonymize text content
|
||||
anonymize_images: Whether to redact images
|
||||
entities_to_anonymize: List of entity types to anonymize (None for all)
|
||||
anonymization_operator: Presidio operator to use ("replace", "mask", "redact", etc.)
|
||||
image_redaction_color: RGB color for image redaction
|
||||
"""
|
||||
if not PRESIDIO_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Presidio is not available. Install with: "
|
||||
"pip install presidio-analyzer presidio-anonymizer presidio-image-redactor"
|
||||
)
|
||||
|
||||
self.anonymize_text = anonymize_text
|
||||
self.anonymize_images = anonymize_images
|
||||
self.entities_to_anonymize = entities_to_anonymize
|
||||
self.anonymization_operator = anonymization_operator
|
||||
self.image_redaction_color = image_redaction_color
|
||||
|
||||
# Initialize Presidio engines
|
||||
self.analyzer = AnalyzerEngine()
|
||||
self.anonymizer = AnonymizerEngine()
|
||||
self.deanonymizer = DeanonymizeEngine()
|
||||
self.image_redactor = ImageRedactorEngine()
|
||||
|
||||
# Store anonymization mappings for deanonymization
|
||||
self.anonymization_mappings: Dict[str, Any] = {}
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Anonymize PII in messages before sending to agent loop.
|
||||
|
||||
Args:
|
||||
messages: List of message dictionaries
|
||||
|
||||
Returns:
|
||||
List of messages with PII anonymized
|
||||
"""
|
||||
if not self.anonymize_text and not self.anonymize_images:
|
||||
return messages
|
||||
|
||||
anonymized_messages = []
|
||||
for msg in messages:
|
||||
anonymized_msg = await self._anonymize_message(msg)
|
||||
anonymized_messages.append(anonymized_msg)
|
||||
|
||||
return anonymized_messages
|
||||
|
||||
async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Deanonymize PII in tool calls and message outputs after agent loop.
|
||||
|
||||
Args:
|
||||
output: List of output dictionaries
|
||||
|
||||
Returns:
|
||||
List of output with PII deanonymized for tool calls
|
||||
"""
|
||||
if not self.anonymize_text:
|
||||
return output
|
||||
|
||||
deanonymized_output = []
|
||||
for item in output:
|
||||
# Only deanonymize tool calls and computer_call messages
|
||||
if item.get("type") in ["computer_call", "computer_call_output"]:
|
||||
deanonymized_item = await self._deanonymize_item(item)
|
||||
deanonymized_output.append(deanonymized_item)
|
||||
else:
|
||||
deanonymized_output.append(item)
|
||||
|
||||
return deanonymized_output
|
||||
|
||||
async def _anonymize_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Anonymize PII in a single message."""
|
||||
msg_copy = message.copy()
|
||||
|
||||
# Anonymize text content
|
||||
if self.anonymize_text:
|
||||
msg_copy = await self._anonymize_text_content(msg_copy)
|
||||
|
||||
# Redact images in computer_call_output
|
||||
if self.anonymize_images and msg_copy.get("type") == "computer_call_output":
|
||||
msg_copy = await self._redact_image_content(msg_copy)
|
||||
|
||||
return msg_copy
|
||||
|
||||
async def _anonymize_text_content(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Anonymize text content in a message."""
|
||||
msg_copy = message.copy()
|
||||
|
||||
# Handle content array
|
||||
content = msg_copy.get("content", [])
|
||||
if isinstance(content, str):
|
||||
anonymized_text, _ = await self._anonymize_text(content)
|
||||
msg_copy["content"] = anonymized_text
|
||||
elif isinstance(content, list):
|
||||
anonymized_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
anonymized_text, _ = await self._anonymize_text(text)
|
||||
item_copy = item.copy()
|
||||
item_copy["text"] = anonymized_text
|
||||
anonymized_content.append(item_copy)
|
||||
else:
|
||||
anonymized_content.append(item)
|
||||
msg_copy["content"] = anonymized_content
|
||||
|
||||
return msg_copy
|
||||
|
||||
async def _redact_image_content(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Redact PII from images in computer_call_output messages."""
|
||||
msg_copy = message.copy()
|
||||
output = msg_copy.get("output", {})
|
||||
|
||||
if isinstance(output, dict) and "image_url" in output:
|
||||
try:
|
||||
# Extract base64 image data
|
||||
image_url = output["image_url"]
|
||||
if image_url.startswith("data:image/"):
|
||||
# Parse data URL
|
||||
header, data = image_url.split(",", 1)
|
||||
image_data = base64.b64decode(data)
|
||||
|
||||
# Load image with PIL
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# Redact PII from image
|
||||
redacted_image = self.image_redactor.redact(image, self.image_redaction_color)
|
||||
|
||||
# Convert back to base64
|
||||
buffer = io.BytesIO()
|
||||
redacted_image.save(buffer, format="PNG")
|
||||
redacted_data = base64.b64encode(buffer.getvalue()).decode()
|
||||
|
||||
# Update image URL
|
||||
output_copy = output.copy()
|
||||
output_copy["image_url"] = f"data:image/png;base64,{redacted_data}"
|
||||
msg_copy["output"] = output_copy
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to redact image: {e}")
|
||||
|
||||
return msg_copy
|
||||
|
||||
async def _deanonymize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Deanonymize PII in tool calls and computer outputs."""
|
||||
item_copy = item.copy()
|
||||
|
||||
# Handle computer_call arguments
|
||||
if item.get("type") == "computer_call":
|
||||
args = item_copy.get("args", {})
|
||||
if isinstance(args, dict):
|
||||
deanonymized_args = {}
|
||||
for key, value in args.items():
|
||||
if isinstance(value, str):
|
||||
deanonymized_value, _ = await self._deanonymize_text(value)
|
||||
deanonymized_args[key] = deanonymized_value
|
||||
else:
|
||||
deanonymized_args[key] = value
|
||||
item_copy["args"] = deanonymized_args
|
||||
|
||||
return item_copy
|
||||
|
||||
async def _anonymize_text(self, text: str) -> Tuple[str, List[RecognizerResult]]:
|
||||
"""Anonymize PII in text and return the anonymized text and results."""
|
||||
if not text.strip():
|
||||
return text, []
|
||||
|
||||
try:
|
||||
# Analyze text for PII
|
||||
analyzer_results = self.analyzer.analyze(
|
||||
text=text,
|
||||
entities=self.entities_to_anonymize,
|
||||
language="en"
|
||||
)
|
||||
|
||||
if not analyzer_results:
|
||||
return text, []
|
||||
|
||||
# Anonymize the text
|
||||
anonymized_result = self.anonymizer.anonymize(
|
||||
text=text,
|
||||
analyzer_results=analyzer_results,
|
||||
operators={entity_type: OperatorConfig(self.anonymization_operator)
|
||||
for entity_type in set(result.entity_type for result in analyzer_results)}
|
||||
)
|
||||
|
||||
# Store mapping for deanonymization
|
||||
mapping_key = str(hash(text))
|
||||
self.anonymization_mappings[mapping_key] = {
|
||||
"original": text,
|
||||
"anonymized": anonymized_result.text,
|
||||
"results": analyzer_results
|
||||
}
|
||||
|
||||
return anonymized_result.text, analyzer_results
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to anonymize text: {e}")
|
||||
return text, []
|
||||
|
||||
async def _deanonymize_text(self, text: str) -> Tuple[str, bool]:
|
||||
"""Attempt to deanonymize text using stored mappings."""
|
||||
try:
|
||||
# Look for matching anonymized text in mappings
|
||||
for mapping_key, mapping in self.anonymization_mappings.items():
|
||||
if mapping["anonymized"] == text:
|
||||
return mapping["original"], True
|
||||
|
||||
# If no mapping found, return original text
|
||||
return text, False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to deanonymize text: {e}")
|
||||
return text, False
|
||||
210
libs/python/agent/agent/callbacks/telemetry.py
Normal file
210
libs/python/agent/agent/callbacks/telemetry.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""
|
||||
Telemetry callback handler for Computer-Use Agent (cua-agent)
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import List, Dict, Any, Optional, Union
|
||||
|
||||
from .base import AsyncCallbackHandler
|
||||
from ..telemetry import (
|
||||
record_event,
|
||||
is_telemetry_enabled,
|
||||
set_dimension,
|
||||
SYSTEM_INFO,
|
||||
)
|
||||
|
||||
|
||||
class TelemetryCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Telemetry callback handler for Computer-Use Agent (cua-agent)
|
||||
|
||||
Tracks agent usage, performance metrics, and optionally trajectory data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent,
|
||||
log_trajectory: bool = False
|
||||
):
|
||||
"""
|
||||
Initialize telemetry callback.
|
||||
|
||||
Args:
|
||||
agent: The ComputerAgent instance
|
||||
log_trajectory: Whether to log full trajectory items (opt-in)
|
||||
"""
|
||||
self.agent = agent
|
||||
self.log_trajectory = log_trajectory
|
||||
|
||||
# Generate session/run IDs
|
||||
self.session_id = str(uuid.uuid4())
|
||||
self.run_id = None
|
||||
|
||||
# Track timing and metrics
|
||||
self.run_start_time = None
|
||||
self.step_count = 0
|
||||
self.step_start_time = None
|
||||
self.total_usage = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"response_cost": 0.0
|
||||
}
|
||||
|
||||
# Record agent initialization
|
||||
if is_telemetry_enabled():
|
||||
self._record_agent_initialization()
|
||||
|
||||
def _record_agent_initialization(self) -> None:
|
||||
"""Record agent type/model and session initialization."""
|
||||
agent_info = {
|
||||
"session_id": self.session_id,
|
||||
"agent_type": self.agent.agent_loop.__name__ if hasattr(self.agent, 'agent_loop') else 'unknown',
|
||||
"model": getattr(self.agent, 'model', 'unknown'),
|
||||
**SYSTEM_INFO
|
||||
}
|
||||
|
||||
# Set session-level dimensions
|
||||
set_dimension("session_id", self.session_id)
|
||||
set_dimension("agent_type", agent_info["agent_type"])
|
||||
set_dimension("model", agent_info["model"])
|
||||
|
||||
record_event("agent_session_start", agent_info)
|
||||
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called at the start of an agent run loop."""
|
||||
if not is_telemetry_enabled():
|
||||
return
|
||||
|
||||
self.run_id = str(uuid.uuid4())
|
||||
self.run_start_time = time.time()
|
||||
self.step_count = 0
|
||||
|
||||
# Calculate input context size
|
||||
input_context_size = self._calculate_context_size(old_items)
|
||||
|
||||
run_data = {
|
||||
"session_id": self.session_id,
|
||||
"run_id": self.run_id,
|
||||
"start_time": self.run_start_time,
|
||||
"input_context_size": input_context_size,
|
||||
"num_existing_messages": len(old_items)
|
||||
}
|
||||
|
||||
# Log trajectory if opted in
|
||||
if self.log_trajectory:
|
||||
trajectory = self._extract_trajectory(old_items)
|
||||
if trajectory:
|
||||
run_data["uploaded_trajectory"] = trajectory
|
||||
|
||||
set_dimension("run_id", self.run_id)
|
||||
record_event("agent_run_start", run_data)
|
||||
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
"""Called at the end of an agent run loop."""
|
||||
if not is_telemetry_enabled() or not self.run_start_time:
|
||||
return
|
||||
|
||||
run_duration = time.time() - self.run_start_time
|
||||
|
||||
run_data = {
|
||||
"session_id": self.session_id,
|
||||
"run_id": self.run_id,
|
||||
"end_time": time.time(),
|
||||
"duration_seconds": run_duration,
|
||||
"num_steps": self.step_count,
|
||||
"total_usage": self.total_usage.copy()
|
||||
}
|
||||
|
||||
# Log trajectory if opted in
|
||||
if self.log_trajectory:
|
||||
trajectory = self._extract_trajectory(new_items)
|
||||
if trajectory:
|
||||
run_data["uploaded_trajectory"] = trajectory
|
||||
|
||||
record_event("agent_run_end", run_data)
|
||||
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Called when usage information is received."""
|
||||
if not is_telemetry_enabled():
|
||||
return
|
||||
|
||||
# Accumulate usage stats
|
||||
self.total_usage["prompt_tokens"] += usage.get("prompt_tokens", 0)
|
||||
self.total_usage["completion_tokens"] += usage.get("completion_tokens", 0)
|
||||
self.total_usage["total_tokens"] += usage.get("total_tokens", 0)
|
||||
self.total_usage["response_cost"] += usage.get("response_cost", 0.0)
|
||||
|
||||
# Record individual usage event
|
||||
usage_data = {
|
||||
"session_id": self.session_id,
|
||||
"run_id": self.run_id,
|
||||
"step": self.step_count,
|
||||
**usage
|
||||
}
|
||||
|
||||
record_event("agent_usage", usage_data)
|
||||
|
||||
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
||||
"""Called when responses are received."""
|
||||
if not is_telemetry_enabled():
|
||||
return
|
||||
|
||||
self.step_count += 1
|
||||
step_duration = None
|
||||
|
||||
if self.step_start_time:
|
||||
step_duration = time.time() - self.step_start_time
|
||||
|
||||
self.step_start_time = time.time()
|
||||
|
||||
step_data = {
|
||||
"session_id": self.session_id,
|
||||
"run_id": self.run_id,
|
||||
"step": self.step_count,
|
||||
"timestamp": self.step_start_time
|
||||
}
|
||||
|
||||
if step_duration is not None:
|
||||
step_data["duration_seconds"] = step_duration
|
||||
|
||||
record_event("agent_step", step_data)
|
||||
|
||||
def _calculate_context_size(self, items: List[Dict[str, Any]]) -> int:
|
||||
"""Calculate approximate context size in tokens/characters."""
|
||||
total_size = 0
|
||||
|
||||
for item in items:
|
||||
if item.get("type") == "message" and "content" in item:
|
||||
content = item["content"]
|
||||
if isinstance(content, str):
|
||||
total_size += len(content)
|
||||
elif isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict) and "text" in part:
|
||||
total_size += len(part["text"])
|
||||
elif "content" in item and isinstance(item["content"], str):
|
||||
total_size += len(item["content"])
|
||||
|
||||
return total_size
|
||||
|
||||
def _extract_trajectory(self, items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Extract trajectory items that should be logged."""
|
||||
trajectory = []
|
||||
|
||||
for item in items:
|
||||
# Include user messages, assistant messages, reasoning, computer calls, and computer outputs
|
||||
if (
|
||||
item.get("role") == "user" or # User inputs
|
||||
(item.get("type") == "message" and item.get("role") == "assistant") or # Model outputs
|
||||
item.get("type") == "reasoning" or # Reasoning traces
|
||||
item.get("type") == "computer_call" or # Computer actions
|
||||
item.get("type") == "computer_call_output" # Computer outputs
|
||||
):
|
||||
# Create a copy of the item with timestamp
|
||||
trajectory_item = item.copy()
|
||||
trajectory_item["logged_at"] = time.time()
|
||||
trajectory.append(trajectory_item)
|
||||
|
||||
return trajectory
|
||||
305
libs/python/agent/agent/callbacks/trajectory_saver.py
Normal file
305
libs/python/agent/agent/callbacks/trajectory_saver.py
Normal file
@@ -0,0 +1,305 @@
|
||||
"""
|
||||
Trajectory saving callback handler for ComputerAgent.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, Optional, Union, override
|
||||
from PIL import Image, ImageDraw
|
||||
import io
|
||||
from .base import AsyncCallbackHandler
|
||||
|
||||
def sanitize_image_urls(data: Any) -> Any:
|
||||
"""
|
||||
Recursively search for 'image_url' keys and set their values to '[omitted]'.
|
||||
|
||||
Args:
|
||||
data: Any data structure (dict, list, or primitive type)
|
||||
|
||||
Returns:
|
||||
A deep copy of the data with all 'image_url' values replaced with '[omitted]'
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
# Create a copy of the dictionary
|
||||
sanitized = {}
|
||||
for key, value in data.items():
|
||||
if key == "image_url":
|
||||
sanitized[key] = "[omitted]"
|
||||
else:
|
||||
# Recursively sanitize the value
|
||||
sanitized[key] = sanitize_image_urls(value)
|
||||
return sanitized
|
||||
|
||||
elif isinstance(data, list):
|
||||
# Recursively sanitize each item in the list
|
||||
return [sanitize_image_urls(item) for item in data]
|
||||
|
||||
else:
|
||||
# For primitive types (str, int, bool, None, etc.), return as-is
|
||||
return data
|
||||
|
||||
|
||||
class TrajectorySaverCallback(AsyncCallbackHandler):
|
||||
"""
|
||||
Callback handler that saves agent trajectories to disk.
|
||||
|
||||
Saves each run as a separate trajectory with unique ID, and each turn
|
||||
within the trajectory gets its own folder with screenshots and responses.
|
||||
"""
|
||||
|
||||
def __init__(self, trajectory_dir: str):
|
||||
"""
|
||||
Initialize trajectory saver.
|
||||
|
||||
Args:
|
||||
trajectory_dir: Base directory to save trajectories
|
||||
"""
|
||||
self.trajectory_dir = Path(trajectory_dir)
|
||||
self.trajectory_id: Optional[str] = None
|
||||
self.current_turn: int = 0
|
||||
self.current_artifact: int = 0
|
||||
self.model: Optional[str] = None
|
||||
self.total_usage: Dict[str, Any] = {}
|
||||
|
||||
# Ensure trajectory directory exists
|
||||
self.trajectory_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _get_turn_dir(self) -> Path:
|
||||
"""Get the directory for the current turn."""
|
||||
if not self.trajectory_id:
|
||||
raise ValueError("Trajectory not initialized - call _on_run_start first")
|
||||
|
||||
# format: trajectory_id/turn_000
|
||||
turn_dir = self.trajectory_dir / self.trajectory_id / f"turn_{self.current_turn:03d}"
|
||||
turn_dir.mkdir(parents=True, exist_ok=True)
|
||||
return turn_dir
|
||||
|
||||
def _save_artifact(self, name: str, artifact: Union[str, bytes, Dict[str, Any]]) -> None:
|
||||
"""Save an artifact to the current turn directory."""
|
||||
turn_dir = self._get_turn_dir()
|
||||
if isinstance(artifact, bytes):
|
||||
# format: turn_000/0000_name.png
|
||||
artifact_filename = f"{self.current_artifact:04d}_{name}"
|
||||
artifact_path = turn_dir / f"{artifact_filename}.png"
|
||||
with open(artifact_path, "wb") as f:
|
||||
f.write(artifact)
|
||||
else:
|
||||
# format: turn_000/0000_name.json
|
||||
artifact_filename = f"{self.current_artifact:04d}_{name}"
|
||||
artifact_path = turn_dir / f"{artifact_filename}.json"
|
||||
with open(artifact_path, "w") as f:
|
||||
json.dump(sanitize_image_urls(artifact), f, indent=2)
|
||||
self.current_artifact += 1
|
||||
|
||||
def _update_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Update total usage statistics."""
|
||||
def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None:
|
||||
for key, value in source.items():
|
||||
if isinstance(value, dict):
|
||||
if key not in target:
|
||||
target[key] = {}
|
||||
add_dicts(target[key], value)
|
||||
else:
|
||||
if key not in target:
|
||||
target[key] = 0
|
||||
target[key] += value
|
||||
add_dicts(self.total_usage, usage)
|
||||
|
||||
@override
|
||||
async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None:
|
||||
"""Initialize trajectory tracking for a new run."""
|
||||
model = kwargs.get("model", "unknown")
|
||||
model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16]
|
||||
if "+" in model:
|
||||
model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short
|
||||
|
||||
# id format: yyyy-mm-dd_model_hhmmss_uuid[:4]
|
||||
now = datetime.now()
|
||||
self.trajectory_id = f"{now.strftime('%Y-%m-%d')}_{model_name_short}_{now.strftime('%H%M%S')}_{str(uuid.uuid4())[:4]}"
|
||||
self.current_turn = 0
|
||||
self.current_artifact = 0
|
||||
self.model = model
|
||||
self.total_usage = {}
|
||||
|
||||
# Create trajectory directory
|
||||
trajectory_path = self.trajectory_dir / self.trajectory_id
|
||||
trajectory_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save trajectory metadata
|
||||
metadata = {
|
||||
"trajectory_id": self.trajectory_id,
|
||||
"created_at": str(uuid.uuid1().time),
|
||||
"status": "running",
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
|
||||
with open(trajectory_path / "metadata.json", "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
@override
|
||||
async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None:
|
||||
"""Finalize run tracking by updating metadata with completion status, usage, and new items."""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
# Update metadata with completion status, total usage, and new items
|
||||
trajectory_path = self.trajectory_dir / self.trajectory_id
|
||||
metadata_path = trajectory_path / "metadata.json"
|
||||
|
||||
# Read existing metadata
|
||||
if metadata_path.exists():
|
||||
with open(metadata_path, "r") as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
metadata = {}
|
||||
|
||||
# Update metadata with completion info
|
||||
metadata.update({
|
||||
"status": "completed",
|
||||
"completed_at": str(uuid.uuid1().time),
|
||||
"total_usage": self.total_usage,
|
||||
"new_items": sanitize_image_urls(new_items),
|
||||
"total_turns": self.current_turn
|
||||
})
|
||||
|
||||
# Save updated metadata
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(metadata, f, indent=2)
|
||||
|
||||
@override
|
||||
async def on_api_start(self, kwargs: Dict[str, Any]) -> None:
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
self._save_artifact("api_start", { "kwargs": kwargs })
|
||||
|
||||
@override
|
||||
async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None:
|
||||
"""Save API call result."""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
self._save_artifact("api_result", { "kwargs": kwargs, "result": result })
|
||||
|
||||
@override
|
||||
async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None:
|
||||
"""Save a screenshot."""
|
||||
if isinstance(screenshot, str):
|
||||
screenshot = base64.b64decode(screenshot)
|
||||
self._save_artifact(name, screenshot)
|
||||
|
||||
@override
|
||||
async def on_usage(self, usage: Dict[str, Any]) -> None:
|
||||
"""Called when usage information is received."""
|
||||
self._update_usage(usage)
|
||||
|
||||
@override
|
||||
async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None:
|
||||
"""Save responses to the current turn directory and update usage statistics."""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
# Save responses
|
||||
turn_dir = self._get_turn_dir()
|
||||
response_data = {
|
||||
"timestamp": str(uuid.uuid1().time),
|
||||
"model": self.model,
|
||||
"kwargs": kwargs,
|
||||
"response": responses
|
||||
}
|
||||
|
||||
self._save_artifact("agent_response", response_data)
|
||||
|
||||
# Increment turn counter
|
||||
self.current_turn += 1
|
||||
|
||||
def _draw_crosshair_on_image(self, image_bytes: bytes, x: int, y: int) -> bytes:
|
||||
"""
|
||||
Draw a red dot and crosshair at the specified coordinates on the image.
|
||||
|
||||
Args:
|
||||
image_bytes: The original image as bytes
|
||||
x: X coordinate for the crosshair
|
||||
y: Y coordinate for the crosshair
|
||||
|
||||
Returns:
|
||||
Modified image as bytes with red dot and crosshair
|
||||
"""
|
||||
# Open the image
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
# Draw crosshair lines (red, 2px thick)
|
||||
crosshair_size = 20
|
||||
line_width = 2
|
||||
color = "red"
|
||||
|
||||
# Horizontal line
|
||||
draw.line([(x - crosshair_size, y), (x + crosshair_size, y)], fill=color, width=line_width)
|
||||
# Vertical line
|
||||
draw.line([(x, y - crosshair_size), (x, y + crosshair_size)], fill=color, width=line_width)
|
||||
|
||||
# Draw center dot (filled circle)
|
||||
dot_radius = 3
|
||||
draw.ellipse([(x - dot_radius, y - dot_radius), (x + dot_radius, y + dot_radius)], fill=color)
|
||||
|
||||
# Convert back to bytes
|
||||
output = io.BytesIO()
|
||||
image.save(output, format='PNG')
|
||||
return output.getvalue()
|
||||
|
||||
@override
|
||||
async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Called when a computer call has completed.
|
||||
Saves screenshots and computer call output.
|
||||
"""
|
||||
if not self.trajectory_id:
|
||||
return
|
||||
|
||||
self._save_artifact("computer_call_result", { "item": item, "result": result })
|
||||
|
||||
# Check if action has x/y coordinates and there's a screenshot in the result
|
||||
action = item.get("action", {})
|
||||
if "x" in action and "y" in action:
|
||||
# Look for screenshot in the result
|
||||
for result_item in result:
|
||||
if (result_item.get("type") == "computer_call_output" and
|
||||
result_item.get("output", {}).get("type") == "input_image"):
|
||||
|
||||
image_url = result_item["output"]["image_url"]
|
||||
|
||||
# Extract base64 image data
|
||||
if image_url.startswith("data:image/"):
|
||||
# Format: data:image/png;base64,<base64_data>
|
||||
base64_data = image_url.split(",", 1)[1]
|
||||
else:
|
||||
# Assume it's just base64 data
|
||||
base64_data = image_url
|
||||
|
||||
try:
|
||||
# Decode the image
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
|
||||
# Draw crosshair at the action coordinates
|
||||
annotated_image = self._draw_crosshair_on_image(
|
||||
image_bytes,
|
||||
int(action["x"]),
|
||||
int(action["y"])
|
||||
)
|
||||
|
||||
# Save as screenshot_action
|
||||
self._save_artifact("screenshot_action", annotated_image)
|
||||
|
||||
except Exception as e:
|
||||
# If annotation fails, just log and continue
|
||||
print(f"Failed to annotate screenshot: {e}")
|
||||
|
||||
break # Only process the first screenshot found
|
||||
|
||||
# Increment turn counter
|
||||
self.current_turn += 1
|
||||
359
libs/python/agent/agent/cli.py
Normal file
359
libs/python/agent/agent/cli.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
CLI chat interface for agent - Computer Use Agent
|
||||
|
||||
Usage:
|
||||
python -m agent.cli <model_string>
|
||||
|
||||
Examples:
|
||||
python -m agent.cli openai/computer-use-preview
|
||||
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
|
||||
"""
|
||||
|
||||
try:
|
||||
import asyncio
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
import dotenv
|
||||
from yaspin import yaspin
|
||||
except ImportError:
|
||||
if __name__ == "__main__":
|
||||
raise ImportError(
|
||||
"CLI dependencies not found. "
|
||||
"Please install with: pip install \"cua-agent[cli]\""
|
||||
)
|
||||
|
||||
# Load environment variables
|
||||
dotenv.load_dotenv()
|
||||
|
||||
# Color codes for terminal output
|
||||
class Colors:
|
||||
RESET = '\033[0m'
|
||||
BOLD = '\033[1m'
|
||||
DIM = '\033[2m'
|
||||
|
||||
# Text colors
|
||||
RED = '\033[31m'
|
||||
GREEN = '\033[32m'
|
||||
YELLOW = '\033[33m'
|
||||
BLUE = '\033[34m'
|
||||
MAGENTA = '\033[35m'
|
||||
CYAN = '\033[36m'
|
||||
WHITE = '\033[37m'
|
||||
GRAY = '\033[90m'
|
||||
|
||||
# Background colors
|
||||
BG_RED = '\033[41m'
|
||||
BG_GREEN = '\033[42m'
|
||||
BG_YELLOW = '\033[43m'
|
||||
BG_BLUE = '\033[44m'
|
||||
|
||||
def print_colored(text: str, color: str = "", bold: bool = False, dim: bool = False, end: str = "\n", right: str = ""):
|
||||
"""Print colored text to terminal with optional right-aligned text."""
|
||||
prefix = ""
|
||||
if bold:
|
||||
prefix += Colors.BOLD
|
||||
if dim:
|
||||
prefix += Colors.DIM
|
||||
if color:
|
||||
prefix += color
|
||||
|
||||
if right:
|
||||
# Get terminal width (default to 80 if unable to determine)
|
||||
try:
|
||||
import shutil
|
||||
terminal_width = shutil.get_terminal_size().columns
|
||||
except:
|
||||
terminal_width = 80
|
||||
|
||||
# Add right margin
|
||||
terminal_width -= 1
|
||||
|
||||
# Calculate padding needed
|
||||
# Account for ANSI escape codes not taking visual space
|
||||
visible_left_len = len(text)
|
||||
visible_right_len = len(right)
|
||||
padding = terminal_width - visible_left_len - visible_right_len
|
||||
|
||||
if padding > 0:
|
||||
output = f"{prefix}{text}{' ' * padding}{right}{Colors.RESET}"
|
||||
else:
|
||||
# If not enough space, just put a single space between
|
||||
output = f"{prefix}{text} {right}{Colors.RESET}"
|
||||
else:
|
||||
output = f"{prefix}{text}{Colors.RESET}"
|
||||
|
||||
print(output, end=end)
|
||||
|
||||
|
||||
def print_action(action_type: str, details: Dict[str, Any], total_cost: float):
|
||||
"""Print computer action with nice formatting."""
|
||||
# Format action details
|
||||
args_str = ""
|
||||
if action_type == "click" and "x" in details and "y" in details:
|
||||
args_str = f"({details['x']}, {details['y']})"
|
||||
elif action_type == "type" and "text" in details:
|
||||
text = details["text"]
|
||||
if len(text) > 50:
|
||||
text = text[:47] + "..."
|
||||
args_str = f'"{text}"'
|
||||
elif action_type == "key" and "key" in details:
|
||||
args_str = f"'{details['key']}'"
|
||||
elif action_type == "scroll" and "x" in details and "y" in details:
|
||||
args_str = f"({details['x']}, {details['y']})"
|
||||
|
||||
if total_cost > 0:
|
||||
print_colored(f"🛠️ {action_type}{args_str}", dim=True, right=f"💸 ${total_cost:.2f}")
|
||||
else:
|
||||
print_colored(f"🛠️ {action_type}{args_str}", dim=True)
|
||||
|
||||
def print_welcome(model: str, agent_loop: str, container_name: str):
|
||||
"""Print welcome message."""
|
||||
print_colored(f"Connected to {container_name} ({model}, {agent_loop})")
|
||||
print_colored("Type 'exit' to quit.", dim=True)
|
||||
|
||||
async def ainput(prompt: str = ""):
|
||||
return await asyncio.to_thread(input, prompt)
|
||||
|
||||
async def chat_loop(agent, model: str, container_name: str, initial_prompt: str = "", show_usage: bool = True):
|
||||
"""Main chat loop with the agent."""
|
||||
print_welcome(model, agent.agent_loop.__name__, container_name)
|
||||
|
||||
history = []
|
||||
|
||||
if initial_prompt:
|
||||
history.append({"role": "user", "content": initial_prompt})
|
||||
|
||||
total_cost = 0
|
||||
|
||||
while True:
|
||||
if history[-1].get("role") != "user":
|
||||
# Get user input with prompt
|
||||
print_colored("> ", end="")
|
||||
user_input = await ainput()
|
||||
|
||||
if user_input.lower() in ['exit', 'quit', 'q']:
|
||||
print_colored("\n👋 Goodbye!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Add user message to history
|
||||
history.append({"role": "user", "content": user_input})
|
||||
|
||||
# Stream responses from the agent with spinner
|
||||
with yaspin(text="Thinking...", spinner="line", attrs=["dark"]) as spinner:
|
||||
spinner.hide()
|
||||
|
||||
async for result in agent.run(history):
|
||||
# Add agent responses to history
|
||||
history.extend(result.get("output", []))
|
||||
|
||||
if show_usage:
|
||||
total_cost += result.get("usage", {}).get("response_cost", 0)
|
||||
|
||||
# Process and display the output
|
||||
for item in result.get("output", []):
|
||||
if item.get("type") == "message":
|
||||
# Display agent text response
|
||||
content = item.get("content", [])
|
||||
for content_part in content:
|
||||
if content_part.get("text"):
|
||||
text = content_part.get("text", "").strip()
|
||||
if text:
|
||||
spinner.hide()
|
||||
print_colored(text)
|
||||
|
||||
elif item.get("type") == "computer_call":
|
||||
# Display computer action
|
||||
action = item.get("action", {})
|
||||
action_type = action.get("type", "")
|
||||
if action_type:
|
||||
spinner.hide()
|
||||
print_action(action_type, action, total_cost)
|
||||
spinner.text = f"Performing {action_type}..."
|
||||
spinner.show()
|
||||
|
||||
elif item.get("type") == "function_call":
|
||||
# Display function call
|
||||
function_name = item.get("name", "")
|
||||
spinner.hide()
|
||||
print_colored(f"🔧 Calling function: {function_name}", dim=True)
|
||||
spinner.text = f"Calling {function_name}..."
|
||||
spinner.show()
|
||||
|
||||
elif item.get("type") == "function_call_output":
|
||||
# Display function output (dimmed)
|
||||
output = item.get("output", "")
|
||||
if output and len(output.strip()) > 0:
|
||||
spinner.hide()
|
||||
print_colored(f"📤 {output}", dim=True)
|
||||
|
||||
spinner.hide()
|
||||
if show_usage and total_cost > 0:
|
||||
print_colored(f"Total cost: ${total_cost:.2f}", dim=True)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main CLI function."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="CUA Agent CLI - Interactive computer use assistant",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
python -m agent.cli openai/computer-use-preview
|
||||
python -m agent.cli anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli omniparser+anthropic/claude-3-5-sonnet-20241022
|
||||
python -m agent.cli huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"model",
|
||||
help="Model string (e.g., 'openai/computer-use-preview', 'anthropic/claude-3-5-sonnet-20241022')"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--images",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Number of recent images to keep in context (default: 3)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--trajectory",
|
||||
action="store_true",
|
||||
help="Save trajectory for debugging"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--budget",
|
||||
type=float,
|
||||
help="Maximum budget for the session (in dollars)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
help="Enable verbose logging"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-p", "--prompt",
|
||||
type=str,
|
||||
help="Initial prompt to send to the agent. Leave blank for interactive mode."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-c", "--cache",
|
||||
action="store_true",
|
||||
help="Tell the API to enable caching"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-u", "--usage",
|
||||
action="store_true",
|
||||
help="Show total cost of the agent runs"
|
||||
)
|
||||
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Check for required environment variables
|
||||
container_name = os.getenv("CUA_CONTAINER_NAME")
|
||||
cua_api_key = os.getenv("CUA_API_KEY")
|
||||
|
||||
# Prompt for missing environment variables
|
||||
if not container_name:
|
||||
print_colored("CUA_CONTAINER_NAME not set.", dim=True)
|
||||
print_colored("You can get a CUA container at https://www.trycua.com/", dim=True)
|
||||
container_name = input("Enter your CUA container name: ").strip()
|
||||
if not container_name:
|
||||
print_colored("❌ Container name is required.")
|
||||
sys.exit(1)
|
||||
|
||||
if not cua_api_key:
|
||||
print_colored("CUA_API_KEY not set.", dim=True)
|
||||
cua_api_key = input("Enter your CUA API key: ").strip()
|
||||
if not cua_api_key:
|
||||
print_colored("❌ API key is required.")
|
||||
sys.exit(1)
|
||||
|
||||
# Check for provider-specific API keys based on model
|
||||
provider_api_keys = {
|
||||
"openai/": "OPENAI_API_KEY",
|
||||
"anthropic/": "ANTHROPIC_API_KEY",
|
||||
"omniparser+": "OPENAI_API_KEY",
|
||||
"omniparser+": "ANTHROPIC_API_KEY",
|
||||
}
|
||||
|
||||
# Find matching provider and check for API key
|
||||
for prefix, env_var in provider_api_keys.items():
|
||||
if args.model.startswith(prefix):
|
||||
if not os.getenv(env_var):
|
||||
print_colored(f"{env_var} not set.", dim=True)
|
||||
api_key = input(f"Enter your {env_var.replace('_', ' ').title()}: ").strip()
|
||||
if not api_key:
|
||||
print_colored(f"❌ {env_var.replace('_', ' ').title()} is required.")
|
||||
sys.exit(1)
|
||||
# Set the environment variable for the session
|
||||
os.environ[env_var] = api_key
|
||||
break
|
||||
|
||||
# Import here to avoid import errors if dependencies are missing
|
||||
try:
|
||||
from agent import ComputerAgent
|
||||
from computer import Computer
|
||||
except ImportError as e:
|
||||
print_colored(f"❌ Import error: {e}", Colors.RED, bold=True)
|
||||
print_colored("Make sure agent and computer libraries are installed.", Colors.YELLOW)
|
||||
sys.exit(1)
|
||||
|
||||
# Create computer instance
|
||||
async with Computer(
|
||||
os_type="linux",
|
||||
provider_type="cloud",
|
||||
name=container_name,
|
||||
api_key=cua_api_key
|
||||
) as computer:
|
||||
|
||||
# Create agent
|
||||
agent_kwargs = {
|
||||
"model": args.model,
|
||||
"tools": [computer],
|
||||
"verbosity": 20 if args.verbose else 30, # DEBUG vs WARNING
|
||||
}
|
||||
|
||||
if args.images > 0:
|
||||
agent_kwargs["only_n_most_recent_images"] = args.images
|
||||
|
||||
if args.trajectory:
|
||||
agent_kwargs["trajectory_dir"] = "trajectories"
|
||||
|
||||
if args.budget:
|
||||
agent_kwargs["max_trajectory_budget"] = {
|
||||
"max_budget": args.budget,
|
||||
"raise_error": True,
|
||||
"reset_after_each_run": False
|
||||
}
|
||||
|
||||
if args.cache:
|
||||
agent_kwargs["use_prompt_caching"] = True
|
||||
|
||||
agent = ComputerAgent(**agent_kwargs)
|
||||
|
||||
# Start chat loop
|
||||
await chat_loop(agent, args.model, container_name, args.prompt, args.usage)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except (KeyboardInterrupt, EOFError) as _:
|
||||
print_colored("\n\n👋 Goodbye!")
|
||||
107
libs/python/agent/agent/computer_handler.py
Normal file
107
libs/python/agent/agent/computer_handler.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Computer handler implementation for OpenAI computer-use-preview protocol.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Dict, List, Any, Literal
|
||||
from .types import Computer
|
||||
|
||||
|
||||
class OpenAIComputerHandler:
|
||||
"""Computer handler that implements the Computer protocol using the computer interface."""
|
||||
|
||||
def __init__(self, computer_interface):
|
||||
"""Initialize with a computer interface (from tool schema)."""
|
||||
self.interface = computer_interface
|
||||
|
||||
async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]:
|
||||
"""Get the current environment type."""
|
||||
# For now, return a default - this could be enhanced to detect actual environment
|
||||
return "windows"
|
||||
|
||||
async def get_dimensions(self) -> tuple[int, int]:
|
||||
"""Get screen dimensions as (width, height)."""
|
||||
screen_size = await self.interface.get_screen_size()
|
||||
return screen_size["width"], screen_size["height"]
|
||||
|
||||
async def screenshot(self) -> str:
|
||||
"""Take a screenshot and return as base64 string."""
|
||||
screenshot_bytes = await self.interface.screenshot()
|
||||
return base64.b64encode(screenshot_bytes).decode('utf-8')
|
||||
|
||||
async def click(self, x: int, y: int, button: str = "left") -> None:
|
||||
"""Click at coordinates with specified button."""
|
||||
if button == "left":
|
||||
await self.interface.left_click(x, y)
|
||||
elif button == "right":
|
||||
await self.interface.right_click(x, y)
|
||||
else:
|
||||
# Default to left click for unknown buttons
|
||||
await self.interface.left_click(x, y)
|
||||
|
||||
async def double_click(self, x: int, y: int) -> None:
|
||||
"""Double click at coordinates."""
|
||||
await self.interface.double_click(x, y)
|
||||
|
||||
async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None:
|
||||
"""Scroll at coordinates with specified scroll amounts."""
|
||||
await self.interface.move_cursor(x, y)
|
||||
await self.interface.scroll(scroll_x, scroll_y)
|
||||
|
||||
async def type(self, text: str) -> None:
|
||||
"""Type text."""
|
||||
await self.interface.type_text(text)
|
||||
|
||||
async def wait(self, ms: int = 1000) -> None:
|
||||
"""Wait for specified milliseconds."""
|
||||
import asyncio
|
||||
await asyncio.sleep(ms / 1000.0)
|
||||
|
||||
async def move(self, x: int, y: int) -> None:
|
||||
"""Move cursor to coordinates."""
|
||||
await self.interface.move_cursor(x, y)
|
||||
|
||||
async def keypress(self, keys: List[str]) -> None:
|
||||
"""Press key combination."""
|
||||
if len(keys) == 1:
|
||||
await self.interface.press_key(keys[0])
|
||||
else:
|
||||
# Handle key combinations
|
||||
await self.interface.hotkey(*keys)
|
||||
|
||||
async def drag(self, path: List[Dict[str, int]]) -> None:
|
||||
"""Drag along specified path."""
|
||||
if not path:
|
||||
return
|
||||
|
||||
# Start drag from first point
|
||||
start = path[0]
|
||||
await self.interface.mouse_down(start["x"], start["y"])
|
||||
|
||||
# Move through path
|
||||
for point in path[1:]:
|
||||
await self.interface.move_cursor(point["x"], point["y"])
|
||||
|
||||
# End drag at last point
|
||||
end = path[-1]
|
||||
await self.interface.mouse_up(end["x"], end["y"])
|
||||
|
||||
async def get_current_url(self) -> str:
|
||||
"""Get current URL (for browser environments)."""
|
||||
# This would need to be implemented based on the specific browser interface
|
||||
# For now, return empty string
|
||||
return ""
|
||||
|
||||
|
||||
def acknowledge_safety_check_callback(message: str) -> bool:
|
||||
"""Safety check callback for user acknowledgment."""
|
||||
response = input(
|
||||
f"Safety Check Warning: {message}\nDo you want to acknowledge and proceed? (y/n): "
|
||||
).lower()
|
||||
return response.strip() == "y"
|
||||
|
||||
|
||||
def check_blocklisted_url(url: str) -> None:
|
||||
"""Check if URL is blocklisted (placeholder implementation)."""
|
||||
# This would contain actual URL checking logic
|
||||
pass
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Core agent components."""
|
||||
|
||||
from .factory import BaseLoop
|
||||
from .messages import (
|
||||
StandardMessageManager,
|
||||
ImageRetentionConfig,
|
||||
)
|
||||
from .callbacks import (
|
||||
CallbackManager,
|
||||
CallbackHandler,
|
||||
BaseCallbackManager,
|
||||
ContentCallback,
|
||||
ToolCallback,
|
||||
APICallback,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseLoop",
|
||||
"CallbackManager",
|
||||
"CallbackHandler",
|
||||
"StandardMessageManager",
|
||||
"ImageRetentionConfig",
|
||||
"BaseCallbackManager",
|
||||
"ContentCallback",
|
||||
"ToolCallback",
|
||||
"APICallback",
|
||||
]
|
||||
@@ -1,210 +0,0 @@
|
||||
"""Main entry point for computer agents."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from computer import Computer
|
||||
from .types import LLM, AgentLoop
|
||||
from .types import AgentResponse
|
||||
from .factory import LoopFactory
|
||||
from .provider_config import DEFAULT_MODELS, ENV_VARS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ComputerAgent:
|
||||
"""A computer agent that can perform automated tasks using natural language instructions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
computer: Computer,
|
||||
model: LLM,
|
||||
loop: AgentLoop,
|
||||
max_retries: int = 3,
|
||||
screenshot_dir: Optional[str] = None,
|
||||
log_dir: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
save_trajectory: bool = True,
|
||||
trajectory_dir: str = "trajectories",
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
verbosity: int = logging.INFO,
|
||||
):
|
||||
"""Initialize the ComputerAgent.
|
||||
|
||||
Args:
|
||||
computer: Computer instance. If not provided, one will be created with default settings.
|
||||
max_retries: Maximum number of retry attempts.
|
||||
screenshot_dir: Directory to save screenshots.
|
||||
log_dir: Directory to save logs (set to None to disable logging to files).
|
||||
model: LLM object containing provider and model name. Takes precedence over provider/model_name.
|
||||
provider: The AI provider to use (e.g., LLMProvider.ANTHROPIC). Only used if model is None.
|
||||
api_key: The API key for the provider. If not provided, will look for environment variable.
|
||||
model_name: The model name to use. Only used if model is None.
|
||||
save_trajectory: Whether to save the trajectory.
|
||||
trajectory_dir: Directory to save the trajectory.
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests.
|
||||
verbosity: Logging level.
|
||||
"""
|
||||
# Basic agent configuration
|
||||
self.max_retries = max_retries
|
||||
self.computer = computer
|
||||
self.queue = asyncio.Queue()
|
||||
self.screenshot_dir = screenshot_dir
|
||||
self.log_dir = log_dir
|
||||
self._retry_count = 0
|
||||
self._initialized = False
|
||||
self._in_context = False
|
||||
|
||||
# Set logging level
|
||||
logger.setLevel(verbosity)
|
||||
|
||||
# Setup logging
|
||||
if self.log_dir:
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
logger.info(f"Created logs directory: {self.log_dir}")
|
||||
|
||||
# Setup screenshots directory
|
||||
if self.screenshot_dir:
|
||||
os.makedirs(self.screenshot_dir, exist_ok=True)
|
||||
logger.info(f"Created screenshots directory: {self.screenshot_dir}")
|
||||
|
||||
# Use the provided LLM object
|
||||
self.provider = model.provider
|
||||
actual_model_name = model.name or DEFAULT_MODELS.get(self.provider, "")
|
||||
self.provider_base_url = getattr(model, "provider_base_url", None)
|
||||
|
||||
# Ensure we have a valid model name
|
||||
if not actual_model_name:
|
||||
actual_model_name = DEFAULT_MODELS.get(self.provider, "")
|
||||
if not actual_model_name:
|
||||
raise ValueError(
|
||||
f"No model specified for provider {self.provider} and no default found"
|
||||
)
|
||||
|
||||
# Get API key from environment if not provided
|
||||
actual_api_key = api_key or os.environ.get(ENV_VARS[self.provider], "")
|
||||
# Ollama and OpenAI-compatible APIs typically don't require an API key
|
||||
if (
|
||||
not actual_api_key
|
||||
and str(self.provider) not in ["ollama", "oaicompat"]
|
||||
and ENV_VARS[self.provider] != "none"
|
||||
):
|
||||
raise ValueError(f"No API key provided for {self.provider}")
|
||||
|
||||
# Create the appropriate loop using the factory
|
||||
try:
|
||||
# Let the factory create the appropriate loop with needed components
|
||||
self._loop = LoopFactory.create_loop(
|
||||
loop_type=loop,
|
||||
provider=self.provider,
|
||||
computer=self.computer,
|
||||
model_name=actual_model_name,
|
||||
api_key=actual_api_key,
|
||||
save_trajectory=save_trajectory,
|
||||
trajectory_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
provider_base_url=self.provider_base_url,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to create loop: {str(e)}")
|
||||
raise
|
||||
|
||||
# Initialize the message manager from the loop
|
||||
self.message_manager = self._loop.message_manager
|
||||
|
||||
logger.info(
|
||||
f"ComputerAgent initialized with provider: {self.provider}, model: {actual_model_name}"
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Initialize the agent when used as a context manager."""
|
||||
logger.info("Entering ComputerAgent context")
|
||||
self._in_context = True
|
||||
|
||||
# In case the computer wasn't initialized
|
||||
try:
|
||||
# Initialize the computer only if not already initialized
|
||||
logger.info("Checking if computer is already initialized...")
|
||||
if not self.computer._initialized:
|
||||
logger.info("Initializing computer in __aenter__...")
|
||||
# Use the computer's __aenter__ directly instead of calling run()
|
||||
await self.computer.__aenter__()
|
||||
logger.info("Computer initialized in __aenter__")
|
||||
else:
|
||||
logger.info("Computer already initialized, skipping initialization")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing computer in __aenter__: {str(e)}")
|
||||
raise
|
||||
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Cleanup agent resources if needed."""
|
||||
logger.info("Cleaning up agent resources")
|
||||
self._in_context = False
|
||||
|
||||
# Do any necessary cleanup
|
||||
# We're not shutting down the computer here as it might be shared
|
||||
# Just log that we're exiting
|
||||
if exc_type:
|
||||
logger.error(f"Exiting agent context with error: {exc_type.__name__}: {exc_val}")
|
||||
else:
|
||||
logger.info("Exiting agent context normally")
|
||||
|
||||
# If we have a queue, make sure to signal it's done
|
||||
if hasattr(self, "queue") and self.queue:
|
||||
await self.queue.put(None) # Signal that we're done
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the agent and its components."""
|
||||
if not self._initialized:
|
||||
# Always initialize the computer if available
|
||||
if self.computer and not self.computer._initialized:
|
||||
await self.computer.run()
|
||||
self._initialized = True
|
||||
|
||||
async def run(self, task: str) -> AsyncGenerator[AgentResponse, None]:
|
||||
"""Run a task using the computer agent.
|
||||
|
||||
Args:
|
||||
task: Task description
|
||||
|
||||
Yields:
|
||||
Agent response format
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Running task: {task}")
|
||||
logger.info(
|
||||
f"Message history before task has {len(self.message_manager.messages)} messages"
|
||||
)
|
||||
|
||||
# Initialize the computer if needed
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
# Add task as a user message using the message manager
|
||||
self.message_manager.add_user_message([{"type": "text", "text": task}])
|
||||
logger.info(
|
||||
f"Added task message. Message history now has {len(self.message_manager.messages)} messages"
|
||||
)
|
||||
|
||||
# Pass properly formatted messages to the loop
|
||||
if self._loop is None:
|
||||
logger.error("Loop not initialized properly")
|
||||
yield {"error": "Loop not initialized properly"}
|
||||
return
|
||||
|
||||
# Execute the task and yield results
|
||||
async for result in self._loop.run(self.message_manager.messages):
|
||||
yield result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in agent run method: {str(e)}")
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
@@ -1,217 +0,0 @@
|
||||
"""Base loop definitions."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from computer import Computer
|
||||
from .messages import StandardMessageManager, ImageRetentionConfig
|
||||
from .types import AgentResponse
|
||||
from .experiment import ExperimentManager
|
||||
from .callbacks import CallbackManager, CallbackHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseLoop(ABC):
|
||||
"""Base class for agent loops that handle message processing and tool execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
computer: Computer,
|
||||
model: str,
|
||||
api_key: str,
|
||||
max_tokens: int = 4096,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
save_trajectory: bool = True,
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
callback_handlers: Optional[List[CallbackHandler]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize base agent loop.
|
||||
|
||||
Args:
|
||||
computer: Computer instance to control
|
||||
model: Model name to use
|
||||
api_key: API key for provider
|
||||
max_tokens: Maximum tokens to generate
|
||||
max_retries: Maximum number of retries
|
||||
retry_delay: Delay between retries in seconds
|
||||
base_dir: Base directory for saving experiment data
|
||||
save_trajectory: Whether to save trajectory data
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
**kwargs: Additional provider-specific arguments
|
||||
"""
|
||||
self.computer = computer
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.max_tokens = max_tokens
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
self.base_dir = base_dir
|
||||
self.save_trajectory = save_trajectory
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self._kwargs = kwargs
|
||||
|
||||
# Initialize message manager
|
||||
self.message_manager = StandardMessageManager(
|
||||
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
||||
)
|
||||
|
||||
# Initialize experiment manager
|
||||
if self.save_trajectory and self.base_dir:
|
||||
self.experiment_manager = ExperimentManager(
|
||||
base_dir=self.base_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
)
|
||||
# Track directories for convenience
|
||||
self.run_dir = self.experiment_manager.run_dir
|
||||
self.current_turn_dir = self.experiment_manager.current_turn_dir
|
||||
else:
|
||||
self.experiment_manager = None
|
||||
self.run_dir = None
|
||||
self.current_turn_dir = None
|
||||
|
||||
# Initialize basic tracking
|
||||
self.turn_count = 0
|
||||
|
||||
# Initialize callback manager
|
||||
self.callback_manager = CallbackManager(handlers=callback_handlers or [])
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize both the API client and computer interface with retries."""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
logger.info(
|
||||
f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..."
|
||||
)
|
||||
|
||||
# Initialize API client
|
||||
await self.initialize_client()
|
||||
|
||||
logger.info("Initialization complete.")
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.warning(
|
||||
f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..."
|
||||
)
|
||||
await asyncio.sleep(self.retry_delay)
|
||||
else:
|
||||
logger.error(
|
||||
f"Initialization failed after {self.max_retries} attempts: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(f"Failed to initialize: {str(e)}")
|
||||
|
||||
###########################################
|
||||
# ABSTRACT METHODS TO BE IMPLEMENTED BY SUBCLASSES
|
||||
###########################################
|
||||
|
||||
@abstractmethod
|
||||
async def initialize_client(self) -> None:
|
||||
"""Initialize the API client and any provider-specific components.
|
||||
|
||||
This method must be implemented by subclasses to set up
|
||||
provider-specific clients and tools.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
|
||||
Returns:
|
||||
An async generator that yields agent responses
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel the currently running agent loop task.
|
||||
|
||||
This method should stop any ongoing processing in the agent loop
|
||||
and clean up resources appropriately.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
###########################################
|
||||
# EXPERIMENT AND TRAJECTORY MANAGEMENT
|
||||
###########################################
|
||||
|
||||
def _setup_experiment_dirs(self) -> None:
|
||||
"""Setup the experiment directory structure."""
|
||||
if self.experiment_manager:
|
||||
# Use the experiment manager to set up directories
|
||||
self.experiment_manager.setup_experiment_dirs()
|
||||
|
||||
# Update local tracking variables
|
||||
self.run_dir = self.experiment_manager.run_dir
|
||||
self.current_turn_dir = self.experiment_manager.current_turn_dir
|
||||
|
||||
def _create_turn_dir(self) -> None:
|
||||
"""Create a new directory for the current turn."""
|
||||
if self.experiment_manager:
|
||||
# Use the experiment manager to create the turn directory
|
||||
self.experiment_manager.create_turn_dir()
|
||||
|
||||
# Update local tracking variables
|
||||
self.current_turn_dir = self.experiment_manager.current_turn_dir
|
||||
self.turn_count = self.experiment_manager.turn_count
|
||||
|
||||
def _log_api_call(
|
||||
self, call_type: str, request: Any, response: Any = None, error: Optional[Exception] = None
|
||||
) -> None:
|
||||
"""Log API call details to file.
|
||||
|
||||
Preserves provider-specific formats for requests and responses to ensure
|
||||
accurate logging for debugging and analysis purposes.
|
||||
|
||||
Args:
|
||||
call_type: Type of API call (e.g., 'request', 'response', 'error')
|
||||
request: The API request data in provider-specific format
|
||||
response: Optional API response data in provider-specific format
|
||||
error: Optional error information
|
||||
"""
|
||||
if self.experiment_manager:
|
||||
# Use the experiment manager to log the API call
|
||||
provider = getattr(self, "provider", "unknown")
|
||||
provider_str = str(provider) if provider else "unknown"
|
||||
|
||||
self.experiment_manager.log_api_call(
|
||||
call_type=call_type,
|
||||
request=request,
|
||||
provider=provider_str,
|
||||
model=self.model,
|
||||
response=response,
|
||||
error=error,
|
||||
)
|
||||
|
||||
def _save_screenshot(self, img_base64: str, action_type: str = "") -> None:
|
||||
"""Save a screenshot to the experiment directory.
|
||||
|
||||
Args:
|
||||
img_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
"""
|
||||
if self.experiment_manager:
|
||||
self.experiment_manager.save_screenshot(img_base64, action_type)
|
||||
|
||||
###########################################
|
||||
# EVENT HOOKS / CALLBACKS
|
||||
###########################################
|
||||
|
||||
async def handle_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[dict] = None) -> None:
|
||||
"""Process a screenshot through callback managers
|
||||
|
||||
Args:
|
||||
screenshot_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
"""
|
||||
if hasattr(self, 'callback_manager'):
|
||||
await self.callback_manager.on_screenshot(screenshot_base64, action_type, parsed_screen)
|
||||
@@ -1,200 +0,0 @@
|
||||
"""Callback handlers for agent."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Protocol
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ContentCallback(Protocol):
|
||||
"""Protocol for content callbacks."""
|
||||
def __call__(self, content: Dict[str, Any]) -> None: ...
|
||||
|
||||
class ToolCallback(Protocol):
|
||||
"""Protocol for tool callbacks."""
|
||||
def __call__(self, result: Any, tool_id: str) -> None: ...
|
||||
|
||||
class APICallback(Protocol):
|
||||
"""Protocol for API callbacks."""
|
||||
def __call__(self, request: Any, response: Any, error: Optional[Exception] = None) -> None: ...
|
||||
|
||||
class ScreenshotCallback(Protocol):
|
||||
"""Protocol for screenshot callbacks."""
|
||||
def __call__(self, screenshot_base64: str, action_type: str = "") -> Optional[str]: ...
|
||||
|
||||
class BaseCallbackManager(ABC):
|
||||
"""Base class for callback managers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content_callback: ContentCallback,
|
||||
tool_callback: ToolCallback,
|
||||
api_callback: APICallback,
|
||||
):
|
||||
"""Initialize the callback manager.
|
||||
|
||||
Args:
|
||||
content_callback: Callback for content updates
|
||||
tool_callback: Callback for tool execution results
|
||||
api_callback: Callback for API interactions
|
||||
"""
|
||||
self.content_callback = content_callback
|
||||
self.tool_callback = tool_callback
|
||||
self.api_callback = api_callback
|
||||
|
||||
@abstractmethod
|
||||
def on_content(self, content: Any) -> None:
|
||||
"""Handle content updates."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_tool_result(self, result: Any, tool_id: str) -> None:
|
||||
"""Handle tool execution results."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_api_interaction(
|
||||
self,
|
||||
request: Any,
|
||||
response: Any,
|
||||
error: Optional[Exception] = None
|
||||
) -> None:
|
||||
"""Handle API interactions."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CallbackManager:
|
||||
"""Manager for callback handlers."""
|
||||
|
||||
def __init__(self, handlers: Optional[List["CallbackHandler"]] = None):
|
||||
"""Initialize with optional handlers.
|
||||
|
||||
Args:
|
||||
handlers: List of callback handlers
|
||||
"""
|
||||
self.handlers = handlers or []
|
||||
|
||||
def add_handler(self, handler: "CallbackHandler") -> None:
|
||||
"""Add a callback handler.
|
||||
|
||||
Args:
|
||||
handler: Callback handler to add
|
||||
"""
|
||||
self.handlers.append(handler)
|
||||
|
||||
async def on_action_start(self, action: str, **kwargs) -> None:
|
||||
"""Called when an action starts.
|
||||
|
||||
Args:
|
||||
action: Action name
|
||||
**kwargs: Additional data
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
await handler.on_action_start(action, **kwargs)
|
||||
|
||||
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
|
||||
"""Called when an action ends.
|
||||
|
||||
Args:
|
||||
action: Action name
|
||||
success: Whether the action was successful
|
||||
**kwargs: Additional data
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
await handler.on_action_end(action, success, **kwargs)
|
||||
|
||||
async def on_error(self, error: Exception, **kwargs) -> None:
|
||||
"""Called when an error occurs.
|
||||
|
||||
Args:
|
||||
error: Exception that occurred
|
||||
**kwargs: Additional data
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
await handler.on_error(error, **kwargs)
|
||||
|
||||
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[dict] = None) -> None:
|
||||
"""Called when a screenshot is taken.
|
||||
|
||||
Args:
|
||||
screenshot_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
parsed_screen: Optional output from parsing the screenshot
|
||||
|
||||
Returns:
|
||||
Modified screenshot or original if no modifications
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
await handler.on_screenshot(screenshot_base64, action_type, parsed_screen)
|
||||
|
||||
class CallbackHandler(ABC):
|
||||
"""Base class for callback handlers."""
|
||||
|
||||
@abstractmethod
|
||||
async def on_action_start(self, action: str, **kwargs) -> None:
|
||||
"""Called when an action starts.
|
||||
|
||||
Args:
|
||||
action: Action name
|
||||
**kwargs: Additional data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
|
||||
"""Called when an action ends.
|
||||
|
||||
Args:
|
||||
action: Action name
|
||||
success: Whether the action was successful
|
||||
**kwargs: Additional data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_error(self, error: Exception, **kwargs) -> None:
|
||||
"""Called when an error occurs.
|
||||
|
||||
Args:
|
||||
error: Exception that occurred
|
||||
**kwargs: Additional data
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_screenshot(self, screenshot_base64: str, action_type: str = "", parsed_screen: Optional[dict] = None) -> None:
|
||||
"""Called when a screenshot is taken.
|
||||
|
||||
Args:
|
||||
screenshot_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
|
||||
Returns:
|
||||
Optional modified screenshot
|
||||
"""
|
||||
pass
|
||||
|
||||
class DefaultCallbackHandler(CallbackHandler):
|
||||
"""Default implementation of CallbackHandler with no-op methods.
|
||||
|
||||
This class implements all abstract methods from CallbackHandler,
|
||||
allowing subclasses to override only the methods they need.
|
||||
"""
|
||||
|
||||
async def on_action_start(self, action: str, **kwargs) -> None:
|
||||
"""Default no-op implementation."""
|
||||
pass
|
||||
|
||||
async def on_action_end(self, action: str, success: bool, **kwargs) -> None:
|
||||
"""Default no-op implementation."""
|
||||
pass
|
||||
|
||||
async def on_error(self, error: Exception, **kwargs) -> None:
|
||||
"""Default no-op implementation."""
|
||||
pass
|
||||
|
||||
async def on_screenshot(self, screenshot_base64: str, action_type: str = "") -> None:
|
||||
"""Default no-op implementation."""
|
||||
pass
|
||||
@@ -1,249 +0,0 @@
|
||||
"""Core experiment management for agents."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from PIL import Image
|
||||
import json
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExperimentManager:
|
||||
"""Manages experiment directories and logging for the agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_dir: Optional[str] = None,
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
):
|
||||
"""Initialize the experiment manager.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for saving experiment data
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
"""
|
||||
self.base_dir = base_dir
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self.run_dir = None
|
||||
self.current_turn_dir = None
|
||||
self.turn_count = 0
|
||||
self.screenshot_count = 0
|
||||
# Track all screenshots for potential API request inclusion
|
||||
self.screenshot_paths = []
|
||||
|
||||
# Set up experiment directories if base_dir is provided
|
||||
if self.base_dir:
|
||||
self.setup_experiment_dirs()
|
||||
|
||||
def setup_experiment_dirs(self) -> None:
|
||||
"""Setup the experiment directory structure."""
|
||||
if not self.base_dir:
|
||||
return
|
||||
|
||||
# Create base experiments directory if it doesn't exist
|
||||
os.makedirs(self.base_dir, exist_ok=True)
|
||||
|
||||
# Create timestamped run directory
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
self.run_dir = os.path.join(self.base_dir, timestamp)
|
||||
os.makedirs(self.run_dir, exist_ok=True)
|
||||
logger.info(f"Created run directory: {self.run_dir}")
|
||||
|
||||
# Create first turn directory
|
||||
self.create_turn_dir()
|
||||
|
||||
def create_turn_dir(self) -> None:
|
||||
"""Create a new directory for the current turn."""
|
||||
if not self.run_dir:
|
||||
logger.warning("Cannot create turn directory: run_dir not set")
|
||||
return
|
||||
|
||||
# Increment turn counter
|
||||
self.turn_count += 1
|
||||
|
||||
# Create turn directory with padded number
|
||||
turn_name = f"turn_{self.turn_count:03d}"
|
||||
self.current_turn_dir = os.path.join(self.run_dir, turn_name)
|
||||
os.makedirs(self.current_turn_dir, exist_ok=True)
|
||||
logger.info(f"Created turn directory: {self.current_turn_dir}")
|
||||
|
||||
def sanitize_log_data(self, data: Any) -> Any:
|
||||
"""Sanitize log data by replacing large binary data with placeholders.
|
||||
|
||||
Args:
|
||||
data: Data to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized copy of the data
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
result = {}
|
||||
for k, v in data.items():
|
||||
# Special handling for 'data' field in Anthropic message source
|
||||
if k == "data" and isinstance(v, str) and len(v) > 1000:
|
||||
result[k] = f"[BASE64_DATA_LENGTH_{len(v)}]"
|
||||
# Special handling for the 'media_type' key which indicates we're in an image block
|
||||
elif k == "media_type" and "image" in str(v):
|
||||
result[k] = v
|
||||
# If we're in an image block, look for a sibling 'data' field with base64 content
|
||||
if (
|
||||
"data" in result
|
||||
and isinstance(result["data"], str)
|
||||
and len(result["data"]) > 1000
|
||||
):
|
||||
result["data"] = f"[BASE64_DATA_LENGTH_{len(result['data'])}]"
|
||||
else:
|
||||
result[k] = self.sanitize_log_data(v)
|
||||
return result
|
||||
elif isinstance(data, list):
|
||||
return [self.sanitize_log_data(item) for item in data]
|
||||
elif isinstance(data, str) and len(data) > 1000 and "base64" in data.lower():
|
||||
return f"[BASE64_DATA_LENGTH_{len(data)}]"
|
||||
else:
|
||||
return data
|
||||
|
||||
def save_screenshot(self, img_base64: str, action_type: str = "") -> Optional[str]:
|
||||
"""Save a screenshot to the experiment directory.
|
||||
|
||||
Args:
|
||||
img_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
|
||||
Returns:
|
||||
Path to the saved screenshot or None if there was an error
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Increment screenshot counter
|
||||
self.screenshot_count += 1
|
||||
|
||||
# Sanitize action_type to ensure valid filename
|
||||
# Replace characters that are not safe for filenames
|
||||
sanitized_action = ""
|
||||
if action_type:
|
||||
# Replace invalid filename characters with underscores
|
||||
sanitized_action = re.sub(r'[\\/*?:"<>|]', "_", action_type)
|
||||
# Limit the length to avoid excessively long filenames
|
||||
sanitized_action = sanitized_action[:50]
|
||||
|
||||
# Create a descriptive filename
|
||||
timestamp = int(datetime.now().timestamp() * 1000)
|
||||
action_suffix = f"_{sanitized_action}" if sanitized_action else ""
|
||||
filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
|
||||
|
||||
# Save directly to the turn directory
|
||||
filepath = os.path.join(self.current_turn_dir, filename)
|
||||
|
||||
# Save the screenshot
|
||||
img_data = base64.b64decode(img_base64)
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(img_data)
|
||||
|
||||
# Keep track of the file path
|
||||
self.screenshot_paths.append(filepath)
|
||||
|
||||
return filepath
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving screenshot: {str(e)}")
|
||||
return None
|
||||
|
||||
def save_action_visualization(
|
||||
self, img: Image.Image, action_name: str, details: str = ""
|
||||
) -> str:
|
||||
"""Save a visualization of an action.
|
||||
|
||||
Args:
|
||||
img: Image to save
|
||||
action_name: Name of the action
|
||||
details: Additional details about the action
|
||||
|
||||
Returns:
|
||||
Path to the saved image
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Create a descriptive filename
|
||||
timestamp = int(datetime.now().timestamp() * 1000)
|
||||
details_suffix = f"_{details}" if details else ""
|
||||
filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
|
||||
|
||||
# Save directly to the turn directory
|
||||
filepath = os.path.join(self.current_turn_dir, filename)
|
||||
|
||||
# Save the image
|
||||
img.save(filepath)
|
||||
|
||||
# Keep track of the file path
|
||||
self.screenshot_paths.append(filepath)
|
||||
|
||||
return filepath
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving action visualization: {str(e)}")
|
||||
return ""
|
||||
|
||||
def log_api_call(
|
||||
self,
|
||||
call_type: str,
|
||||
request: Any,
|
||||
provider: str = "unknown",
|
||||
model: str = "unknown",
|
||||
response: Any = None,
|
||||
error: Optional[Exception] = None,
|
||||
) -> None:
|
||||
"""Log API call details to file.
|
||||
|
||||
Args:
|
||||
call_type: Type of API call (request, response, error)
|
||||
request: Request data
|
||||
provider: API provider name
|
||||
model: Model name
|
||||
response: Response data (for response logs)
|
||||
error: Error information (for error logs)
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
logger.warning("Cannot log API call: current_turn_dir not set")
|
||||
return
|
||||
|
||||
try:
|
||||
# Create a timestamp for the log file
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Create filename based on log type
|
||||
filename = f"api_call_{timestamp}_{call_type}.json"
|
||||
filepath = os.path.join(self.current_turn_dir, filename)
|
||||
|
||||
# Sanitize data before logging
|
||||
sanitized_request = self.sanitize_log_data(request)
|
||||
sanitized_response = self.sanitize_log_data(response) if response is not None else None
|
||||
|
||||
# Prepare log data
|
||||
log_data = {
|
||||
"timestamp": timestamp,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"type": call_type,
|
||||
"request": sanitized_request,
|
||||
}
|
||||
|
||||
if sanitized_response is not None:
|
||||
log_data["response"] = sanitized_response
|
||||
if error is not None:
|
||||
log_data["error"] = str(error)
|
||||
|
||||
# Write to file
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(log_data, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Logged API {call_type} to {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging API call: {str(e)}")
|
||||
@@ -1,122 +0,0 @@
|
||||
"""Base agent loop implementation."""
|
||||
|
||||
import logging
|
||||
import importlib.util
|
||||
from typing import Dict, Optional, Type, TYPE_CHECKING, Any, cast, Callable, Awaitable
|
||||
|
||||
from computer import Computer
|
||||
from .types import AgentLoop
|
||||
from .base import BaseLoop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopFactory:
|
||||
"""Factory class for creating agent loops."""
|
||||
|
||||
# Registry to store loop implementations
|
||||
_loop_registry: Dict[AgentLoop, Type[BaseLoop]] = {}
|
||||
|
||||
@classmethod
|
||||
def create_loop(
|
||||
cls,
|
||||
loop_type: AgentLoop,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
computer: Computer,
|
||||
provider: Any = None,
|
||||
save_trajectory: bool = True,
|
||||
trajectory_dir: str = "trajectories",
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
|
||||
provider_base_url: Optional[str] = None,
|
||||
) -> BaseLoop:
|
||||
"""Create and return an appropriate loop instance based on type."""
|
||||
if loop_type == AgentLoop.ANTHROPIC:
|
||||
# Lazy import AnthropicLoop only when needed
|
||||
try:
|
||||
from ..providers.anthropic.loop import AnthropicLoop
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'anthropic' provider is not installed. "
|
||||
"Install it with 'pip install cua-agent[anthropic]'"
|
||||
)
|
||||
|
||||
return AnthropicLoop(
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
computer=computer,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
)
|
||||
elif loop_type == AgentLoop.OPENAI:
|
||||
# Lazy import OpenAILoop only when needed
|
||||
try:
|
||||
from ..providers.openai.loop import OpenAILoop
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'openai' provider is not installed. "
|
||||
"Install it with 'pip install cua-agent[openai]'"
|
||||
)
|
||||
|
||||
return OpenAILoop(
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
computer=computer,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
acknowledge_safety_check_callback=acknowledge_safety_check_callback,
|
||||
)
|
||||
elif loop_type == AgentLoop.OMNI:
|
||||
# Lazy import OmniLoop and related classes only when needed
|
||||
try:
|
||||
from ..providers.omni.loop import OmniLoop
|
||||
from ..providers.omni.parser import OmniParser
|
||||
from .types import LLMProvider
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'omni' provider is not installed. "
|
||||
"Install it with 'pip install cua-agent[all]'"
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError("Provider is required for OMNI loop type")
|
||||
|
||||
# We know provider is the correct type at this point, so cast it
|
||||
provider_instance = cast(LLMProvider, provider)
|
||||
|
||||
return OmniLoop(
|
||||
provider=provider_instance,
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
computer=computer,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
parser=OmniParser(),
|
||||
provider_base_url=provider_base_url,
|
||||
)
|
||||
elif loop_type == AgentLoop.UITARS:
|
||||
# Lazy import UITARSLoop only when needed
|
||||
try:
|
||||
from ..providers.uitars.loop import UITARSLoop
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'uitars' provider is not installed. "
|
||||
"Install it with 'pip install cua-agent[all]'"
|
||||
)
|
||||
|
||||
return UITARSLoop(
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
computer=computer,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
provider_base_url=provider_base_url,
|
||||
provider=provider,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported loop type: {loop_type}")
|
||||
@@ -1,332 +0,0 @@
|
||||
"""Message handling utilities for agent."""
|
||||
|
||||
import logging
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageRetentionConfig:
|
||||
"""Configuration for image retention in messages."""
|
||||
|
||||
num_images_to_keep: Optional[int] = None
|
||||
min_removal_threshold: int = 1
|
||||
enable_caching: bool = True
|
||||
|
||||
def should_retain_images(self) -> bool:
|
||||
"""Check if image retention is enabled."""
|
||||
return self.num_images_to_keep is not None and self.num_images_to_keep > 0
|
||||
|
||||
class StandardMessageManager:
|
||||
"""Manages messages in a standardized OpenAI format across different providers."""
|
||||
|
||||
def __init__(self, config: Optional[ImageRetentionConfig] = None):
|
||||
"""Initialize message manager.
|
||||
|
||||
Args:
|
||||
config: Configuration for image retention
|
||||
"""
|
||||
self.messages: List[Dict[str, Any]] = []
|
||||
self.config = config or ImageRetentionConfig()
|
||||
|
||||
def add_user_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
|
||||
"""Add a user message.
|
||||
|
||||
Args:
|
||||
content: Message content (text or multimodal content)
|
||||
"""
|
||||
self.messages.append({"role": "user", "content": content})
|
||||
|
||||
def add_assistant_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
|
||||
"""Add an assistant message.
|
||||
|
||||
Args:
|
||||
content: Message content (text or multimodal content)
|
||||
"""
|
||||
self.messages.append({"role": "assistant", "content": content})
|
||||
|
||||
def add_system_message(self, content: str) -> None:
|
||||
"""Add a system message.
|
||||
|
||||
Args:
|
||||
content: System message content
|
||||
"""
|
||||
self.messages.append({"role": "system", "content": content})
|
||||
|
||||
def get_messages(self) -> List[Dict[str, Any]]:
|
||||
"""Get all messages in standard format.
|
||||
This method applies image retention policy if configured.
|
||||
|
||||
Returns:
|
||||
List of messages
|
||||
"""
|
||||
# If image retention is configured, apply it
|
||||
if self.config.num_images_to_keep is not None:
|
||||
return self._apply_image_retention(self.messages)
|
||||
return self.messages
|
||||
|
||||
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Apply image retention policy to messages.
|
||||
|
||||
Args:
|
||||
messages: List of messages
|
||||
|
||||
Returns:
|
||||
List of messages with image retention applied
|
||||
"""
|
||||
if not self.config.num_images_to_keep:
|
||||
return messages
|
||||
|
||||
# Find messages with images (both user messages and tool call outputs)
|
||||
image_messages = []
|
||||
for msg in messages:
|
||||
has_image = False
|
||||
|
||||
# Check user messages with images
|
||||
if msg["role"] == "user" and isinstance(msg["content"], list):
|
||||
has_image = any(
|
||||
item.get("type") == "image_url" or item.get("type") == "image"
|
||||
for item in msg["content"]
|
||||
)
|
||||
|
||||
# Check assistant messages with tool calls that have images
|
||||
elif msg["role"] == "assistant" and isinstance(msg["content"], list):
|
||||
for item in msg["content"]:
|
||||
if item.get("type") == "tool_result" and "base64_image" in item:
|
||||
has_image = True
|
||||
break
|
||||
|
||||
if has_image:
|
||||
image_messages.append(msg)
|
||||
|
||||
# If we don't have more images than the limit, return all messages
|
||||
if len(image_messages) <= self.config.num_images_to_keep:
|
||||
return messages
|
||||
|
||||
# Get the most recent N images to keep
|
||||
images_to_keep = image_messages[-self.config.num_images_to_keep :]
|
||||
images_to_remove = image_messages[: -self.config.num_images_to_keep]
|
||||
|
||||
# Create a new message list, removing images from older messages
|
||||
result = []
|
||||
for msg in messages:
|
||||
if msg in images_to_remove:
|
||||
# Remove images from this message but keep the text content
|
||||
if msg["role"] == "user" and isinstance(msg["content"], list):
|
||||
# Keep only text content, remove images
|
||||
new_content = [
|
||||
item for item in msg["content"]
|
||||
if item.get("type") not in ["image_url", "image"]
|
||||
]
|
||||
if new_content: # Only add if there's still content
|
||||
result.append({"role": msg["role"], "content": new_content})
|
||||
elif msg["role"] == "assistant" and isinstance(msg["content"], list):
|
||||
# Remove base64_image from tool_result items
|
||||
new_content = []
|
||||
for item in msg["content"]:
|
||||
if item.get("type") == "tool_result" and "base64_image" in item:
|
||||
# Create a copy without the base64_image
|
||||
new_item = {k: v for k, v in item.items() if k != "base64_image"}
|
||||
new_content.append(new_item)
|
||||
else:
|
||||
new_content.append(item)
|
||||
result.append({"role": msg["role"], "content": new_content})
|
||||
else:
|
||||
# For other message types, keep as is
|
||||
result.append(msg)
|
||||
else:
|
||||
result.append(msg)
|
||||
|
||||
return result
|
||||
|
||||
def to_anthropic_format(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> Tuple[List[Dict[str, Any]], str]:
|
||||
"""Convert standard OpenAI format messages to Anthropic format.
|
||||
|
||||
Args:
|
||||
messages: List of messages in OpenAI format
|
||||
|
||||
Returns:
|
||||
Tuple containing (anthropic_messages, system_content)
|
||||
"""
|
||||
result = []
|
||||
system_content = ""
|
||||
|
||||
# Process messages in order to maintain conversation flow
|
||||
previous_assistant_tool_use_ids = (
|
||||
set()
|
||||
) # Track tool_use_ids in the previous assistant message
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
# Collect system messages for later use
|
||||
system_content += content + "\n"
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
# Track tool_use_ids in this assistant message for the next user message
|
||||
previous_assistant_tool_use_ids = set()
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "tool_use"
|
||||
and "id" in item
|
||||
):
|
||||
previous_assistant_tool_use_ids.add(item["id"])
|
||||
|
||||
logger.info(
|
||||
f"Tool use IDs in assistant message #{i}: {previous_assistant_tool_use_ids}"
|
||||
)
|
||||
|
||||
if role in ["user", "assistant"]:
|
||||
anthropic_msg = {"role": role}
|
||||
|
||||
# Convert content based on type
|
||||
if isinstance(content, str):
|
||||
# Simple text content
|
||||
anthropic_msg["content"] = [{"type": "text", "text": content}]
|
||||
elif isinstance(content, list):
|
||||
# Convert complex content
|
||||
anthropic_content = []
|
||||
for item in content:
|
||||
item_type = item.get("type", "")
|
||||
|
||||
if item_type == "text":
|
||||
anthropic_content.append({"type": "text", "text": item.get("text", "")})
|
||||
elif item_type == "image_url":
|
||||
# Convert OpenAI image format to Anthropic
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
if image_url.startswith("data:"):
|
||||
# Extract base64 data and media type
|
||||
match = re.match(r"data:(.+);base64,(.+)", image_url)
|
||||
if match:
|
||||
media_type, data = match.groups()
|
||||
anthropic_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": data,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Regular URL
|
||||
anthropic_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": image_url,
|
||||
},
|
||||
}
|
||||
)
|
||||
elif item_type == "tool_use":
|
||||
# Always include tool_use blocks
|
||||
anthropic_content.append(item)
|
||||
elif item_type == "tool_result":
|
||||
# Check if this is a user message AND if the tool_use_id exists in the previous assistant message
|
||||
tool_use_id = item.get("tool_use_id")
|
||||
|
||||
# Only include tool_result if it references a tool_use from the immediately preceding assistant message
|
||||
if (
|
||||
role == "user"
|
||||
and tool_use_id
|
||||
and tool_use_id in previous_assistant_tool_use_ids
|
||||
):
|
||||
anthropic_content.append(item)
|
||||
logger.info(
|
||||
f"Including tool_result with tool_use_id: {tool_use_id}"
|
||||
)
|
||||
else:
|
||||
# Convert to text to preserve information
|
||||
logger.warning(
|
||||
f"Converting tool_result to text. Tool use ID {tool_use_id} not found in previous assistant message"
|
||||
)
|
||||
content_text = "Tool Result: "
|
||||
if "content" in item:
|
||||
if isinstance(item["content"], list):
|
||||
for content_item in item["content"]:
|
||||
if (
|
||||
isinstance(content_item, dict)
|
||||
and content_item.get("type") == "text"
|
||||
):
|
||||
content_text += content_item.get("text", "")
|
||||
elif isinstance(item["content"], str):
|
||||
content_text += item["content"]
|
||||
anthropic_content.append({"type": "text", "text": content_text})
|
||||
|
||||
anthropic_msg["content"] = anthropic_content
|
||||
|
||||
result.append(anthropic_msg)
|
||||
|
||||
return result, system_content
|
||||
|
||||
def from_anthropic_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert Anthropic format messages to standard OpenAI format.
|
||||
|
||||
Args:
|
||||
messages: List of messages in Anthropic format
|
||||
|
||||
Returns:
|
||||
List of messages in OpenAI format
|
||||
"""
|
||||
result = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", [])
|
||||
|
||||
if role in ["user", "assistant"]:
|
||||
openai_msg = {"role": role}
|
||||
|
||||
# Simple case: single text block
|
||||
if len(content) == 1 and content[0].get("type") == "text":
|
||||
openai_msg["content"] = content[0].get("text", "")
|
||||
else:
|
||||
# Complex case: multiple blocks or non-text
|
||||
openai_content = []
|
||||
for item in content:
|
||||
item_type = item.get("type", "")
|
||||
|
||||
if item_type == "text":
|
||||
openai_content.append({"type": "text", "text": item.get("text", "")})
|
||||
elif item_type == "image":
|
||||
# Convert Anthropic image to OpenAI format
|
||||
source = item.get("source", {})
|
||||
if source.get("type") == "base64":
|
||||
media_type = source.get("media_type", "image/png")
|
||||
data = source.get("data", "")
|
||||
openai_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{media_type};base64,{data}"},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# URL
|
||||
openai_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": source.get("url", "")},
|
||||
}
|
||||
)
|
||||
elif item_type in ["tool_use", "tool_result"]:
|
||||
# Pass through tool-related content
|
||||
openai_content.append(item)
|
||||
|
||||
openai_msg["content"] = openai_content
|
||||
|
||||
result.append(openai_msg)
|
||||
|
||||
return result
|
||||
@@ -1,21 +0,0 @@
|
||||
"""Provider-specific configurations and constants."""
|
||||
|
||||
from .types import LLMProvider
|
||||
|
||||
# Default models for different providers
|
||||
DEFAULT_MODELS = {
|
||||
LLMProvider.OPENAI: "gpt-4o",
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M",
|
||||
LLMProvider.OAICOMPAT: "Qwen2.5-VL-7B-Instruct",
|
||||
LLMProvider.MLXVLM: "mlx-community/UI-TARS-1.5-7B-4bit",
|
||||
}
|
||||
|
||||
# Map providers to their environment variable names
|
||||
ENV_VARS = {
|
||||
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
||||
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
||||
LLMProvider.OLLAMA: "none",
|
||||
LLMProvider.OAICOMPAT: "none", # OpenAI-compatible API typically doesn't require an API key
|
||||
LLMProvider.MLXVLM: "none", # MLX VLM typically doesn't require an API key
|
||||
}
|
||||
@@ -1,142 +0,0 @@
|
||||
"""Agent telemetry for tracking anonymous usage and feature usage."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from typing import Dict, Any, Callable
|
||||
|
||||
# Import the core telemetry module
|
||||
TELEMETRY_AVAILABLE = False
|
||||
|
||||
|
||||
# Local fallbacks in case core telemetry isn't available
|
||||
def _noop(*args: Any, **kwargs: Any) -> None:
|
||||
"""No-op function for when telemetry is not available."""
|
||||
pass
|
||||
|
||||
|
||||
# Define default functions with unique names to avoid shadowing
|
||||
_default_record_event = _noop
|
||||
_default_increment_counter = _noop
|
||||
_default_set_dimension = _noop
|
||||
_default_get_telemetry_client = lambda: None
|
||||
_default_flush = _noop
|
||||
_default_is_telemetry_enabled = lambda: False
|
||||
_default_is_telemetry_globally_disabled = lambda: True
|
||||
|
||||
# Set the actual functions to the defaults initially
|
||||
record_event = _default_record_event
|
||||
increment_counter = _default_increment_counter
|
||||
set_dimension = _default_set_dimension
|
||||
get_telemetry_client = _default_get_telemetry_client
|
||||
flush = _default_flush
|
||||
is_telemetry_enabled = _default_is_telemetry_enabled
|
||||
is_telemetry_globally_disabled = _default_is_telemetry_globally_disabled
|
||||
|
||||
logger = logging.getLogger("agent.telemetry")
|
||||
|
||||
try:
|
||||
# Import from core telemetry
|
||||
from core.telemetry import (
|
||||
record_event as core_record_event,
|
||||
increment as core_increment,
|
||||
get_telemetry_client as core_get_telemetry_client,
|
||||
flush as core_flush,
|
||||
is_telemetry_enabled as core_is_telemetry_enabled,
|
||||
is_telemetry_globally_disabled as core_is_telemetry_globally_disabled,
|
||||
)
|
||||
|
||||
# Override the default functions with actual implementations
|
||||
record_event = core_record_event
|
||||
get_telemetry_client = core_get_telemetry_client
|
||||
flush = core_flush
|
||||
is_telemetry_enabled = core_is_telemetry_enabled
|
||||
is_telemetry_globally_disabled = core_is_telemetry_globally_disabled
|
||||
|
||||
def increment_counter(counter_name: str, value: int = 1) -> None:
|
||||
"""Wrapper for increment to maintain backward compatibility."""
|
||||
if is_telemetry_enabled():
|
||||
core_increment(counter_name, value)
|
||||
|
||||
def set_dimension(name: str, value: Any) -> None:
|
||||
"""Set a dimension that will be attached to all events."""
|
||||
logger.debug(f"Setting dimension {name}={value}")
|
||||
|
||||
TELEMETRY_AVAILABLE = True
|
||||
logger.info("Successfully imported telemetry")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import telemetry: {e}")
|
||||
logger.debug("Telemetry not available, using no-op functions")
|
||||
|
||||
# Get system info once to use in telemetry
|
||||
SYSTEM_INFO = {
|
||||
"os": platform.system().lower(),
|
||||
"os_version": platform.release(),
|
||||
"python_version": platform.python_version(),
|
||||
}
|
||||
|
||||
|
||||
def enable_telemetry() -> bool:
|
||||
"""Enable telemetry if available.
|
||||
|
||||
Returns:
|
||||
bool: True if telemetry was successfully enabled, False otherwise
|
||||
"""
|
||||
global TELEMETRY_AVAILABLE, record_event, increment_counter, get_telemetry_client, flush, is_telemetry_enabled, is_telemetry_globally_disabled
|
||||
|
||||
# Check if globally disabled using core function
|
||||
if TELEMETRY_AVAILABLE and is_telemetry_globally_disabled():
|
||||
logger.info("Telemetry is globally disabled via environment variable - cannot enable")
|
||||
return False
|
||||
|
||||
# Already enabled
|
||||
if TELEMETRY_AVAILABLE:
|
||||
return True
|
||||
|
||||
# Try to import and enable
|
||||
try:
|
||||
from core.telemetry import (
|
||||
record_event,
|
||||
increment,
|
||||
get_telemetry_client,
|
||||
flush,
|
||||
is_telemetry_globally_disabled,
|
||||
)
|
||||
|
||||
# Check again after import
|
||||
if is_telemetry_globally_disabled():
|
||||
logger.info("Telemetry is globally disabled via environment variable - cannot enable")
|
||||
return False
|
||||
|
||||
TELEMETRY_AVAILABLE = True
|
||||
logger.info("Telemetry successfully enabled")
|
||||
return True
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not enable telemetry: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def is_telemetry_enabled() -> bool:
|
||||
"""Check if telemetry is enabled.
|
||||
|
||||
Returns:
|
||||
bool: True if telemetry is enabled, False otherwise
|
||||
"""
|
||||
# Use the core function if available, otherwise use our local flag
|
||||
if TELEMETRY_AVAILABLE:
|
||||
from core.telemetry import is_telemetry_enabled as core_is_enabled
|
||||
|
||||
return core_is_enabled()
|
||||
return False
|
||||
|
||||
|
||||
def record_agent_initialization() -> None:
|
||||
"""Record when an agent instance is initialized."""
|
||||
if TELEMETRY_AVAILABLE and is_telemetry_enabled():
|
||||
record_event("agent_initialized", SYSTEM_INFO)
|
||||
|
||||
# Set dimensions that will be attached to all events
|
||||
set_dimension("os", SYSTEM_INFO["os"])
|
||||
set_dimension("os_version", SYSTEM_INFO["os_version"])
|
||||
set_dimension("python_version", SYSTEM_INFO["python_version"])
|
||||
@@ -1,32 +0,0 @@
|
||||
"""Tool-related type definitions."""
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
class ToolInvocationState(StrEnum):
|
||||
"""States for tool invocation."""
|
||||
CALL = 'call'
|
||||
PARTIAL_CALL = 'partial-call'
|
||||
RESULT = 'result'
|
||||
|
||||
class ToolInvocation(BaseModel):
|
||||
"""Tool invocation type."""
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
state: Optional[str] = None
|
||||
toolCallId: str
|
||||
toolName: Optional[str] = None
|
||||
args: Optional[Dict[str, Any]] = None
|
||||
|
||||
class ClientAttachment(BaseModel):
|
||||
"""Client attachment type."""
|
||||
name: str
|
||||
contentType: str
|
||||
url: str
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""Result of a tool execution."""
|
||||
model_config = ConfigDict(extra='forbid')
|
||||
output: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
@@ -1,21 +0,0 @@
|
||||
"""Core tools package."""
|
||||
|
||||
from .base import BaseTool, ToolResult, ToolError, ToolFailure, CLIResult
|
||||
from .bash import BaseBashTool
|
||||
from .collection import ToolCollection
|
||||
from .computer import BaseComputerTool
|
||||
from .edit import BaseEditTool
|
||||
from .manager import BaseToolManager
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"ToolResult",
|
||||
"ToolError",
|
||||
"ToolFailure",
|
||||
"CLIResult",
|
||||
"BaseBashTool",
|
||||
"BaseComputerTool",
|
||||
"BaseEditTool",
|
||||
"ToolCollection",
|
||||
"BaseToolManager",
|
||||
]
|
||||
@@ -1,74 +0,0 @@
|
||||
"""Abstract base classes for tools that can be used with any provider."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
class BaseTool(metaclass=ABCMeta):
|
||||
"""Abstract base class for provider-agnostic tools."""
|
||||
|
||||
name: str
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> Any:
|
||||
"""Executes the tool with the given arguments."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to provider-specific API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters specific to the LLM provider
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class ToolResult:
|
||||
"""Represents the result of a tool execution."""
|
||||
|
||||
output: str | None = None
|
||||
error: str | None = None
|
||||
base64_image: str | None = None
|
||||
system: str | None = None
|
||||
content: list[dict] | None = None
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field.name) for field in fields(self))
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
content=self.content or other.content, # Use first non-None content
|
||||
)
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new ToolResult with the given fields replaced."""
|
||||
return replace(self, **kwargs)
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""A ToolResult that can be rendered as a CLI output."""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
class ToolError(Exception):
|
||||
"""Raised when a tool encounters an error."""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
@@ -1,52 +0,0 @@
|
||||
"""Abstract base bash/shell tool implementation."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseTool, ToolResult
|
||||
|
||||
|
||||
class BaseBashTool(BaseTool):
|
||||
"""Base class for bash/shell command execution tools across different providers."""
|
||||
|
||||
name = "bash"
|
||||
logger = logging.getLogger(__name__)
|
||||
computer: Computer
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the BashTool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance, may be used for related operations
|
||||
"""
|
||||
self.computer = computer
|
||||
|
||||
async def run_command(self, command: str) -> Tuple[int, str, str]:
|
||||
"""Run a shell command and return exit code, stdout, and stderr.
|
||||
|
||||
Args:
|
||||
command: Shell command to execute
|
||||
|
||||
Returns:
|
||||
Tuple containing (exit_code, stdout, stderr)
|
||||
"""
|
||||
try:
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
return process.returncode or 0, stdout.decode(), stderr.decode()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error running command: {str(e)}")
|
||||
return 1, "", str(e)
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Execute the tool with the provided arguments."""
|
||||
raise NotImplementedError
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Collection classes for managing multiple tools."""
|
||||
|
||||
from typing import Any, Dict, List, Type
|
||||
|
||||
from .base import (
|
||||
BaseTool,
|
||||
ToolError,
|
||||
ToolFailure,
|
||||
ToolResult,
|
||||
)
|
||||
|
||||
|
||||
class ToolCollection:
|
||||
"""A collection of tools that can be used with any provider."""
|
||||
|
||||
def __init__(self, *tools: BaseTool):
|
||||
self.tools = tools
|
||||
self.tool_map = {tool.name: tool for tool in tools}
|
||||
|
||||
def to_params(self) -> List[Dict[str, Any]]:
|
||||
"""Convert all tools to provider-specific parameters.
|
||||
|
||||
Returns:
|
||||
List of dictionaries with tool parameters
|
||||
"""
|
||||
return [tool.to_params() for tool in self.tools]
|
||||
|
||||
async def run(self, *, name: str, tool_input: Dict[str, Any]) -> ToolResult:
|
||||
"""Run a tool with the given input.
|
||||
|
||||
Args:
|
||||
name: Name of the tool to run
|
||||
tool_input: Input parameters for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool execution
|
||||
"""
|
||||
tool = self.tool_map.get(name)
|
||||
if not tool:
|
||||
return ToolFailure(error=f"Tool {name} is invalid")
|
||||
try:
|
||||
return await tool(**tool_input)
|
||||
except ToolError as e:
|
||||
return ToolFailure(error=e.message)
|
||||
except Exception as e:
|
||||
return ToolFailure(error=f"Unexpected error in tool {name}: {str(e)}")
|
||||
@@ -1,113 +0,0 @@
|
||||
"""Abstract base computer tool implementation."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from PIL import Image
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseTool, ToolError, ToolResult
|
||||
|
||||
|
||||
class BaseComputerTool(BaseTool):
|
||||
"""Base class for computer interaction tools across different providers."""
|
||||
|
||||
name = "computer"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
display_num: Optional[int] = None
|
||||
computer: Computer
|
||||
|
||||
_screenshot_delay = 1.0 # Default delay for most platforms
|
||||
_scaling_enabled = True
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the ComputerTool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for screen interactions
|
||||
"""
|
||||
self.computer = computer
|
||||
|
||||
async def initialize_dimensions(self):
|
||||
"""Initialize screen dimensions from the computer interface."""
|
||||
display_size = await self.computer.interface.get_screen_size()
|
||||
self.width = display_size["width"]
|
||||
self.height = display_size["height"]
|
||||
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
|
||||
|
||||
@property
|
||||
def options(self) -> Dict[str, Any]:
|
||||
"""Get the options for the tool.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool options
|
||||
"""
|
||||
if self.width is None or self.height is None:
|
||||
raise RuntimeError(
|
||||
"Screen dimensions not initialized. Call initialize_dimensions() first."
|
||||
)
|
||||
return {
|
||||
"display_width_px": self.width,
|
||||
"display_height_px": self.height,
|
||||
"display_number": self.display_num,
|
||||
}
|
||||
|
||||
async def resize_screenshot_if_needed(self, screenshot: bytes) -> bytes:
|
||||
"""Resize a screenshot to match the expected dimensions.
|
||||
|
||||
Args:
|
||||
screenshot: Raw screenshot data
|
||||
|
||||
Returns:
|
||||
Resized screenshot data
|
||||
"""
|
||||
if self.width is None or self.height is None:
|
||||
raise ToolError("Screen dimensions not initialized")
|
||||
|
||||
try:
|
||||
img = Image.open(io.BytesIO(screenshot))
|
||||
if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info):
|
||||
img = img.convert("RGB")
|
||||
|
||||
# Resize if dimensions don't match
|
||||
if img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling image from {img.size} to {self.width}x{self.height} to match screen dimensions"
|
||||
)
|
||||
img = img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Save back to bytes
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
return buffer.getvalue()
|
||||
|
||||
return screenshot
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during screenshot resizing: {str(e)}")
|
||||
raise ToolError(f"Failed to resize screenshot: {str(e)}")
|
||||
|
||||
async def screenshot(self) -> ToolResult:
|
||||
"""Take a screenshot and return it as a ToolResult with base64-encoded image.
|
||||
|
||||
Returns:
|
||||
ToolResult with the screenshot
|
||||
"""
|
||||
try:
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
screenshot = await self.resize_screenshot_if_needed(screenshot)
|
||||
return ToolResult(base64_image=base64.b64encode(screenshot).decode())
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error taking screenshot: {str(e)}")
|
||||
return ToolResult(error=f"Failed to take screenshot: {str(e)}")
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Execute the tool with the provided arguments."""
|
||||
raise NotImplementedError
|
||||
@@ -1,67 +0,0 @@
|
||||
"""Abstract base edit tool implementation."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseTool, ToolError, ToolResult
|
||||
|
||||
|
||||
class BaseEditTool(BaseTool):
|
||||
"""Base class for text editor tools across different providers."""
|
||||
|
||||
name = "edit"
|
||||
logger = logging.getLogger(__name__)
|
||||
computer: Computer
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the EditTool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance, may be used for related operations
|
||||
"""
|
||||
self.computer = computer
|
||||
|
||||
async def read_file(self, path: str) -> str:
|
||||
"""Read a file and return its contents.
|
||||
|
||||
Args:
|
||||
path: Path to the file to read
|
||||
|
||||
Returns:
|
||||
File contents as a string
|
||||
"""
|
||||
try:
|
||||
path_obj = Path(path)
|
||||
if not path_obj.exists():
|
||||
raise ToolError(f"File does not exist: {path}")
|
||||
return path_obj.read_text()
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error reading file: {str(e)}")
|
||||
raise ToolError(f"Failed to read file: {str(e)}")
|
||||
|
||||
async def write_file(self, path: str, content: str) -> None:
|
||||
"""Write content to a file.
|
||||
|
||||
Args:
|
||||
path: Path to the file to write
|
||||
content: Content to write to the file
|
||||
"""
|
||||
try:
|
||||
path_obj = Path(path)
|
||||
# Create parent directories if they don't exist
|
||||
path_obj.parent.mkdir(parents=True, exist_ok=True)
|
||||
path_obj.write_text(content)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error writing file: {str(e)}")
|
||||
raise ToolError(f"Failed to write file: {str(e)}")
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Execute the tool with the provided arguments."""
|
||||
raise NotImplementedError
|
||||
@@ -1,56 +0,0 @@
|
||||
"""Tool manager for initializing and running tools."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseTool, ToolResult
|
||||
from .collection import ToolCollection
|
||||
|
||||
|
||||
class BaseToolManager(ABC):
|
||||
"""Base class for tool managers across different providers."""
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the tool manager.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for computer-related tools
|
||||
"""
|
||||
self.computer = computer
|
||||
self.tools: ToolCollection | None = None
|
||||
|
||||
@abstractmethod
|
||||
def _initialize_tools(self) -> ToolCollection:
|
||||
"""Initialize all available tools."""
|
||||
...
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize tool-specific requirements and create tool collection."""
|
||||
await self._initialize_tools_specific()
|
||||
self.tools = self._initialize_tools()
|
||||
|
||||
@abstractmethod
|
||||
async def _initialize_tools_specific(self) -> None:
|
||||
"""Initialize provider-specific tool requirements."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_tool_params(self) -> List[Dict[str, Any]]:
|
||||
"""Get tool parameters for API calls."""
|
||||
...
|
||||
|
||||
async def execute_tool(self, name: str, tool_input: Dict[str, Any]) -> ToolResult:
|
||||
"""Execute a tool with the given input.
|
||||
|
||||
Args:
|
||||
name: Name of the tool to execute
|
||||
tool_input: Input parameters for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool execution
|
||||
"""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
return await self.tools.run(name=name, tool_input=tool_input)
|
||||
@@ -1,88 +0,0 @@
|
||||
"""Core type definitions."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||
from enum import StrEnum
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class AgentLoop(StrEnum):
|
||||
"""Enumeration of available loop types."""
|
||||
|
||||
ANTHROPIC = "anthropic" # Anthropic implementation
|
||||
OMNI = "omni" # OmniLoop implementation
|
||||
OPENAI = "openai" # OpenAI implementation
|
||||
OLLAMA = "ollama" # OLLAMA implementation
|
||||
UITARS = "uitars" # UI-TARS implementation
|
||||
# Add more loop types as needed
|
||||
|
||||
|
||||
class LLMProvider(StrEnum):
|
||||
"""Supported LLM providers."""
|
||||
|
||||
ANTHROPIC = "anthropic"
|
||||
OPENAI = "openai"
|
||||
OLLAMA = "ollama"
|
||||
OAICOMPAT = "oaicompat"
|
||||
MLXVLM= "mlxvlm"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLM:
|
||||
"""Configuration for LLM model and provider."""
|
||||
|
||||
provider: LLMProvider
|
||||
name: Optional[str] = None
|
||||
provider_base_url: Optional[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
"""Set default model name if not provided."""
|
||||
if self.name is None:
|
||||
from .provider_config import DEFAULT_MODELS
|
||||
|
||||
self.name = DEFAULT_MODELS.get(self.provider)
|
||||
|
||||
# Set default provider URL if none provided
|
||||
if self.provider_base_url is None and self.provider == LLMProvider.OAICOMPAT:
|
||||
# Default for vLLM
|
||||
self.provider_base_url = "http://localhost:8000/v1"
|
||||
# Common alternatives:
|
||||
# - LM Studio: "http://localhost:1234/v1"
|
||||
# - LocalAI: "http://localhost:8080/v1"
|
||||
# - Ollama with OpenAI compatible API: "http://localhost:11434/v1"
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
LLMModel = LLM
|
||||
Model = LLM
|
||||
|
||||
|
||||
class AgentResponse(TypedDict, total=False):
|
||||
"""Agent response format."""
|
||||
|
||||
id: str
|
||||
object: str
|
||||
created_at: int
|
||||
status: str
|
||||
error: Optional[str]
|
||||
incomplete_details: Optional[Any]
|
||||
instructions: Optional[Any]
|
||||
max_output_tokens: Optional[int]
|
||||
model: str
|
||||
output: List[Dict[str, Any]]
|
||||
parallel_tool_calls: bool
|
||||
previous_response_id: Optional[str]
|
||||
reasoning: Dict[str, str]
|
||||
store: bool
|
||||
temperature: float
|
||||
text: Dict[str, Dict[str, str]]
|
||||
tool_choice: str
|
||||
tools: List[Dict[str, Union[str, int]]]
|
||||
top_p: float
|
||||
truncation: str
|
||||
usage: Dict[str, Any]
|
||||
user: Optional[str]
|
||||
metadata: Dict[str, Any]
|
||||
response: Dict[str, List[Dict[str, Any]]]
|
||||
# Additional fields for error responses
|
||||
role: str
|
||||
content: Union[str, List[Dict[str, Any]]]
|
||||
@@ -1,197 +0,0 @@
|
||||
"""Core visualization utilities for agents."""
|
||||
|
||||
import logging
|
||||
import base64
|
||||
from typing import Dict, Tuple
|
||||
from PIL import Image, ImageDraw
|
||||
from io import BytesIO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
|
||||
"""Visualize a click action by drawing a circle on the screenshot.
|
||||
|
||||
Args:
|
||||
x: X coordinate of the click
|
||||
y: Y coordinate of the click
|
||||
img_base64: Base64-encoded screenshot
|
||||
|
||||
Returns:
|
||||
PIL Image with visualization
|
||||
"""
|
||||
try:
|
||||
# Decode the base64 image
|
||||
image_data = base64.b64decode(img_base64)
|
||||
img = Image.open(BytesIO(image_data))
|
||||
|
||||
# Create a copy to draw on
|
||||
draw_img = img.copy()
|
||||
draw = ImageDraw.Draw(draw_img)
|
||||
|
||||
# Draw a circle at the click location
|
||||
radius = 15
|
||||
draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], outline="red", width=3)
|
||||
|
||||
# Draw crosshairs
|
||||
line_length = 20
|
||||
draw.line([(x - line_length, y), (x + line_length, y)], fill="red", width=3)
|
||||
draw.line([(x, y - line_length), (x, y + line_length)], fill="red", width=3)
|
||||
|
||||
return draw_img
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing click: {str(e)}")
|
||||
# Return a blank image as fallback
|
||||
return Image.new("RGB", (800, 600), "white")
|
||||
|
||||
|
||||
def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
|
||||
"""Visualize a scroll action by drawing arrows on the screenshot.
|
||||
|
||||
Args:
|
||||
direction: Direction of scroll ('up' or 'down')
|
||||
clicks: Number of scroll clicks
|
||||
img_base64: Base64-encoded screenshot
|
||||
|
||||
Returns:
|
||||
PIL Image with visualization
|
||||
"""
|
||||
try:
|
||||
# Decode the base64 image
|
||||
image_data = base64.b64decode(img_base64)
|
||||
img = Image.open(BytesIO(image_data))
|
||||
|
||||
# Create a copy to draw on
|
||||
draw_img = img.copy()
|
||||
draw = ImageDraw.Draw(draw_img)
|
||||
|
||||
# Calculate parameters for visualization
|
||||
width, height = img.size
|
||||
center_x = width // 2
|
||||
|
||||
# Draw arrows to indicate scrolling
|
||||
arrow_length = min(100, height // 4)
|
||||
arrow_width = 30
|
||||
num_arrows = min(clicks, 3) # Don't draw too many arrows
|
||||
|
||||
# Calculate starting position
|
||||
if direction == "down":
|
||||
start_y = height // 3
|
||||
arrow_dir = 1 # Down
|
||||
else:
|
||||
start_y = height * 2 // 3
|
||||
arrow_dir = -1 # Up
|
||||
|
||||
# Draw the arrows
|
||||
for i in range(num_arrows):
|
||||
y_pos = start_y + (i * arrow_length * arrow_dir * 0.7)
|
||||
arrow_top = (center_x, y_pos)
|
||||
arrow_bottom = (center_x, y_pos + arrow_length * arrow_dir)
|
||||
|
||||
# Draw the main line
|
||||
draw.line([arrow_top, arrow_bottom], fill="red", width=5)
|
||||
|
||||
# Draw the arrowhead
|
||||
arrowhead_size = 20
|
||||
if direction == "down":
|
||||
draw.line(
|
||||
[
|
||||
(center_x - arrow_width // 2, arrow_bottom[1] - arrowhead_size),
|
||||
arrow_bottom,
|
||||
(center_x + arrow_width // 2, arrow_bottom[1] - arrowhead_size),
|
||||
],
|
||||
fill="red",
|
||||
width=5,
|
||||
)
|
||||
else:
|
||||
draw.line(
|
||||
[
|
||||
(center_x - arrow_width // 2, arrow_bottom[1] + arrowhead_size),
|
||||
arrow_bottom,
|
||||
(center_x + arrow_width // 2, arrow_bottom[1] + arrowhead_size),
|
||||
],
|
||||
fill="red",
|
||||
width=5,
|
||||
)
|
||||
|
||||
return draw_img
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing scroll: {str(e)}")
|
||||
# Return a blank image as fallback
|
||||
return Image.new("RGB", (800, 600), "white")
|
||||
|
||||
|
||||
def calculate_element_center(bbox: Dict[str, float], width: int, height: int) -> Tuple[int, int]:
|
||||
"""Calculate the center point of a UI element.
|
||||
|
||||
Args:
|
||||
bbox: Bounding box dictionary with x1, y1, x2, y2 coordinates (0-1 normalized)
|
||||
width: Screen width in pixels
|
||||
height: Screen height in pixels
|
||||
|
||||
Returns:
|
||||
(x, y) tuple with pixel coordinates
|
||||
"""
|
||||
center_x = int((bbox["x1"] + bbox["x2"]) / 2 * width)
|
||||
center_y = int((bbox["y1"] + bbox["y2"]) / 2 * height)
|
||||
return center_x, center_y
|
||||
|
||||
|
||||
class VisualizationHelper:
|
||||
"""Helper class for visualizing agent actions."""
|
||||
|
||||
def __init__(self, agent):
|
||||
"""Initialize visualization helper.
|
||||
|
||||
Args:
|
||||
agent: Reference to the agent that will use this helper
|
||||
"""
|
||||
self.agent = agent
|
||||
|
||||
def visualize_action(self, x: int, y: int, img_base64: str) -> None:
|
||||
"""Visualize a click action by drawing on the screenshot."""
|
||||
if (
|
||||
not self.agent.save_trajectory
|
||||
or not hasattr(self.agent, "experiment_manager")
|
||||
or not self.agent.experiment_manager
|
||||
):
|
||||
return
|
||||
|
||||
try:
|
||||
# Use the visualization utility
|
||||
img = visualize_click(x, y, img_base64)
|
||||
|
||||
# Save the visualization
|
||||
self.agent.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing action: {str(e)}")
|
||||
|
||||
def visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
|
||||
"""Visualize a scroll action by drawing arrows on the screenshot."""
|
||||
if (
|
||||
not self.agent.save_trajectory
|
||||
or not hasattr(self.agent, "experiment_manager")
|
||||
or not self.agent.experiment_manager
|
||||
):
|
||||
return
|
||||
|
||||
try:
|
||||
# Use the visualization utility
|
||||
img = visualize_scroll(direction, clicks, img_base64)
|
||||
|
||||
# Save the visualization
|
||||
self.agent.experiment_manager.save_action_visualization(
|
||||
img, "scroll", f"{direction}_{clicks}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing scroll: {str(e)}")
|
||||
|
||||
def save_action_visualization(
|
||||
self, img: Image.Image, action_name: str, details: str = ""
|
||||
) -> str:
|
||||
"""Save a visualization of an action."""
|
||||
if hasattr(self.agent, "experiment_manager") and self.agent.experiment_manager:
|
||||
return self.agent.experiment_manager.save_action_visualization(
|
||||
img, action_name, details
|
||||
)
|
||||
return ""
|
||||
90
libs/python/agent/agent/decorators.py
Normal file
90
libs/python/agent/agent/decorators.py
Normal file
@@ -0,0 +1,90 @@
|
||||
"""
|
||||
Decorators for agent - agent_loop decorator
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Dict, List, Any, Callable, Optional
|
||||
from functools import wraps
|
||||
|
||||
from .types import AgentLoopInfo
|
||||
|
||||
# Global registry
|
||||
_agent_loops: List[AgentLoopInfo] = []
|
||||
|
||||
def agent_loop(models: str, priority: int = 0):
|
||||
"""
|
||||
Decorator to register an agent loop function.
|
||||
|
||||
Args:
|
||||
models: Regex pattern to match supported models
|
||||
priority: Priority for loop selection (higher = more priority)
|
||||
"""
|
||||
def decorator(func: Callable):
|
||||
# Validate function signature
|
||||
sig = inspect.signature(func)
|
||||
required_params = {'messages', 'model'}
|
||||
func_params = set(sig.parameters.keys())
|
||||
|
||||
if not required_params.issubset(func_params):
|
||||
missing = required_params - func_params
|
||||
raise ValueError(f"Agent loop function must have parameters: {missing}")
|
||||
|
||||
# Register the loop
|
||||
loop_info = AgentLoopInfo(
|
||||
func=func,
|
||||
models_regex=models,
|
||||
priority=priority
|
||||
)
|
||||
_agent_loops.append(loop_info)
|
||||
|
||||
# Sort by priority (highest first)
|
||||
_agent_loops.sort(key=lambda x: x.priority, reverse=True)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Wrap the function in an asyncio.Queue for cancellation support
|
||||
queue = asyncio.Queue()
|
||||
task = None
|
||||
|
||||
try:
|
||||
# Create a task that can be cancelled
|
||||
async def run_loop():
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
await queue.put(('result', result))
|
||||
except Exception as e:
|
||||
await queue.put(('error', e))
|
||||
|
||||
task = asyncio.create_task(run_loop())
|
||||
|
||||
# Wait for result or cancellation
|
||||
event_type, data = await queue.get()
|
||||
|
||||
if event_type == 'error':
|
||||
raise data
|
||||
return data
|
||||
|
||||
except asyncio.CancelledError:
|
||||
if task:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
def get_agent_loops() -> List[AgentLoopInfo]:
|
||||
"""Get all registered agent loops"""
|
||||
return _agent_loops.copy()
|
||||
|
||||
def find_agent_loop(model: str) -> Optional[AgentLoopInfo]:
|
||||
"""Find the best matching agent loop for a model"""
|
||||
for loop_info in _agent_loops:
|
||||
if loop_info.matches_model(model):
|
||||
return loop_info
|
||||
return None
|
||||
11
libs/python/agent/agent/loops/__init__.py
Normal file
11
libs/python/agent/agent/loops/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
Agent loops for agent
|
||||
"""
|
||||
|
||||
# Import the loops to register them
|
||||
from . import anthropic
|
||||
from . import openai
|
||||
from . import uitars
|
||||
from . import omniparser
|
||||
|
||||
__all__ = ["anthropic", "openai", "uitars", "omniparser"]
|
||||
1367
libs/python/agent/agent/loops/anthropic.py
Normal file
1367
libs/python/agent/agent/loops/anthropic.py
Normal file
File diff suppressed because it is too large
Load Diff
339
libs/python/agent/agent/loops/omniparser.py
Normal file
339
libs/python/agent/agent/loops/omniparser.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop implementation using liteLLM
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
import litellm
|
||||
import inspect
|
||||
import base64
|
||||
|
||||
from ..decorators import agent_loop
|
||||
from ..types import Messages, AgentResponse, Tools
|
||||
|
||||
SOM_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment"
|
||||
],
|
||||
"description": "The action to perform"
|
||||
},
|
||||
"element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
|
||||
},
|
||||
"start_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to start dragging from (required for drag action)"
|
||||
},
|
||||
"end_element_id": {
|
||||
"type": "integer",
|
||||
"description": "The ID of the element to drag to (required for drag action)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)"
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"action"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
OMNIPARSER_AVAILABLE = False
|
||||
try:
|
||||
from som import OmniParser
|
||||
OMNIPARSER_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
OMNIPARSER_SINGLETON = None
|
||||
|
||||
def get_parser():
|
||||
global OMNIPARSER_SINGLETON
|
||||
if OMNIPARSER_SINGLETON is None:
|
||||
OMNIPARSER_SINGLETON = OmniParser()
|
||||
return OMNIPARSER_SINGLETON
|
||||
|
||||
def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Get the last computer_call_output message from a messages list.
|
||||
|
||||
Args:
|
||||
messages: List of messages to search through
|
||||
|
||||
Returns:
|
||||
The last computer_call_output message dict, or None if not found
|
||||
"""
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, dict) and message.get("type") == "computer_call_output":
|
||||
return message
|
||||
return None
|
||||
|
||||
def _prepare_tools_for_omniparser(tool_schemas: List[Dict[str, Any]]) -> Tuple[Tools, dict]:
|
||||
"""Prepare tools for OpenAI API format"""
|
||||
omniparser_tools = []
|
||||
id2xy = dict()
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
omniparser_tools.append(SOM_TOOL_SCHEMA)
|
||||
if "id2xy" in schema:
|
||||
id2xy = schema["id2xy"]
|
||||
else:
|
||||
schema["id2xy"] = id2xy
|
||||
elif schema["type"] == "function":
|
||||
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
|
||||
# Schema should be: {type, name, description, parameters}
|
||||
omniparser_tools.append({ "type": "function", **schema["function"] })
|
||||
|
||||
return omniparser_tools, id2xy
|
||||
|
||||
async def replace_function_with_computer_call(item: Dict[str, Any], id2xy: Dict[int, Tuple[float, float]]):
|
||||
item_type = item.get("type")
|
||||
|
||||
def _get_xy(element_id: Optional[int]) -> Union[Tuple[float, float], Tuple[None, None]]:
|
||||
if element_id is None:
|
||||
return (None, None)
|
||||
return id2xy.get(element_id, (None, None))
|
||||
|
||||
if item_type == "function_call":
|
||||
fn_name = item.get("name")
|
||||
fn_args = json.loads(item.get("arguments", "{}"))
|
||||
|
||||
item_id = item.get("id")
|
||||
call_id = item.get("call_id")
|
||||
|
||||
if fn_name == "computer":
|
||||
action = fn_args.get("action")
|
||||
element_id = fn_args.get("element_id")
|
||||
start_element_id = fn_args.get("start_element_id")
|
||||
end_element_id = fn_args.get("end_element_id")
|
||||
text = fn_args.get("text")
|
||||
keys = fn_args.get("keys")
|
||||
button = fn_args.get("button")
|
||||
scroll_x = fn_args.get("scroll_x")
|
||||
scroll_y = fn_args.get("scroll_y")
|
||||
|
||||
x, y = _get_xy(element_id)
|
||||
start_x, start_y = _get_xy(start_element_id)
|
||||
end_x, end_y = _get_xy(end_element_id)
|
||||
|
||||
action_args = {
|
||||
"type": action,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"start_x": start_x,
|
||||
"start_y": start_y,
|
||||
"end_x": end_x,
|
||||
"end_y": end_y,
|
||||
"text": text,
|
||||
"keys": keys,
|
||||
"button": button,
|
||||
"scroll_x": scroll_x,
|
||||
"scroll_y": scroll_y
|
||||
}
|
||||
# Remove None values to keep the JSON clean
|
||||
action_args = {k: v for k, v in action_args.items() if v is not None}
|
||||
|
||||
return [{
|
||||
"type": "computer_call",
|
||||
"action": action_args,
|
||||
"id": item_id,
|
||||
"call_id": call_id,
|
||||
"status": "completed"
|
||||
}]
|
||||
|
||||
return [item]
|
||||
|
||||
async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[Tuple[float, float], int]):
|
||||
"""
|
||||
Convert computer_call back to function_call format.
|
||||
Also handles computer_call_output -> function_call_output conversion.
|
||||
|
||||
Args:
|
||||
item: The item to convert
|
||||
xy2id: Mapping from (x, y) coordinates to element IDs
|
||||
"""
|
||||
item_type = item.get("type")
|
||||
|
||||
def _get_element_id(x: Optional[float], y: Optional[float]) -> Optional[int]:
|
||||
"""Get element ID from coordinates, return None if coordinates are None"""
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return xy2id.get((x, y))
|
||||
|
||||
if item_type == "computer_call":
|
||||
action_data = item.get("action", {})
|
||||
|
||||
# Extract coordinates and convert back to element IDs
|
||||
element_id = _get_element_id(action_data.get("x"), action_data.get("y"))
|
||||
start_element_id = _get_element_id(action_data.get("start_x"), action_data.get("start_y"))
|
||||
end_element_id = _get_element_id(action_data.get("end_x"), action_data.get("end_y"))
|
||||
|
||||
# Build function arguments
|
||||
fn_args = {
|
||||
"action": action_data.get("type"),
|
||||
"element_id": element_id,
|
||||
"start_element_id": start_element_id,
|
||||
"end_element_id": end_element_id,
|
||||
"text": action_data.get("text"),
|
||||
"keys": action_data.get("keys"),
|
||||
"button": action_data.get("button"),
|
||||
"scroll_x": action_data.get("scroll_x"),
|
||||
"scroll_y": action_data.get("scroll_y")
|
||||
}
|
||||
|
||||
# Remove None values to keep the JSON clean
|
||||
fn_args = {k: v for k, v in fn_args.items() if v is not None}
|
||||
|
||||
return [{
|
||||
"type": "function_call",
|
||||
"name": "computer",
|
||||
"arguments": json.dumps(fn_args),
|
||||
"id": item.get("id"),
|
||||
"call_id": item.get("call_id"),
|
||||
"status": "completed",
|
||||
|
||||
# Fall back to string representation
|
||||
"content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})"
|
||||
}]
|
||||
|
||||
elif item_type == "computer_call_output":
|
||||
# Simple conversion: computer_call_output -> function_call_output
|
||||
return [{
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"content": [item.get("output")],
|
||||
"id": item.get("id"),
|
||||
"status": "completed"
|
||||
}]
|
||||
|
||||
return [item]
|
||||
|
||||
|
||||
@agent_loop(models=r"omniparser\+.*|omni\+.*", priority=10)
|
||||
async def omniparser_loop(
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop using liteLLM responses.
|
||||
|
||||
Supports OpenAI's computer use preview models.
|
||||
"""
|
||||
if not OMNIPARSER_AVAILABLE:
|
||||
raise ValueError("omniparser loop requires som to be installed. Install it with `pip install cua-som`.")
|
||||
|
||||
tools = tools or []
|
||||
|
||||
llm_model = model.split('+')[-1]
|
||||
|
||||
# Prepare tools for OpenAI API
|
||||
openai_tools, id2xy = _prepare_tools_for_omniparser(tools)
|
||||
|
||||
# Find last computer_call_output
|
||||
last_computer_call_output = get_last_computer_call_output(messages)
|
||||
if last_computer_call_output:
|
||||
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
|
||||
image_data = image_url.split(",")[-1]
|
||||
if image_data:
|
||||
parser = get_parser()
|
||||
result = parser.parse(image_data)
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(result.annotated_image_base64, "annotated_image")
|
||||
for element in result.elements:
|
||||
id2xy[element.id] = ((element.bbox.x1 + element.bbox.x2) / 2, (element.bbox.y1 + element.bbox.y2) / 2)
|
||||
|
||||
# handle computer calls -> function calls
|
||||
new_messages = []
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
message = message.__dict__
|
||||
new_messages += await replace_computer_call_with_function(message, id2xy)
|
||||
messages = new_messages
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": llm_model,
|
||||
"input": messages,
|
||||
"tools": openai_tools if openai_tools else None,
|
||||
"stream": stream,
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
print(str(api_kwargs)[:1000])
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract usage information
|
||||
response.usage = {
|
||||
**response.usage.model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response.usage)
|
||||
|
||||
# handle som function calls -> xy computer calls
|
||||
new_output = []
|
||||
for i in range(len(response.output)):
|
||||
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy)
|
||||
response.output = new_output
|
||||
|
||||
return response
|
||||
95
libs/python/agent/agent/loops/openai.py
Normal file
95
libs/python/agent/agent/loops/openai.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop implementation using liteLLM
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional
|
||||
import litellm
|
||||
|
||||
from ..decorators import agent_loop
|
||||
from ..types import Messages, AgentResponse, Tools
|
||||
|
||||
def _map_computer_tool_to_openai(computer_tool: Any) -> Dict[str, Any]:
|
||||
"""Map a computer tool to OpenAI's computer-use-preview tool schema"""
|
||||
return {
|
||||
"type": "computer_use_preview",
|
||||
"display_width": getattr(computer_tool, 'display_width', 1024),
|
||||
"display_height": getattr(computer_tool, 'display_height', 768),
|
||||
"environment": getattr(computer_tool, 'environment', "linux") # mac, windows, linux, browser
|
||||
}
|
||||
|
||||
|
||||
def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools:
|
||||
"""Prepare tools for OpenAI API format"""
|
||||
openai_tools = []
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
# Map computer tool to OpenAI format
|
||||
openai_tools.append(_map_computer_tool_to_openai(schema["computer"]))
|
||||
elif schema["type"] == "function":
|
||||
# Function tools use OpenAI-compatible schema directly (liteLLM expects this format)
|
||||
# Schema should be: {type, name, description, parameters}
|
||||
openai_tools.append({ "type": "function", **schema["function"] })
|
||||
|
||||
return openai_tools
|
||||
|
||||
|
||||
@agent_loop(models=r".*computer-use-preview.*", priority=10)
|
||||
async def openai_computer_use_loop(
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop using liteLLM responses.
|
||||
|
||||
Supports OpenAI's computer use preview models.
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
# Prepare tools for OpenAI API
|
||||
openai_tools = _prepare_tools_for_openai(tools)
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"input": messages,
|
||||
"tools": openai_tools if openai_tools else None,
|
||||
"stream": stream,
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract usage information
|
||||
response.usage = {
|
||||
**response.usage.model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response.usage)
|
||||
|
||||
return response
|
||||
688
libs/python/agent/agent/loops/uitars.py
Normal file
688
libs/python/agent/agent/loops/uitars.py
Normal file
@@ -0,0 +1,688 @@
|
||||
"""
|
||||
UITARS agent loop implementation using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from ctypes import cast
|
||||
import json
|
||||
import base64
|
||||
import math
|
||||
import re
|
||||
import ast
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import litellm
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig
|
||||
from litellm.responses.utils import Usage
|
||||
from openai.types.responses.response_computer_tool_call_param import ActionType, ResponseComputerToolCallParam
|
||||
from openai.types.responses.response_input_param import ComputerCallOutput
|
||||
from openai.types.responses.response_output_message_param import ResponseOutputMessageParam
|
||||
from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam, Summary
|
||||
|
||||
from ..decorators import agent_loop
|
||||
from ..types import Messages, AgentResponse, Tools
|
||||
from ..responses import (
|
||||
make_reasoning_item,
|
||||
make_output_text_item,
|
||||
make_click_item,
|
||||
make_double_click_item,
|
||||
make_drag_item,
|
||||
make_keypress_item,
|
||||
make_scroll_item,
|
||||
make_type_item,
|
||||
make_wait_item,
|
||||
make_input_image_item
|
||||
)
|
||||
|
||||
# Constants from reference code
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
FINISH_WORD = "finished"
|
||||
WAIT_WORD = "wait"
|
||||
ENV_FAIL_WORD = "error_env"
|
||||
CALL_USER = "call_user"
|
||||
|
||||
# Action space prompt for UITARS
|
||||
UITARS_ACTION_SPACE = """
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
"""
|
||||
|
||||
UITARS_PROMPT_TEMPLATE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
{action_space}
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
|
||||
def round_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
|
||||
def ceil_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
|
||||
def floor_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
|
||||
def escape_single_quotes(text):
|
||||
"""Escape single quotes in text for safe string formatting."""
|
||||
pattern = r"(?<!\\)'"
|
||||
return re.sub(pattern, r"\\'", text)
|
||||
|
||||
|
||||
def parse_action(action_str):
|
||||
"""Parse action string into structured format."""
|
||||
try:
|
||||
node = ast.parse(action_str, mode='eval')
|
||||
if not isinstance(node, ast.Expression):
|
||||
raise ValueError("Not an expression")
|
||||
|
||||
call = node.body
|
||||
if not isinstance(call, ast.Call):
|
||||
raise ValueError("Not a function call")
|
||||
|
||||
# Get function name
|
||||
if isinstance(call.func, ast.Name):
|
||||
func_name = call.func.id
|
||||
elif isinstance(call.func, ast.Attribute):
|
||||
func_name = call.func.attr
|
||||
else:
|
||||
func_name = None
|
||||
|
||||
# Get keyword arguments
|
||||
kwargs = {}
|
||||
for kw in call.keywords:
|
||||
key = kw.arg
|
||||
if isinstance(kw.value, ast.Constant):
|
||||
value = kw.value.value
|
||||
elif isinstance(kw.value, ast.Str): # Compatibility with older Python
|
||||
value = kw.value.s
|
||||
else:
|
||||
value = None
|
||||
kwargs[key] = value
|
||||
|
||||
return {
|
||||
'function': func_name,
|
||||
'args': kwargs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to parse action '{action_str}': {e}")
|
||||
return None
|
||||
|
||||
|
||||
def parse_uitars_response(text: str, image_width: int, image_height: int) -> List[Dict[str, Any]]:
|
||||
"""Parse UITARS model response into structured actions."""
|
||||
text = text.strip()
|
||||
|
||||
# Extract thought
|
||||
thought = None
|
||||
if text.startswith("Thought:"):
|
||||
thought_match = re.search(r"Thought: (.+?)(?=\s*Action:|$)", text, re.DOTALL)
|
||||
if thought_match:
|
||||
thought = thought_match.group(1).strip()
|
||||
|
||||
# Extract action
|
||||
if "Action:" not in text:
|
||||
raise ValueError("No Action found in response")
|
||||
|
||||
action_str = text.split("Action:")[-1].strip()
|
||||
|
||||
# Handle special case for type actions
|
||||
if "type(content" in action_str:
|
||||
def escape_quotes(match):
|
||||
return match.group(1)
|
||||
|
||||
pattern = r"type\(content='(.*?)'\)"
|
||||
content = re.sub(pattern, escape_quotes, action_str)
|
||||
action_str = escape_single_quotes(content)
|
||||
action_str = "type(content='" + action_str + "')"
|
||||
|
||||
|
||||
# Parse the action
|
||||
parsed_action = parse_action(action_str.replace("\n", "\\n").lstrip())
|
||||
if parsed_action is None:
|
||||
raise ValueError(f"Action can't parse: {action_str}")
|
||||
|
||||
action_type = parsed_action["function"]
|
||||
params = parsed_action["args"]
|
||||
|
||||
# Process parameters
|
||||
action_inputs = {}
|
||||
for param_name, param in params.items():
|
||||
if param == "":
|
||||
continue
|
||||
param = str(param).lstrip()
|
||||
action_inputs[param_name.strip()] = param
|
||||
|
||||
# Handle coordinate parameters
|
||||
if "start_box" in param_name or "end_box" in param_name:
|
||||
# Parse coordinates like '(x,y)' or '(x1,y1,x2,y2)'
|
||||
numbers = param.replace("(", "").replace(")", "").split(",")
|
||||
float_numbers = [float(num.strip()) / 1000 for num in numbers] # Normalize to 0-1 range
|
||||
|
||||
if len(float_numbers) == 2:
|
||||
# Single point, duplicate for box format
|
||||
float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]]
|
||||
|
||||
action_inputs[param_name.strip()] = str(float_numbers)
|
||||
|
||||
return [{
|
||||
"thought": thought,
|
||||
"action_type": action_type,
|
||||
"action_inputs": action_inputs,
|
||||
"text": text
|
||||
}]
|
||||
|
||||
|
||||
def convert_to_computer_actions(parsed_responses: List[Dict[str, Any]], image_width: int, image_height: int) -> List[ResponseComputerToolCallParam | ResponseOutputMessageParam]:
|
||||
"""Convert parsed UITARS responses to computer actions."""
|
||||
computer_actions = []
|
||||
|
||||
for response in parsed_responses:
|
||||
action_type = response.get("action_type")
|
||||
action_inputs = response.get("action_inputs", {})
|
||||
|
||||
if action_type == "finished":
|
||||
finished_text = action_inputs.get("content", "Task completed successfully.")
|
||||
computer_actions.append(make_output_text_item(finished_text))
|
||||
break
|
||||
|
||||
elif action_type == "wait":
|
||||
computer_actions.append(make_wait_item())
|
||||
|
||||
elif action_type == "call_user":
|
||||
computer_actions.append(make_output_text_item("I need assistance from the user to proceed with this task."))
|
||||
|
||||
elif action_type in ["click", "left_single"]:
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
|
||||
computer_actions.append(make_click_item(x, y, "left"))
|
||||
|
||||
elif action_type == "double_click":
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
|
||||
computer_actions.append(make_double_click_item(x, y))
|
||||
|
||||
elif action_type == "right_click":
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
|
||||
computer_actions.append(make_click_item(x, y, "right"))
|
||||
|
||||
elif action_type == "type":
|
||||
content = action_inputs.get("content", "")
|
||||
computer_actions.append(make_type_item(content))
|
||||
|
||||
elif action_type == "hotkey":
|
||||
key = action_inputs.get("key", "")
|
||||
keys = key.split()
|
||||
computer_actions.append(make_keypress_item(keys))
|
||||
|
||||
elif action_type == "press":
|
||||
key = action_inputs.get("key", "")
|
||||
computer_actions.append(make_keypress_item([key]))
|
||||
|
||||
elif action_type == "scroll":
|
||||
start_box = action_inputs.get("start_box")
|
||||
direction = action_inputs.get("direction", "down")
|
||||
|
||||
if start_box:
|
||||
coords = eval(start_box)
|
||||
x = int((coords[0] + coords[2]) / 2 * image_width)
|
||||
y = int((coords[1] + coords[3]) / 2 * image_height)
|
||||
else:
|
||||
x, y = image_width // 2, image_height // 2
|
||||
|
||||
scroll_y = 5 if "up" in direction.lower() else -5
|
||||
computer_actions.append(make_scroll_item(x, y, 0, scroll_y))
|
||||
|
||||
elif action_type == "drag":
|
||||
start_box = action_inputs.get("start_box")
|
||||
end_box = action_inputs.get("end_box")
|
||||
|
||||
if start_box and end_box:
|
||||
start_coords = eval(start_box)
|
||||
end_coords = eval(end_box)
|
||||
|
||||
start_x = int((start_coords[0] + start_coords[2]) / 2 * image_width)
|
||||
start_y = int((start_coords[1] + start_coords[3]) / 2 * image_height)
|
||||
end_x = int((end_coords[0] + end_coords[2]) / 2 * image_width)
|
||||
end_y = int((end_coords[1] + end_coords[3]) / 2 * image_height)
|
||||
|
||||
path = [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}]
|
||||
computer_actions.append(make_drag_item(path))
|
||||
|
||||
return computer_actions
|
||||
|
||||
|
||||
def pil_to_base64(image: Image.Image) -> str:
|
||||
"""Convert PIL image to base64 string."""
|
||||
buffer = BytesIO()
|
||||
image.save(buffer, format="PNG")
|
||||
return base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
def process_image_for_uitars(image_data: str, max_pixels: int = MAX_PIXELS, min_pixels: int = MIN_PIXELS) -> tuple[Image.Image, int, int]:
|
||||
"""Process image for UITARS model input."""
|
||||
# Decode base64 image
|
||||
if image_data.startswith('data:image'):
|
||||
image_data = image_data.split(',')[1]
|
||||
|
||||
image_bytes = base64.b64decode(image_data)
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
|
||||
original_width, original_height = image.size
|
||||
|
||||
# Resize image according to UITARS requirements
|
||||
if image.width * image.height > max_pixels:
|
||||
resize_factor = math.sqrt(max_pixels / (image.width * image.height))
|
||||
width = int(image.width * resize_factor)
|
||||
height = int(image.height * resize_factor)
|
||||
image = image.resize((width, height))
|
||||
|
||||
if image.width * image.height < min_pixels:
|
||||
resize_factor = math.sqrt(min_pixels / (image.width * image.height))
|
||||
width = math.ceil(image.width * resize_factor)
|
||||
height = math.ceil(image.height * resize_factor)
|
||||
image = image.resize((width, height))
|
||||
|
||||
if image.mode != "RGB":
|
||||
image = image.convert("RGB")
|
||||
|
||||
return image, original_width, original_height
|
||||
|
||||
|
||||
def sanitize_message(msg: Any) -> Any:
|
||||
"""Return a copy of the message with image_url ommited within content parts"""
|
||||
if isinstance(msg, dict):
|
||||
result = {}
|
||||
for key, value in msg.items():
|
||||
if key == "content" and isinstance(value, list):
|
||||
result[key] = [
|
||||
{k: v for k, v in item.items() if k != "image_url"} if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
elif isinstance(msg, list):
|
||||
return [sanitize_message(item) for item in msg]
|
||||
else:
|
||||
return msg
|
||||
|
||||
|
||||
def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert UITARS internal message format back to LiteLLM format.
|
||||
|
||||
This function processes reasoning, computer_call, and computer_call_output messages
|
||||
and converts them to the appropriate LiteLLM assistant message format.
|
||||
|
||||
Args:
|
||||
messages: List of UITARS internal messages
|
||||
|
||||
Returns:
|
||||
List of LiteLLM formatted messages
|
||||
"""
|
||||
litellm_messages = []
|
||||
current_assistant_content = []
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, dict):
|
||||
message_type = message.get("type")
|
||||
|
||||
if message_type == "reasoning":
|
||||
# Extract reasoning text from summary
|
||||
summary = message.get("summary", [])
|
||||
if summary and isinstance(summary, list):
|
||||
for summary_item in summary:
|
||||
if isinstance(summary_item, dict) and summary_item.get("type") == "summary_text":
|
||||
reasoning_text = summary_item.get("text", "")
|
||||
if reasoning_text:
|
||||
current_assistant_content.append(f"Thought: {reasoning_text}")
|
||||
|
||||
elif message_type == "computer_call":
|
||||
# Convert computer action to UITARS action format
|
||||
action = message.get("action", {})
|
||||
action_type = action.get("type")
|
||||
|
||||
if action_type == "click":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
button = action.get("button", "left")
|
||||
if button == "left":
|
||||
action_text = f"Action: click(start_box='({x},{y})')"
|
||||
elif button == "right":
|
||||
action_text = f"Action: right_single(start_box='({x},{y})')"
|
||||
else:
|
||||
action_text = f"Action: click(start_box='({x},{y})')"
|
||||
|
||||
elif action_type == "double_click":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
action_text = f"Action: left_double(start_box='({x},{y})')"
|
||||
|
||||
elif action_type == "drag":
|
||||
start_x, start_y = action.get("start_x", 0), action.get("start_y", 0)
|
||||
end_x, end_y = action.get("end_x", 0), action.get("end_y", 0)
|
||||
action_text = f"Action: drag(start_box='({start_x},{start_y})', end_box='({end_x},{end_y})')"
|
||||
|
||||
elif action_type == "key":
|
||||
key = action.get("key", "")
|
||||
action_text = f"Action: hotkey(key='{key}')"
|
||||
|
||||
elif action_type == "type":
|
||||
text = action.get("text", "")
|
||||
# Escape single quotes in the text
|
||||
escaped_text = escape_single_quotes(text)
|
||||
action_text = f"Action: type(content='{escaped_text}')"
|
||||
|
||||
elif action_type == "scroll":
|
||||
x, y = action.get("x", 0), action.get("y", 0)
|
||||
direction = action.get("direction", "down")
|
||||
action_text = f"Action: scroll(start_box='({x},{y})', direction='{direction}')"
|
||||
|
||||
elif action_type == "wait":
|
||||
action_text = "Action: wait()"
|
||||
|
||||
else:
|
||||
# Fallback for unknown action types
|
||||
action_text = f"Action: {action_type}({action})"
|
||||
|
||||
current_assistant_content.append(action_text)
|
||||
|
||||
# When we hit a computer_call_output, finalize the current assistant message
|
||||
if current_assistant_content:
|
||||
litellm_messages.append({
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "\n".join(current_assistant_content)}]
|
||||
})
|
||||
current_assistant_content = []
|
||||
|
||||
elif message_type == "computer_call_output":
|
||||
# Add screenshot from computer call output
|
||||
output = message.get("output", {})
|
||||
if isinstance(output, dict) and output.get("type") == "input_image":
|
||||
image_url = output.get("image_url", "")
|
||||
if image_url:
|
||||
litellm_messages.append({
|
||||
"role": "user",
|
||||
"content": [{"type": "image_url", "image_url": {"url": image_url}}]
|
||||
})
|
||||
|
||||
elif message.get("role") == "user":
|
||||
# # Handle user messages
|
||||
# content = message.get("content", "")
|
||||
# if isinstance(content, str):
|
||||
# litellm_messages.append({
|
||||
# "role": "user",
|
||||
# "content": content
|
||||
# })
|
||||
# elif isinstance(content, list):
|
||||
# litellm_messages.append({
|
||||
# "role": "user",
|
||||
# "content": content
|
||||
# })
|
||||
pass
|
||||
|
||||
# Add any remaining assistant content
|
||||
if current_assistant_content:
|
||||
litellm_messages.append({
|
||||
"role": "assistant",
|
||||
"content": current_assistant_content
|
||||
})
|
||||
|
||||
return litellm_messages
|
||||
|
||||
@agent_loop(models=r"(?i).*ui-?tars.*", priority=10)
|
||||
async def uitars_loop(
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]:
|
||||
"""
|
||||
UITARS agent loop using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B model.
|
||||
|
||||
Supports UITARS vision-language models for computer control.
|
||||
"""
|
||||
tools = tools or []
|
||||
|
||||
# Create response items
|
||||
response_items = []
|
||||
|
||||
# Find computer tool for screen dimensions
|
||||
computer_tool = None
|
||||
for tool_schema in tools:
|
||||
if tool_schema["type"] == "computer":
|
||||
computer_tool = tool_schema["computer"]
|
||||
break
|
||||
|
||||
# Get screen dimensions
|
||||
screen_width, screen_height = 1024, 768
|
||||
if computer_tool:
|
||||
try:
|
||||
screen_width, screen_height = await computer_tool.get_dimensions()
|
||||
except:
|
||||
pass
|
||||
|
||||
# Process messages to extract instruction and image
|
||||
instruction = ""
|
||||
image_data = None
|
||||
|
||||
# Convert messages to list if string
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
# Extract instruction and latest screenshot
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, dict):
|
||||
content = message.get("content", "")
|
||||
|
||||
# Handle different content formats
|
||||
if isinstance(content, str):
|
||||
if not instruction and message.get("role") == "user":
|
||||
instruction = content
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text" and not instruction:
|
||||
instruction = item.get("text", "")
|
||||
elif item.get("type") == "image_url" and not image_data:
|
||||
image_url = item.get("image_url", {})
|
||||
if isinstance(image_url, dict):
|
||||
image_data = image_url.get("url", "")
|
||||
else:
|
||||
image_data = image_url
|
||||
|
||||
# Also check for computer_call_output with screenshots
|
||||
if message.get("type") == "computer_call_output" and not image_data:
|
||||
output = message.get("output", {})
|
||||
if isinstance(output, dict) and output.get("type") == "input_image":
|
||||
image_data = output.get("image_url", "")
|
||||
|
||||
if instruction and image_data:
|
||||
break
|
||||
|
||||
if not instruction:
|
||||
instruction = "Help me complete this task by analyzing the screen and taking appropriate actions."
|
||||
|
||||
# Create prompt
|
||||
user_prompt = UITARS_PROMPT_TEMPLATE.format(
|
||||
instruction=instruction,
|
||||
action_space=UITARS_ACTION_SPACE,
|
||||
language="English"
|
||||
)
|
||||
|
||||
# Convert conversation history to LiteLLM format
|
||||
history_messages = convert_uitars_messages_to_litellm(messages)
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant."
|
||||
}
|
||||
]
|
||||
|
||||
# Add current user instruction with screenshot
|
||||
current_user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_prompt},
|
||||
]
|
||||
}
|
||||
litellm_messages.append(current_user_message)
|
||||
|
||||
# Process image for UITARS
|
||||
if not image_data:
|
||||
# Take screenshot if none found in messages
|
||||
if computer_handler:
|
||||
image_data = await computer_handler.screenshot()
|
||||
await _on_screenshot(image_data, "screenshot_before")
|
||||
|
||||
# Add screenshot to output items so it can be retained in history
|
||||
response_items.append(make_input_image_item(image_data))
|
||||
else:
|
||||
raise ValueError("No screenshot found in messages and no computer_handler provided")
|
||||
processed_image, original_width, original_height = process_image_for_uitars(image_data)
|
||||
encoded_image = pil_to_base64(processed_image)
|
||||
|
||||
# Add conversation history
|
||||
if history_messages:
|
||||
litellm_messages.extend(history_messages)
|
||||
else:
|
||||
litellm_messages.append({
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
|
||||
]
|
||||
})
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": litellm_messages,
|
||||
"max_tokens": kwargs.get("max_tokens", 500),
|
||||
"temperature": kwargs.get("temperature", 0.0),
|
||||
"do_sample": kwargs.get("temperature", 0.0) > 0.0,
|
||||
"num_retries": max_retries,
|
||||
**{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]}
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
# Call liteLLM with UITARS model
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract response content
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
# Parse UITARS response
|
||||
parsed_responses = parse_uitars_response(response_content, original_width, original_height)
|
||||
|
||||
# Convert to computer actions
|
||||
computer_actions = convert_to_computer_actions(parsed_responses, original_width, original_height)
|
||||
|
||||
# Add computer actions to response items
|
||||
thought = parsed_responses[0].get("thought", "")
|
||||
if thought:
|
||||
response_items.append(make_reasoning_item(thought))
|
||||
response_items.extend(computer_actions)
|
||||
|
||||
# Extract usage information
|
||||
response_usage = {
|
||||
**LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(),
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(response_usage)
|
||||
|
||||
# Create agent response
|
||||
agent_response = {
|
||||
"output": response_items,
|
||||
"usage": response_usage
|
||||
}
|
||||
|
||||
return agent_response
|
||||
@@ -1,4 +0,0 @@
|
||||
"""Provider implementations for different AI services."""
|
||||
|
||||
# Import specific providers only when needed to avoid circular imports
|
||||
__all__ = [] # Let each provider module handle its own exports
|
||||
@@ -1,6 +0,0 @@
|
||||
"""Anthropic provider implementation."""
|
||||
|
||||
from .loop import AnthropicLoop
|
||||
from .types import LLMProvider
|
||||
|
||||
__all__ = ["AnthropicLoop", "LLMProvider"]
|
||||
@@ -1,360 +0,0 @@
|
||||
from typing import Any, List, Dict, cast
|
||||
import httpx
|
||||
import asyncio
|
||||
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
|
||||
from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaToolUnionParam
|
||||
from ..types import LLMProvider
|
||||
from .logging import log_api_interaction
|
||||
import random
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class APIConnectionError(Exception):
|
||||
"""Error raised when there are connection issues with the API."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class BaseAnthropicClient:
|
||||
"""Base class for Anthropic API clients."""
|
||||
|
||||
MAX_RETRIES = 10
|
||||
INITIAL_RETRY_DELAY = 1.0
|
||||
MAX_RETRY_DELAY = 60.0
|
||||
JITTER_FACTOR = 0.1
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
*,
|
||||
messages: list[BetaMessageParam],
|
||||
system: list[Any],
|
||||
tools: list[BetaToolUnionParam],
|
||||
max_tokens: int,
|
||||
betas: list[str],
|
||||
) -> BetaMessage:
|
||||
"""Create a message using the Anthropic API."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _make_api_call_with_retries(self, api_call):
|
||||
"""Make an API call with exponential backoff retry logic.
|
||||
|
||||
Args:
|
||||
api_call: Async function that makes the actual API call
|
||||
|
||||
Returns:
|
||||
API response
|
||||
|
||||
Raises:
|
||||
APIConnectionError: If all retries fail
|
||||
"""
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
while retry_count < self.MAX_RETRIES:
|
||||
try:
|
||||
return await api_call()
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
retry_count += 1
|
||||
|
||||
if retry_count == self.MAX_RETRIES:
|
||||
break
|
||||
|
||||
# Calculate delay with exponential backoff and jitter
|
||||
delay = min(
|
||||
self.INITIAL_RETRY_DELAY * (2 ** (retry_count - 1)), self.MAX_RETRY_DELAY
|
||||
)
|
||||
# Add jitter to avoid thundering herd
|
||||
jitter = delay * self.JITTER_FACTOR * (2 * random.random() - 1)
|
||||
final_delay = delay + jitter
|
||||
|
||||
logger.info(
|
||||
f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) "
|
||||
f"in {final_delay:.2f} seconds after error: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(final_delay)
|
||||
|
||||
raise APIConnectionError(
|
||||
f"Failed after {self.MAX_RETRIES} retries. " f"Last error: {str(last_error)}"
|
||||
)
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: int = 4096
|
||||
) -> Any:
|
||||
"""Run the Anthropic API with the Claude model, supports interleaved tool calling.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
system: System prompt
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
# Add the tool_result check/fix logic here
|
||||
fixed_messages = self._fix_missing_tool_results(messages)
|
||||
|
||||
# Get model name from concrete implementation if available
|
||||
model_name = getattr(self, "model", "unknown model")
|
||||
logger.info(f"Running Anthropic API call with model {model_name}")
|
||||
|
||||
retry_count = 0
|
||||
|
||||
while retry_count < self.MAX_RETRIES:
|
||||
try:
|
||||
# Call the Anthropic API through create_message which is implemented by subclasses
|
||||
# Convert system str to the list format expected by create_message
|
||||
system_list = [system]
|
||||
|
||||
# Convert message format if needed - concrete implementations may do further conversion
|
||||
response = await self.create_message(
|
||||
messages=cast(list[BetaMessageParam], fixed_messages),
|
||||
system=system_list,
|
||||
tools=[], # Tools are included in the messages
|
||||
max_tokens=max_tokens,
|
||||
betas=["tools-2023-12-13"],
|
||||
)
|
||||
logger.info(f"Anthropic API call successful")
|
||||
return response
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
wait_time = self.INITIAL_RETRY_DELAY * (
|
||||
2 ** (retry_count - 1)
|
||||
) # Exponential backoff
|
||||
logger.info(
|
||||
f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) in {wait_time:.2f} seconds after error: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
# If we get here, all retries failed
|
||||
raise RuntimeError(f"Failed to call Anthropic API after {self.MAX_RETRIES} attempts")
|
||||
|
||||
def _fix_missing_tool_results(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Check for and fix any missing tool_result blocks after tool_use blocks.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
|
||||
Returns:
|
||||
Fixed messages with proper tool_result blocks
|
||||
"""
|
||||
fixed_messages = []
|
||||
pending_tool_uses = {} # Map of tool_use IDs to their details
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
# Track any tool_use blocks in this message
|
||||
if message.get("role") == "assistant" and "content" in message:
|
||||
content = message.get("content", [])
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
tool_id = block.get("id")
|
||||
if tool_id:
|
||||
pending_tool_uses[tool_id] = {
|
||||
"name": block.get("name", ""),
|
||||
"input": block.get("input", {}),
|
||||
}
|
||||
|
||||
# Check if this message handles any pending tool_use blocks
|
||||
if message.get("role") == "user" and "content" in message:
|
||||
# Check for tool_result blocks in this message
|
||||
content = message.get("content", [])
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
tool_id = block.get("tool_use_id")
|
||||
if tool_id in pending_tool_uses:
|
||||
# This tool_result handles a pending tool_use
|
||||
pending_tool_uses.pop(tool_id)
|
||||
|
||||
# Add the message to our fixed list
|
||||
fixed_messages.append(message)
|
||||
|
||||
# If this is an assistant message with tool_use blocks and there are
|
||||
# pending tool uses that need to be resolved before the next assistant message
|
||||
if (
|
||||
i + 1 < len(messages)
|
||||
and message.get("role") == "assistant"
|
||||
and messages[i + 1].get("role") == "assistant"
|
||||
and pending_tool_uses
|
||||
):
|
||||
|
||||
# We need to insert a user message with tool_results for all pending tool_uses
|
||||
tool_results = []
|
||||
for tool_id, tool_info in pending_tool_uses.items():
|
||||
tool_results.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_id,
|
||||
"content": {
|
||||
"type": "error",
|
||||
"message": "Tool execution was skipped or failed",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# Insert a synthetic user message with the tool results
|
||||
if tool_results:
|
||||
fixed_messages.append({"role": "user", "content": tool_results})
|
||||
|
||||
# Clear pending tools since we've added results for them
|
||||
pending_tool_uses = {}
|
||||
|
||||
# Check if there are any remaining pending tool_uses at the end of the conversation
|
||||
if pending_tool_uses and fixed_messages and fixed_messages[-1].get("role") == "assistant":
|
||||
# Add a final user message with tool results for any pending tool_uses
|
||||
tool_results = []
|
||||
for tool_id, tool_info in pending_tool_uses.items():
|
||||
tool_results.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_id,
|
||||
"content": {
|
||||
"type": "error",
|
||||
"message": "Tool execution was skipped or failed",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if tool_results:
|
||||
fixed_messages.append({"role": "user", "content": tool_results})
|
||||
|
||||
return fixed_messages
|
||||
|
||||
|
||||
class AnthropicDirectClient(BaseAnthropicClient):
|
||||
"""Direct Anthropic API client implementation."""
|
||||
|
||||
def __init__(self, api_key: str, model: str):
|
||||
self.model = model
|
||||
self.client = Anthropic(api_key=api_key, http_client=self._create_http_client())
|
||||
|
||||
def _create_http_client(self) -> httpx.Client:
|
||||
"""Create an HTTP client with appropriate settings."""
|
||||
return httpx.Client(
|
||||
verify=True,
|
||||
timeout=httpx.Timeout(connect=30.0, read=300.0, write=30.0, pool=30.0),
|
||||
transport=httpx.HTTPTransport(
|
||||
retries=3,
|
||||
verify=True,
|
||||
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
||||
),
|
||||
)
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
*,
|
||||
messages: list[BetaMessageParam],
|
||||
system: list[Any],
|
||||
tools: list[BetaToolUnionParam],
|
||||
max_tokens: int,
|
||||
betas: list[str],
|
||||
) -> BetaMessage:
|
||||
"""Create a message using the direct Anthropic API with retry logic."""
|
||||
|
||||
async def api_call():
|
||||
response = self.client.beta.messages.with_raw_response.create(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
system=system,
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
log_api_interaction(response.http_response.request, response.http_response, None)
|
||||
return response.parse()
|
||||
|
||||
try:
|
||||
return await self._make_api_call_with_retries(api_call)
|
||||
except Exception as e:
|
||||
log_api_interaction(None, None, e)
|
||||
raise
|
||||
|
||||
|
||||
class AnthropicVertexClient(BaseAnthropicClient):
|
||||
"""Google Cloud Vertex AI implementation of Anthropic client."""
|
||||
|
||||
def __init__(self, model: str):
|
||||
self.model = model
|
||||
self.client = AnthropicVertex()
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
*,
|
||||
messages: list[BetaMessageParam],
|
||||
system: list[Any],
|
||||
tools: list[BetaToolUnionParam],
|
||||
max_tokens: int,
|
||||
betas: list[str],
|
||||
) -> BetaMessage:
|
||||
"""Create a message using Vertex AI with retry logic."""
|
||||
|
||||
async def api_call():
|
||||
response = self.client.beta.messages.with_raw_response.create(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
system=system,
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
log_api_interaction(response.http_response.request, response.http_response, None)
|
||||
return response.parse()
|
||||
|
||||
try:
|
||||
return await self._make_api_call_with_retries(api_call)
|
||||
except Exception as e:
|
||||
log_api_interaction(None, None, e)
|
||||
raise
|
||||
|
||||
|
||||
class AnthropicBedrockClient(BaseAnthropicClient):
|
||||
"""AWS Bedrock implementation of Anthropic client."""
|
||||
|
||||
def __init__(self, model: str):
|
||||
self.model = model
|
||||
self.client = AnthropicBedrock()
|
||||
|
||||
async def create_message(
|
||||
self,
|
||||
*,
|
||||
messages: list[BetaMessageParam],
|
||||
system: list[Any],
|
||||
tools: list[BetaToolUnionParam],
|
||||
max_tokens: int,
|
||||
betas: list[str],
|
||||
) -> BetaMessage:
|
||||
"""Create a message using AWS Bedrock with retry logic."""
|
||||
|
||||
async def api_call():
|
||||
response = self.client.beta.messages.with_raw_response.create(
|
||||
max_tokens=max_tokens,
|
||||
messages=messages,
|
||||
model=self.model,
|
||||
system=system,
|
||||
tools=tools,
|
||||
betas=betas,
|
||||
)
|
||||
log_api_interaction(response.http_response.request, response.http_response, None)
|
||||
return response.parse()
|
||||
|
||||
try:
|
||||
return await self._make_api_call_with_retries(api_call)
|
||||
except Exception as e:
|
||||
log_api_interaction(None, None, e)
|
||||
raise
|
||||
|
||||
|
||||
class AnthropicClientFactory:
|
||||
"""Factory for creating appropriate Anthropic client implementations."""
|
||||
|
||||
@staticmethod
|
||||
def create_client(provider: LLMProvider, api_key: str, model: str) -> BaseAnthropicClient:
|
||||
"""Create an appropriate client based on the provider."""
|
||||
if provider == LLMProvider.ANTHROPIC:
|
||||
return AnthropicDirectClient(api_key, model)
|
||||
elif provider == LLMProvider.VERTEX:
|
||||
return AnthropicVertexClient(model)
|
||||
elif provider == LLMProvider.BEDROCK:
|
||||
return AnthropicBedrockClient(model)
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
@@ -1,150 +0,0 @@
|
||||
"""API logging functionality."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
import httpx
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def _filter_base64_images(content: Any) -> Any:
|
||||
"""Filter out base64 image data from content.
|
||||
|
||||
Args:
|
||||
content: Content to filter
|
||||
|
||||
Returns:
|
||||
Filtered content with base64 data replaced by placeholder
|
||||
"""
|
||||
if isinstance(content, dict):
|
||||
filtered = {}
|
||||
for key, value in content.items():
|
||||
if (
|
||||
isinstance(value, dict)
|
||||
and value.get("type") == "image"
|
||||
and value.get("source", {}).get("type") == "base64"
|
||||
):
|
||||
# Replace base64 data with placeholder
|
||||
filtered[key] = {
|
||||
**value,
|
||||
"source": {
|
||||
**value["source"],
|
||||
"data": "<base64_image_data>"
|
||||
}
|
||||
}
|
||||
else:
|
||||
filtered[key] = _filter_base64_images(value)
|
||||
return filtered
|
||||
elif isinstance(content, list):
|
||||
return [_filter_base64_images(item) for item in content]
|
||||
return content
|
||||
|
||||
def log_api_interaction(
|
||||
request: httpx.Request | None,
|
||||
response: httpx.Response | object | None,
|
||||
error: Exception | None,
|
||||
log_dir: Path = Path("/tmp/claude_logs")
|
||||
) -> None:
|
||||
"""Log API request, response, and any errors in a structured way.
|
||||
|
||||
Args:
|
||||
request: The HTTP request if available
|
||||
response: The HTTP response or response object
|
||||
error: Any error that occurred
|
||||
log_dir: Directory to store log files
|
||||
"""
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")
|
||||
|
||||
# Helper function to safely decode JSON content
|
||||
def safe_json_decode(content):
|
||||
if not content:
|
||||
return None
|
||||
try:
|
||||
if isinstance(content, bytes):
|
||||
return json.loads(content.decode())
|
||||
elif isinstance(content, str):
|
||||
return json.loads(content)
|
||||
elif isinstance(content, dict):
|
||||
return content
|
||||
return None
|
||||
except json.JSONDecodeError:
|
||||
return {"error": "Could not decode JSON", "raw": str(content)}
|
||||
|
||||
# Process request content
|
||||
request_content = None
|
||||
if request and request.content:
|
||||
request_content = safe_json_decode(request.content)
|
||||
request_content = _filter_base64_images(request_content)
|
||||
|
||||
# Process response content
|
||||
response_content = None
|
||||
if response:
|
||||
if isinstance(response, httpx.Response):
|
||||
try:
|
||||
response_content = response.json()
|
||||
except json.JSONDecodeError:
|
||||
response_content = {"error": "Could not decode JSON", "raw": response.text}
|
||||
else:
|
||||
response_content = safe_json_decode(response)
|
||||
response_content = _filter_base64_images(response_content)
|
||||
|
||||
log_entry = {
|
||||
"timestamp": timestamp,
|
||||
"request": {
|
||||
"method": request.method if request else None,
|
||||
"url": str(request.url) if request else None,
|
||||
"headers": dict(request.headers) if request else None,
|
||||
"content": request_content,
|
||||
} if request else None,
|
||||
"response": {
|
||||
"status_code": response.status_code if isinstance(response, httpx.Response) else None,
|
||||
"headers": dict(response.headers) if isinstance(response, httpx.Response) else None,
|
||||
"content": response_content,
|
||||
} if response else None,
|
||||
"error": {
|
||||
"type": type(error).__name__ if error else None,
|
||||
"message": str(error) if error else None,
|
||||
} if error else None
|
||||
}
|
||||
|
||||
# Log to file with timestamp in filename
|
||||
log_dir.mkdir(exist_ok=True)
|
||||
log_file = log_dir / f"claude_api_{timestamp.replace(' ', '_').replace(':', '-')}.json"
|
||||
|
||||
with open(log_file, 'w') as f:
|
||||
json.dump(log_entry, f, indent=2)
|
||||
|
||||
# Also log a summary to the console
|
||||
if error:
|
||||
logger.error(f"API Error at {timestamp}: {error}")
|
||||
else:
|
||||
logger.info(
|
||||
f"API Call at {timestamp}: "
|
||||
f"{request.method if request else 'No request'} -> "
|
||||
f"{response.status_code if isinstance(response, httpx.Response) else 'No response'}"
|
||||
)
|
||||
|
||||
# Log if there are any images in the content
|
||||
if response_content:
|
||||
image_count = count_images(response_content)
|
||||
if image_count > 0:
|
||||
logger.info(f"Response contains {image_count} images")
|
||||
|
||||
def count_images(content: dict | list | Any) -> int:
|
||||
"""Count the number of images in the content.
|
||||
|
||||
Args:
|
||||
content: Content to search for images
|
||||
|
||||
Returns:
|
||||
Number of images found
|
||||
"""
|
||||
if isinstance(content, dict):
|
||||
if content.get("type") == "image":
|
||||
return 1
|
||||
return sum(count_images(v) for v in content.values())
|
||||
elif isinstance(content, list):
|
||||
return sum(count_images(item) for item in content)
|
||||
return 0
|
||||
@@ -1,140 +0,0 @@
|
||||
"""API call handling for Anthropic provider."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from anthropic.types.beta import (
|
||||
BetaMessage,
|
||||
BetaMessageParam,
|
||||
BetaTextBlockParam,
|
||||
)
|
||||
|
||||
from .types import LLMProvider
|
||||
from .prompts import SYSTEM_PROMPT
|
||||
|
||||
# Constants
|
||||
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
|
||||
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicAPIHandler:
|
||||
"""Handles API calls to Anthropic's API with structured error handling and retries."""
|
||||
|
||||
def __init__(self, loop):
|
||||
"""Initialize the API handler.
|
||||
|
||||
Args:
|
||||
loop: Reference to the parent loop instance that provides context
|
||||
"""
|
||||
self.loop = loop
|
||||
|
||||
async def make_api_call(
|
||||
self, messages: List[BetaMessageParam], system_prompt: str = SYSTEM_PROMPT
|
||||
) -> BetaMessage:
|
||||
"""Make API call to Anthropic with retry logic.
|
||||
|
||||
Args:
|
||||
messages: List of messages to send to the API
|
||||
system_prompt: System prompt to use (default: SYSTEM_PROMPT)
|
||||
|
||||
Returns:
|
||||
API response
|
||||
|
||||
Raises:
|
||||
RuntimeError: If API call fails after all retries
|
||||
"""
|
||||
if self.loop.client is None:
|
||||
raise RuntimeError("Client not initialized. Call initialize_client() first.")
|
||||
if self.loop.tool_manager is None:
|
||||
raise RuntimeError("Tool manager not initialized. Call initialize_client() first.")
|
||||
|
||||
last_error = None
|
||||
|
||||
# Add detailed debug logging to examine messages
|
||||
logger.info(f"Sending {len(messages)} messages to Anthropic API")
|
||||
|
||||
# Log tool use IDs and tool result IDs for debugging
|
||||
tool_use_ids = set()
|
||||
tool_result_ids = set()
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
logger.info(f"Message {i}: role={msg.get('role')}")
|
||||
if isinstance(msg.get("content"), list):
|
||||
for content_block in msg.get("content", []):
|
||||
if isinstance(content_block, dict):
|
||||
block_type = content_block.get("type")
|
||||
if block_type == "tool_use" and "id" in content_block:
|
||||
tool_id = content_block.get("id")
|
||||
tool_use_ids.add(tool_id)
|
||||
logger.info(f" - Found tool_use with ID: {tool_id}")
|
||||
elif block_type == "tool_result" and "tool_use_id" in content_block:
|
||||
result_id = content_block.get("tool_use_id")
|
||||
tool_result_ids.add(result_id)
|
||||
logger.info(f" - Found tool_result referencing ID: {result_id}")
|
||||
|
||||
# Check for mismatches
|
||||
missing_tool_uses = tool_result_ids - tool_use_ids
|
||||
if missing_tool_uses:
|
||||
logger.warning(
|
||||
f"Found tool_result IDs without matching tool_use IDs: {missing_tool_uses}"
|
||||
)
|
||||
|
||||
for attempt in range(self.loop.max_retries):
|
||||
try:
|
||||
# Log request
|
||||
request_data = {
|
||||
"messages": messages,
|
||||
"max_tokens": self.loop.max_tokens,
|
||||
"system": system_prompt,
|
||||
}
|
||||
# Let ExperimentManager handle sanitization
|
||||
self.loop._log_api_call("request", request_data)
|
||||
|
||||
# Setup betas and system
|
||||
system = BetaTextBlockParam(
|
||||
type="text",
|
||||
text=system_prompt,
|
||||
)
|
||||
|
||||
betas = [COMPUTER_USE_BETA_FLAG]
|
||||
# Add prompt caching if enabled in the message manager's config
|
||||
if self.loop.message_manager.config.enable_caching:
|
||||
betas.append(PROMPT_CACHING_BETA_FLAG)
|
||||
system["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
# Make API call
|
||||
response = await self.loop.client.create_message(
|
||||
messages=messages,
|
||||
system=[system],
|
||||
tools=self.loop.tool_manager.get_tool_params(),
|
||||
max_tokens=self.loop.max_tokens,
|
||||
betas=betas,
|
||||
)
|
||||
|
||||
# Let ExperimentManager handle sanitization
|
||||
self.loop._log_api_call("response", request_data, response)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.error(
|
||||
f"Error in API call (attempt {attempt + 1}/{self.loop.max_retries}): {str(e)}"
|
||||
)
|
||||
self.loop._log_api_call("error", {"messages": messages}, error=e)
|
||||
|
||||
if attempt < self.loop.max_retries - 1:
|
||||
await asyncio.sleep(
|
||||
self.loop.retry_delay * (attempt + 1)
|
||||
) # Exponential backoff
|
||||
continue
|
||||
|
||||
# If we get here, all retries failed
|
||||
error_message = f"API call failed after {self.loop.max_retries} attempts"
|
||||
if last_error:
|
||||
error_message += f": {str(last_error)}"
|
||||
|
||||
logger.error(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Anthropic callbacks package."""
|
||||
|
||||
from .manager import CallbackManager
|
||||
|
||||
__all__ = ["CallbackManager"]
|
||||
@@ -1,65 +0,0 @@
|
||||
from typing import Callable, Protocol
|
||||
import httpx
|
||||
from anthropic.types.beta import BetaContentBlockParam
|
||||
from ..tools import ToolResult
|
||||
|
||||
|
||||
class APICallback(Protocol):
|
||||
"""Protocol for API callbacks."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
request: httpx.Request | None,
|
||||
response: httpx.Response | object | None,
|
||||
error: Exception | None,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
class ContentCallback(Protocol):
|
||||
"""Protocol for content callbacks."""
|
||||
|
||||
def __call__(self, content: BetaContentBlockParam) -> None: ...
|
||||
|
||||
|
||||
class ToolCallback(Protocol):
|
||||
"""Protocol for tool callbacks."""
|
||||
|
||||
def __call__(self, result: ToolResult, tool_id: str) -> None: ...
|
||||
|
||||
|
||||
class CallbackManager:
|
||||
"""Manages various callbacks for the agent system."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content_callback: ContentCallback,
|
||||
tool_callback: ToolCallback,
|
||||
api_callback: APICallback,
|
||||
):
|
||||
"""Initialize the callback manager.
|
||||
|
||||
Args:
|
||||
content_callback: Callback for content updates
|
||||
tool_callback: Callback for tool execution results
|
||||
api_callback: Callback for API interactions
|
||||
"""
|
||||
self.content_callback = content_callback
|
||||
self.tool_callback = tool_callback
|
||||
self.api_callback = api_callback
|
||||
|
||||
def on_content(self, content: BetaContentBlockParam) -> None:
|
||||
"""Handle content updates."""
|
||||
self.content_callback(content)
|
||||
|
||||
def on_tool_result(self, result: ToolResult, tool_id: str) -> None:
|
||||
"""Handle tool execution results."""
|
||||
self.tool_callback(result, tool_id)
|
||||
|
||||
def on_api_interaction(
|
||||
self,
|
||||
request: httpx.Request | None,
|
||||
response: httpx.Response | object | None,
|
||||
error: Exception | None,
|
||||
) -> None:
|
||||
"""Handle API interactions."""
|
||||
self.api_callback(request, response, error)
|
||||
@@ -1,568 +0,0 @@
|
||||
"""Anthropic-specific agent loop implementation."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, cast
|
||||
from anthropic.types.beta import (
|
||||
BetaMessage,
|
||||
BetaMessageParam,
|
||||
BetaTextBlock,
|
||||
BetaContentBlockParam,
|
||||
)
|
||||
import base64
|
||||
from datetime import datetime
|
||||
|
||||
# Computer
|
||||
from computer import Computer
|
||||
|
||||
# Base imports
|
||||
from ...core.base import BaseLoop
|
||||
from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
||||
from ...core.types import AgentResponse
|
||||
|
||||
# Anthropic provider-specific imports
|
||||
from .api.client import AnthropicClientFactory, BaseAnthropicClient
|
||||
from .tools.manager import ToolManager
|
||||
from .prompts import SYSTEM_PROMPT
|
||||
from .types import LLMProvider
|
||||
from .tools import ToolResult
|
||||
from .utils import to_anthropic_format, to_agent_response_format
|
||||
|
||||
# Import the new modules we created
|
||||
from .api_handler import AnthropicAPIHandler
|
||||
from .response_handler import AnthropicResponseHandler
|
||||
from .callbacks.manager import CallbackManager
|
||||
|
||||
# Constants
|
||||
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
|
||||
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicLoop(BaseLoop):
|
||||
"""Anthropic-specific implementation of the agent loop.
|
||||
|
||||
This class extends BaseLoop to provide specialized support for Anthropic's Claude models
|
||||
with their unique tool-use capabilities, custom message formatting, and
|
||||
callback-driven approach to handling responses.
|
||||
"""
|
||||
|
||||
###########################################
|
||||
# INITIALIZATION AND CONFIGURATION
|
||||
###########################################
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
computer: Computer,
|
||||
model: str = "claude-3-7-sonnet-20250219",
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
save_trajectory: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the Anthropic loop.
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key
|
||||
model: Model name (fixed to claude-3-7-sonnet-20250219)
|
||||
computer: Computer instance
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
base_dir: Base directory for saving experiment data
|
||||
max_retries: Maximum number of retries for API calls
|
||||
retry_delay: Delay between retries in seconds
|
||||
save_trajectory: Whether to save trajectory data
|
||||
"""
|
||||
# Initialize base class with core config
|
||||
super().__init__(
|
||||
computer=computer,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
base_dir=base_dir,
|
||||
save_trajectory=save_trajectory,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize message manager
|
||||
self.message_manager = StandardMessageManager(
|
||||
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
||||
)
|
||||
|
||||
# Anthropic-specific attributes
|
||||
self.provider = LLMProvider.ANTHROPIC
|
||||
self.client = None
|
||||
self.retry_count = 0
|
||||
self.tool_manager = None
|
||||
self.callback_manager = None
|
||||
self.queue = asyncio.Queue() # Initialize queue
|
||||
self.loop_task = None # Store the loop task for cancellation
|
||||
|
||||
# Initialize handlers
|
||||
self.api_handler = AnthropicAPIHandler(self)
|
||||
self.response_handler = AnthropicResponseHandler(self)
|
||||
|
||||
###########################################
|
||||
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def initialize_client(self) -> None:
|
||||
"""Initialize the Anthropic API client and tools.
|
||||
|
||||
Implements abstract method from BaseLoop to set up the Anthropic-specific
|
||||
client, tool manager, message manager, and callback handlers.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Initializing Anthropic client with model {self.model}...")
|
||||
|
||||
# Initialize client
|
||||
self.client = AnthropicClientFactory.create_client(
|
||||
provider=self.provider, api_key=self.api_key, model=self.model
|
||||
)
|
||||
|
||||
# Initialize callback manager with our callback handlers
|
||||
self.callback_manager = CallbackManager(
|
||||
content_callback=self._handle_content,
|
||||
tool_callback=self._handle_tool_result,
|
||||
api_callback=self._handle_api_interaction,
|
||||
)
|
||||
|
||||
# Initialize tool manager
|
||||
self.tool_manager = ToolManager(self.computer)
|
||||
await self.tool_manager.initialize()
|
||||
|
||||
logger.info(f"Initialized Anthropic client with model {self.model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing Anthropic client: {str(e)}")
|
||||
self.client = None
|
||||
raise RuntimeError(f"Failed to initialize Anthropic client: {str(e)}")
|
||||
|
||||
###########################################
|
||||
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
messages: List of message objects in standard OpenAI format
|
||||
|
||||
Yields:
|
||||
Agent response format
|
||||
"""
|
||||
try:
|
||||
logger.info("Starting Anthropic loop run")
|
||||
|
||||
# Create queue for response streaming
|
||||
queue = asyncio.Queue()
|
||||
|
||||
# Ensure client is initialized
|
||||
if self.client is None or self.tool_manager is None:
|
||||
logger.info("Initializing client...")
|
||||
await self.initialize_client()
|
||||
if self.client is None:
|
||||
raise RuntimeError("Failed to initialize client")
|
||||
logger.info("Client initialized successfully")
|
||||
|
||||
# Start loop in background task
|
||||
self.loop_task = asyncio.create_task(self._run_loop(queue, messages))
|
||||
|
||||
# Process and yield messages as they arrive
|
||||
while True:
|
||||
try:
|
||||
item = await queue.get()
|
||||
if item is None: # Stop signal
|
||||
break
|
||||
yield item
|
||||
queue.task_done()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing queue item: {str(e)}")
|
||||
continue
|
||||
|
||||
# Wait for loop to complete
|
||||
await self.loop_task
|
||||
|
||||
# Send completion message
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": "Task completed successfully.",
|
||||
"metadata": {"title": "✅ Complete"},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing task: {str(e)}")
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel the currently running agent loop task.
|
||||
|
||||
This method stops the ongoing processing in the agent loop
|
||||
by cancelling the loop_task if it exists and is running.
|
||||
"""
|
||||
if self.loop_task and not self.loop_task.done():
|
||||
logger.info("Cancelling Anthropic loop task")
|
||||
self.loop_task.cancel()
|
||||
try:
|
||||
# Wait for the task to be cancelled with a timeout
|
||||
await asyncio.wait_for(self.loop_task, timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout while waiting for loop task to cancel")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Loop task cancelled successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error while cancelling loop task: {str(e)}")
|
||||
finally:
|
||||
# Put None in the queue to signal any waiting consumers to stop
|
||||
await self.queue.put(None)
|
||||
logger.info("Anthropic loop task cancelled")
|
||||
else:
|
||||
logger.info("No active Anthropic loop task to cancel")
|
||||
|
||||
###########################################
|
||||
# AGENT LOOP IMPLEMENTATION
|
||||
###########################################
|
||||
|
||||
async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
queue: Queue for response streaming
|
||||
messages: List of messages in standard OpenAI format
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
# Capture screenshot
|
||||
try:
|
||||
# Take screenshot - always returns raw PNG bytes
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
logger.info("Screenshot captured successfully")
|
||||
|
||||
# Convert PNG bytes to base64
|
||||
base64_image = base64.b64encode(screenshot).decode("utf-8")
|
||||
logger.info(f"Screenshot converted to base64 (size: {len(base64_image)} bytes)")
|
||||
|
||||
# Save screenshot if requested
|
||||
if self.save_trajectory and self.experiment_manager:
|
||||
try:
|
||||
self._save_screenshot(base64_image, action_type="state")
|
||||
logger.info("Screenshot saved to trajectory")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving screenshot: {str(e)}")
|
||||
|
||||
# Create screenshot message
|
||||
screen_info_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": base64_image,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
# Add screenshot to messages
|
||||
messages.append(screen_info_msg)
|
||||
logger.info("Screenshot message added to conversation")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing or processing screenshot: {str(e)}")
|
||||
raise
|
||||
|
||||
# Create new turn directory for this API call
|
||||
self._create_turn_dir()
|
||||
|
||||
|
||||
# Apply image retention policy
|
||||
self.message_manager.messages = messages.copy()
|
||||
prepared_messages = self.message_manager.get_messages()
|
||||
# Convert standard messages to Anthropic format using utility function
|
||||
anthropic_messages, system_content = to_anthropic_format(prepared_messages)
|
||||
|
||||
# Use API handler to make API call with Anthropic format
|
||||
response = await self.api_handler.make_api_call(
|
||||
messages=cast(List[BetaMessageParam], anthropic_messages),
|
||||
system_prompt=system_content or SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
# Use response handler to handle the response and get new messages
|
||||
new_messages, should_continue = await self.response_handler.handle_response(
|
||||
response, messages
|
||||
)
|
||||
|
||||
# Add new messages to the parent's message history
|
||||
messages.extend(new_messages)
|
||||
|
||||
openai_compatible_response = await to_agent_response_format(
|
||||
response,
|
||||
messages,
|
||||
model=self.model,
|
||||
)
|
||||
# Log standardized response for ease of parsing
|
||||
self._log_api_call("agent_response", request=None, response=openai_compatible_response)
|
||||
await queue.put(openai_compatible_response)
|
||||
|
||||
if not should_continue:
|
||||
break
|
||||
|
||||
# Signal completion
|
||||
await queue.put(None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _run_loop: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error in agent loop: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
await queue.put(None)
|
||||
|
||||
###########################################
|
||||
# RESPONSE AND CALLBACK HANDLING
|
||||
###########################################
|
||||
|
||||
async def _handle_response(self, response: BetaMessage, messages: List[Dict[str, Any]]) -> bool:
|
||||
"""Handle a response from the Anthropic API.
|
||||
|
||||
Args:
|
||||
response: The response from the Anthropic API
|
||||
messages: The message history
|
||||
|
||||
Returns:
|
||||
bool: Whether to continue the conversation
|
||||
"""
|
||||
try:
|
||||
# Convert response to standard format
|
||||
openai_compatible_response = await to_agent_response_format(
|
||||
response,
|
||||
messages,
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
# Put the response on the queue
|
||||
await self.queue.put(openai_compatible_response)
|
||||
|
||||
if self.callback_manager is None:
|
||||
raise RuntimeError(
|
||||
"Callback manager not initialized. Call initialize_client() first."
|
||||
)
|
||||
|
||||
# Handle tool use blocks and collect ALL results before adding to messages
|
||||
tool_result_content = []
|
||||
has_tool_use = False
|
||||
|
||||
for content_block in response.content:
|
||||
# Notify callback of content
|
||||
self.callback_manager.on_content(cast(BetaContentBlockParam, content_block))
|
||||
|
||||
# Handle tool use - carefully check and access attributes
|
||||
if hasattr(content_block, "type") and content_block.type == "tool_use":
|
||||
has_tool_use = True
|
||||
if self.tool_manager is None:
|
||||
raise RuntimeError(
|
||||
"Tool manager not initialized. Call initialize_client() first."
|
||||
)
|
||||
|
||||
# Safely get attributes
|
||||
tool_name = getattr(content_block, "name", "")
|
||||
tool_input = getattr(content_block, "input", {})
|
||||
tool_id = getattr(content_block, "id", "")
|
||||
|
||||
result = await self.tool_manager.execute_tool(
|
||||
name=tool_name,
|
||||
tool_input=cast(Dict[str, Any], tool_input),
|
||||
)
|
||||
|
||||
# Create tool result
|
||||
tool_result = self._make_tool_result(cast(ToolResult, result), tool_id)
|
||||
tool_result_content.append(tool_result)
|
||||
|
||||
# Notify callback of tool result
|
||||
self.callback_manager.on_tool_result(cast(ToolResult, result), tool_id)
|
||||
|
||||
# If we had any tool_use blocks, we MUST add the tool_result message
|
||||
# even if there were errors or no actual results
|
||||
if has_tool_use:
|
||||
# If somehow we have no tool results but had tool uses, add synthetic error results
|
||||
if not tool_result_content:
|
||||
logger.warning(
|
||||
"Had tool uses but no tool results, adding synthetic error results"
|
||||
)
|
||||
for content_block in response.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "tool_use":
|
||||
tool_id = getattr(content_block, "id", "")
|
||||
if tool_id:
|
||||
tool_result_content.append(
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_id,
|
||||
"content": {
|
||||
"type": "error",
|
||||
"text": "Tool execution was skipped or failed",
|
||||
},
|
||||
"is_error": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Add ALL tool results as a SINGLE user message
|
||||
messages.append({"role": "user", "content": tool_result_content})
|
||||
return True
|
||||
else:
|
||||
# No tool uses, we're done
|
||||
self.callback_manager.on_content({"type": "text", "text": "<DONE>"})
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling response: {str(e)}")
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
}
|
||||
)
|
||||
return False
|
||||
|
||||
def _response_to_blocks(self, response: BetaMessage) -> List[Dict[str, Any]]:
|
||||
"""Convert Anthropic API response to standard blocks format.
|
||||
|
||||
Args:
|
||||
response: API response message
|
||||
|
||||
Returns:
|
||||
List of content blocks in standard format
|
||||
"""
|
||||
result = []
|
||||
for block in response.content:
|
||||
if isinstance(block, BetaTextBlock):
|
||||
result.append({"type": "text", "text": block.text})
|
||||
elif hasattr(block, "type") and block.type == "tool_use":
|
||||
# Safely access attributes after confirming it's a tool_use
|
||||
result.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": getattr(block, "id", ""),
|
||||
"name": getattr(block, "name", ""),
|
||||
"input": getattr(block, "input", {}),
|
||||
}
|
||||
)
|
||||
else:
|
||||
# For other block types, convert to dict
|
||||
block_dict = {}
|
||||
for key, value in vars(block).items():
|
||||
if not key.startswith("_"):
|
||||
block_dict[key] = value
|
||||
result.append(block_dict)
|
||||
|
||||
return result
|
||||
|
||||
def _make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]:
|
||||
"""Convert a tool result to standard format.
|
||||
|
||||
Args:
|
||||
result: Tool execution result
|
||||
tool_use_id: ID of the tool use
|
||||
|
||||
Returns:
|
||||
Formatted tool result
|
||||
"""
|
||||
if result.content:
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"content": result.content,
|
||||
"tool_use_id": tool_use_id,
|
||||
"is_error": bool(result.error),
|
||||
}
|
||||
|
||||
tool_result_content = []
|
||||
is_error = False
|
||||
|
||||
if result.error:
|
||||
is_error = True
|
||||
tool_result_content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": self._maybe_prepend_system_tool_result(result, result.error),
|
||||
}
|
||||
]
|
||||
else:
|
||||
if result.output:
|
||||
tool_result_content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": self._maybe_prepend_system_tool_result(result, result.output),
|
||||
}
|
||||
)
|
||||
if result.base64_image:
|
||||
tool_result_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{result.base64_image}"},
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"content": tool_result_content,
|
||||
"tool_use_id": tool_use_id,
|
||||
"is_error": is_error,
|
||||
}
|
||||
|
||||
def _maybe_prepend_system_tool_result(self, result: ToolResult, result_text: str) -> str:
|
||||
"""Prepend system information to tool result if available.
|
||||
|
||||
Args:
|
||||
result: Tool execution result
|
||||
result_text: Text to prepend to
|
||||
|
||||
Returns:
|
||||
Text with system information prepended if available
|
||||
"""
|
||||
if result.system:
|
||||
result_text = f"<s>{result.system}</s>\n{result_text}"
|
||||
return result_text
|
||||
|
||||
###########################################
|
||||
# CALLBACK HANDLERS
|
||||
###########################################
|
||||
|
||||
def _handle_content(self, content):
|
||||
"""Handle content updates from the assistant."""
|
||||
if content.get("type") == "text":
|
||||
text = content.get("text", "")
|
||||
if text == "<DONE>":
|
||||
return
|
||||
logger.info(f"Assistant: {text}")
|
||||
|
||||
def _handle_tool_result(self, result, tool_id):
|
||||
"""Handle tool execution results."""
|
||||
if result.error:
|
||||
logger.error(f"Tool {tool_id} error: {result.error}")
|
||||
else:
|
||||
logger.info(f"Tool {tool_id} output: {result.output}")
|
||||
|
||||
def _handle_api_interaction(
|
||||
self, request: Any, response: Any, error: Optional[Exception]
|
||||
) -> None:
|
||||
"""Handle API interactions."""
|
||||
if error:
|
||||
logger.error(f"API error: {error}")
|
||||
self._log_api_call("error", request, error=error)
|
||||
else:
|
||||
logger.debug(f"API request: {request}")
|
||||
if response:
|
||||
self._log_api_call("response", request, response)
|
||||
else:
|
||||
self._log_api_call("request", request)
|
||||
@@ -1,23 +0,0 @@
|
||||
"""System prompts for Anthropic provider."""
|
||||
|
||||
from datetime import datetime
|
||||
import platform
|
||||
|
||||
today = datetime.today()
|
||||
today = f"{today.strftime('%A, %B')} {today.day}, {today.year}"
|
||||
|
||||
SYSTEM_PROMPT = f"""<SYSTEM_CAPABILITY>
|
||||
* You are utilising a macOS virtual machine using ARM architecture with internet access and Safari as default browser.
|
||||
* You can feel free to install macOS applications with your bash tool. Use curl instead of wget.
|
||||
* Using bash tool you can start GUI applications. GUI apps run with bash tool will appear within your desktop environment, but they may take some time to appear. Take a screenshot to confirm it did.
|
||||
* When using your bash tool with commands that are expected to output very large quantities of text, redirect into a tmp file and use str_replace_editor or `grep -n -B <lines before> -A <lines after> <query> <filename>` to confirm output.
|
||||
* When viewing a page it can be helpful to zoom out so that you can see everything on the page. Either that, or make sure you scroll down to see everything before deciding something isn't available.
|
||||
* When using your computer function calls, they take a while to run and send back to you. Where possible/feasible, try to chain multiple of these calls all into one function calls request.
|
||||
* The current date is {today}.
|
||||
</SYSTEM_CAPABILITY>
|
||||
|
||||
<IMPORTANT>
|
||||
* Plan at maximum 1 step each time, and evaluate the result of each step before proceeding. Hold back if you're not sure about the result of the step.
|
||||
* If you're not sure about the location of an application, use start the app using the bash tool.
|
||||
* If the item you are looking at is a pdf, if after taking a single screenshot of the pdf it seems that you want to read the entire document instead of trying to continue to read the pdf from your screenshots + navigation, determine the URL, use curl to download the pdf, install and use pdftotext to convert it to a text file, and then read that text file directly with your StrReplaceEditTool.
|
||||
</IMPORTANT>"""
|
||||
@@ -1,226 +0,0 @@
|
||||
"""Response and tool handling for Anthropic provider."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Tuple, cast
|
||||
|
||||
from anthropic.types.beta import (
|
||||
BetaMessage,
|
||||
BetaTextBlock,
|
||||
BetaContentBlockParam,
|
||||
)
|
||||
|
||||
from .tools import ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicResponseHandler:
|
||||
"""Handles Anthropic API responses and tool execution results."""
|
||||
|
||||
def __init__(self, loop):
|
||||
"""Initialize the response handler.
|
||||
|
||||
Args:
|
||||
loop: Reference to the parent loop instance that provides context
|
||||
"""
|
||||
self.loop = loop
|
||||
|
||||
async def handle_response(
|
||||
self, response: BetaMessage, messages: List[Dict[str, Any]]
|
||||
) -> Tuple[List[Dict[str, Any]], bool]:
|
||||
"""Handle the Anthropic API response.
|
||||
|
||||
Args:
|
||||
response: API response
|
||||
messages: List of messages for context
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- List of new messages to be added
|
||||
- Boolean indicating if the loop should continue
|
||||
"""
|
||||
try:
|
||||
new_messages = []
|
||||
|
||||
# Convert response to parameter format
|
||||
response_params = self.response_to_params(response)
|
||||
|
||||
# Collect all existing tool_use IDs from previous messages for validation
|
||||
existing_tool_use_ids = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant" and isinstance(msg.get("content"), list):
|
||||
for block in msg.get("content", []):
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "tool_use"
|
||||
and "id" in block
|
||||
):
|
||||
existing_tool_use_ids.add(block["id"])
|
||||
|
||||
# Also add new tool_use IDs from the current response
|
||||
current_tool_use_ids = set()
|
||||
for block in response_params:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use" and "id" in block:
|
||||
current_tool_use_ids.add(block["id"])
|
||||
existing_tool_use_ids.add(block["id"])
|
||||
|
||||
logger.info(f"Existing tool_use IDs in conversation: {existing_tool_use_ids}")
|
||||
logger.info(f"New tool_use IDs in current response: {current_tool_use_ids}")
|
||||
|
||||
# Create assistant message
|
||||
new_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": response_params,
|
||||
}
|
||||
)
|
||||
|
||||
if self.loop.callback_manager is None:
|
||||
raise RuntimeError(
|
||||
"Callback manager not initialized. Call initialize_client() first."
|
||||
)
|
||||
|
||||
# Handle tool use blocks and collect results
|
||||
tool_result_content = []
|
||||
for content_block in response_params:
|
||||
# Notify callback of content
|
||||
self.loop.callback_manager.on_content(cast(BetaContentBlockParam, content_block))
|
||||
|
||||
# Handle tool use
|
||||
if content_block.get("type") == "tool_use":
|
||||
if self.loop.tool_manager is None:
|
||||
raise RuntimeError(
|
||||
"Tool manager not initialized. Call initialize_client() first."
|
||||
)
|
||||
|
||||
# Execute the tool
|
||||
result = await self.loop.tool_manager.execute_tool(
|
||||
name=content_block["name"],
|
||||
tool_input=cast(Dict[str, Any], content_block["input"]),
|
||||
)
|
||||
|
||||
# Verify the tool_use ID exists in the conversation (which it should now)
|
||||
tool_use_id = content_block["id"]
|
||||
if tool_use_id in existing_tool_use_ids:
|
||||
# Create tool result and add to content
|
||||
tool_result = self.make_tool_result(cast(ToolResult, result), tool_use_id)
|
||||
tool_result_content.append(tool_result)
|
||||
|
||||
# Notify callback of tool result
|
||||
self.loop.callback_manager.on_tool_result(
|
||||
cast(ToolResult, result), content_block["id"]
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Tool use ID {tool_use_id} not found in previous messages. Skipping tool result."
|
||||
)
|
||||
|
||||
# If no tool results, we're done
|
||||
if not tool_result_content:
|
||||
# Signal completion
|
||||
self.loop.callback_manager.on_content({"type": "text", "text": "<DONE>"})
|
||||
return new_messages, False
|
||||
|
||||
# Add tool results as user message
|
||||
new_messages.append({"content": tool_result_content, "role": "user"})
|
||||
return new_messages, True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling response: {str(e)}")
|
||||
new_messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
}
|
||||
)
|
||||
return new_messages, False
|
||||
|
||||
def response_to_params(
|
||||
self,
|
||||
response: BetaMessage,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert API response to message parameters.
|
||||
|
||||
Args:
|
||||
response: API response message
|
||||
|
||||
Returns:
|
||||
List of content blocks
|
||||
"""
|
||||
result = []
|
||||
for block in response.content:
|
||||
if isinstance(block, BetaTextBlock):
|
||||
result.append({"type": "text", "text": block.text})
|
||||
else:
|
||||
result.append(cast(Dict[str, Any], block.model_dump()))
|
||||
return result
|
||||
|
||||
def make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]:
|
||||
"""Convert a tool result to API format.
|
||||
|
||||
Args:
|
||||
result: Tool execution result
|
||||
tool_use_id: ID of the tool use
|
||||
|
||||
Returns:
|
||||
Formatted tool result
|
||||
"""
|
||||
if result.content:
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"content": result.content,
|
||||
"tool_use_id": tool_use_id,
|
||||
"is_error": bool(result.error),
|
||||
}
|
||||
|
||||
tool_result_content = []
|
||||
is_error = False
|
||||
|
||||
if result.error:
|
||||
is_error = True
|
||||
tool_result_content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": self.maybe_prepend_system_tool_result(result, result.error),
|
||||
}
|
||||
]
|
||||
else:
|
||||
if result.output:
|
||||
tool_result_content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": self.maybe_prepend_system_tool_result(result, result.output),
|
||||
}
|
||||
)
|
||||
if result.base64_image:
|
||||
tool_result_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": result.base64_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"content": tool_result_content,
|
||||
"tool_use_id": tool_use_id,
|
||||
"is_error": is_error,
|
||||
}
|
||||
|
||||
def maybe_prepend_system_tool_result(self, result: ToolResult, result_text: str) -> str:
|
||||
"""Prepend system information to tool result if available.
|
||||
|
||||
Args:
|
||||
result: Tool execution result
|
||||
result_text: Text to prepend to
|
||||
|
||||
Returns:
|
||||
Text with system information prepended if available
|
||||
"""
|
||||
if result.system:
|
||||
result_text = f"<s>{result.system}</s>\n{result_text}"
|
||||
return result_text
|
||||
@@ -1,33 +0,0 @@
|
||||
"""Anthropic-specific tools for agent."""
|
||||
|
||||
from .base import (
|
||||
BaseAnthropicTool,
|
||||
ToolResult,
|
||||
ToolError,
|
||||
ToolFailure,
|
||||
CLIResult,
|
||||
AnthropicToolResult,
|
||||
AnthropicToolError,
|
||||
AnthropicToolFailure,
|
||||
AnthropicCLIResult,
|
||||
)
|
||||
from .bash import BashTool
|
||||
from .computer import ComputerTool
|
||||
from .edit import EditTool
|
||||
from .manager import ToolManager
|
||||
|
||||
__all__ = [
|
||||
"BaseAnthropicTool",
|
||||
"ToolResult",
|
||||
"ToolError",
|
||||
"ToolFailure",
|
||||
"CLIResult",
|
||||
"AnthropicToolResult",
|
||||
"AnthropicToolError",
|
||||
"AnthropicToolFailure",
|
||||
"AnthropicCLIResult",
|
||||
"BashTool",
|
||||
"ComputerTool",
|
||||
"EditTool",
|
||||
"ToolManager",
|
||||
]
|
||||
@@ -1,88 +0,0 @@
|
||||
"""Anthropic-specific tool base classes."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Any, Dict
|
||||
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
|
||||
from ....core.tools.base import BaseTool
|
||||
|
||||
|
||||
class BaseAnthropicTool(BaseTool, metaclass=ABCMeta):
|
||||
"""Abstract base class for Anthropic-defined tools."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the base Anthropic tool."""
|
||||
# No specific initialization needed yet, but included for future extensibility
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> Any:
|
||||
"""Executes the tool with the given arguments."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to Anthropic-specific API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters for Anthropic API
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class ToolResult:
|
||||
"""Represents the result of a tool execution."""
|
||||
|
||||
output: str | None = None
|
||||
error: str | None = None
|
||||
base64_image: str | None = None
|
||||
system: str | None = None
|
||||
content: list[dict] | None = None
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field.name) for field in fields(self))
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
content=self.content or other.content, # Use first non-None content
|
||||
)
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new ToolResult with the given fields replaced."""
|
||||
return replace(self, **kwargs)
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""A ToolResult that can be rendered as a CLI output."""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
class ToolError(Exception):
|
||||
"""Raised when a tool encounters an error."""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
|
||||
# Re-export the core tool classes with Anthropic-specific names for backward compatibility
|
||||
AnthropicToolResult = ToolResult
|
||||
AnthropicToolError = ToolError
|
||||
AnthropicToolFailure = ToolFailure
|
||||
AnthropicCLIResult = CLIResult
|
||||
@@ -1,66 +0,0 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import ClassVar, Literal, Dict, Any
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
|
||||
from ....core.tools.bash import BaseBashTool
|
||||
|
||||
|
||||
class BashTool(BaseBashTool, BaseAnthropicTool):
|
||||
"""
|
||||
A tool that allows the agent to run bash commands.
|
||||
The tool parameters are defined by Anthropic and are not editable.
|
||||
"""
|
||||
|
||||
name: ClassVar[Literal["bash"]] = "bash"
|
||||
api_type: ClassVar[Literal["bash_20250124"]] = "bash_20250124"
|
||||
_timeout: float = 120.0 # seconds
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the bash tool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for executing commands
|
||||
"""
|
||||
# Initialize the base bash tool first
|
||||
BaseBashTool.__init__(self, computer)
|
||||
# Then initialize the Anthropic tool
|
||||
BaseAnthropicTool.__init__(self)
|
||||
# Initialize bash session
|
||||
|
||||
async def __call__(self, command: str | None = None, restart: bool = False, **kwargs):
|
||||
"""Execute a bash command.
|
||||
|
||||
Args:
|
||||
command: The command to execute
|
||||
restart: Whether to restart the shell (not used with computer interface)
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
|
||||
Raises:
|
||||
ToolError: If command execution fails
|
||||
"""
|
||||
if restart:
|
||||
return ToolResult(system="Restart not needed with computer interface.")
|
||||
|
||||
if command is None:
|
||||
raise ToolError("no command provided.")
|
||||
|
||||
try:
|
||||
async with asyncio.timeout(self._timeout):
|
||||
result = await self.computer.interface.run_command(command)
|
||||
return CLIResult(output=result.stdout or "", error=result.stderr or "")
|
||||
except asyncio.TimeoutError as e:
|
||||
raise ToolError(f"Command timed out after {self._timeout} seconds") from e
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to execute command: {str(e)}")
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
return {"name": self.name, "type": self.api_type}
|
||||
@@ -1,34 +0,0 @@
|
||||
"""Collection classes for managing multiple tools."""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
|
||||
from .base import (
|
||||
BaseAnthropicTool,
|
||||
ToolError,
|
||||
ToolFailure,
|
||||
ToolResult,
|
||||
)
|
||||
|
||||
|
||||
class ToolCollection:
|
||||
"""A collection of anthropic-defined tools."""
|
||||
|
||||
def __init__(self, *tools: BaseAnthropicTool):
|
||||
self.tools = tools
|
||||
self.tool_map = {tool.to_params()["name"]: tool for tool in tools}
|
||||
|
||||
def to_params(
|
||||
self,
|
||||
) -> list[BetaToolUnionParam]:
|
||||
return cast(list[BetaToolUnionParam], [tool.to_params() for tool in self.tools])
|
||||
|
||||
async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
|
||||
tool = self.tool_map.get(name)
|
||||
if not tool:
|
||||
return ToolFailure(error=f"Tool {name} is invalid")
|
||||
try:
|
||||
return await tool(**tool_input)
|
||||
except ToolError as e:
|
||||
return ToolFailure(error=e.message)
|
||||
@@ -1,396 +0,0 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import logging
|
||||
from enum import StrEnum
|
||||
from pathlib import Path
|
||||
from typing import Literal, TypedDict, Any, Dict
|
||||
import subprocess
|
||||
from PIL import Image
|
||||
from datetime import datetime
|
||||
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseAnthropicTool, ToolError, ToolResult
|
||||
from .run import run
|
||||
from ....core.tools.computer import BaseComputerTool
|
||||
|
||||
TYPING_DELAY_MS = 12
|
||||
TYPING_GROUP_SIZE = 50
|
||||
|
||||
Action = Literal[
|
||||
"key",
|
||||
"type",
|
||||
"mouse_move",
|
||||
"left_click",
|
||||
"left_click_drag",
|
||||
"right_click",
|
||||
"middle_click",
|
||||
"double_click",
|
||||
"screenshot",
|
||||
"cursor_position",
|
||||
"scroll",
|
||||
]
|
||||
|
||||
|
||||
class Resolution(TypedDict):
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class ScalingSource(StrEnum):
|
||||
COMPUTER = "computer"
|
||||
API = "api"
|
||||
|
||||
|
||||
class ComputerToolOptions(TypedDict):
|
||||
display_height_px: int
|
||||
display_width_px: int
|
||||
display_number: int | None
|
||||
|
||||
|
||||
def chunks(s: str, chunk_size: int) -> list[str]:
|
||||
return [s[i : i + chunk_size] for i in range(0, len(s), chunk_size)]
|
||||
|
||||
|
||||
class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
"""
|
||||
A tool that allows the agent to interact with the screen, keyboard, and mouse of the current macOS computer.
|
||||
The tool parameters are defined by Anthropic and are not editable.
|
||||
"""
|
||||
|
||||
name: Literal["computer"] = "computer"
|
||||
api_type: Literal["computer_20250124"] = "computer_20250124"
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
display_num: int | None = None
|
||||
computer: Computer # The CUA Computer instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_screenshot_delay = 1.0 # macOS is generally faster than X11
|
||||
_scaling_enabled = True
|
||||
|
||||
@property
|
||||
def options(self) -> ComputerToolOptions:
|
||||
if self.width is None or self.height is None:
|
||||
raise RuntimeError(
|
||||
"Screen dimensions not initialized. Call initialize_dimensions() first."
|
||||
)
|
||||
return {
|
||||
"display_width_px": self.width,
|
||||
"display_height_px": self.height,
|
||||
"display_number": self.display_num,
|
||||
}
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
return {"name": self.name, "type": self.api_type, **self.options}
|
||||
|
||||
def __init__(self, computer):
|
||||
# Initialize the base computer tool first
|
||||
BaseComputerTool.__init__(self, computer)
|
||||
# Then initialize the Anthropic tool
|
||||
BaseAnthropicTool.__init__(self)
|
||||
|
||||
# Additional initialization
|
||||
self.width = None # Will be initialized from computer interface
|
||||
self.height = None # Will be initialized from computer interface
|
||||
self.display_num = None
|
||||
|
||||
async def initialize_dimensions(self):
|
||||
"""Initialize screen dimensions from the computer interface."""
|
||||
display_size = await self.computer.interface.get_screen_size()
|
||||
self.width = display_size["width"]
|
||||
self.height = display_size["height"]
|
||||
assert isinstance(self.width, int) and isinstance(self.height, int)
|
||||
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
action: Action,
|
||||
text: str | None = None,
|
||||
coordinate: tuple[int, int] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
# Ensure dimensions are initialized
|
||||
if self.width is None or self.height is None:
|
||||
await self.initialize_dimensions()
|
||||
if self.width is None or self.height is None:
|
||||
raise ToolError("Failed to initialize screen dimensions")
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to initialize dimensions: {e}")
|
||||
|
||||
if action in ("mouse_move", "left_click_drag"):
|
||||
if coordinate is None:
|
||||
raise ToolError(f"coordinate is required for {action}")
|
||||
if text is not None:
|
||||
raise ToolError(f"text is not accepted for {action}")
|
||||
if not isinstance(coordinate, (list, tuple)) or len(coordinate) != 2:
|
||||
raise ToolError(f"{coordinate} must be a tuple of length 2")
|
||||
if not all(isinstance(i, int) and i >= 0 for i in coordinate):
|
||||
raise ToolError(f"{coordinate} must be a tuple of non-negative ints")
|
||||
|
||||
try:
|
||||
x, y = coordinate
|
||||
self.logger.info(f"Handling {action} action:")
|
||||
self.logger.info(f" Coordinates: ({x}, {y})")
|
||||
|
||||
# Take pre-action screenshot to get current dimensions
|
||||
pre_screenshot = await self.computer.interface.screenshot()
|
||||
pre_img = Image.open(io.BytesIO(pre_screenshot))
|
||||
|
||||
# Scale image to match screen dimensions if needed
|
||||
if pre_img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling image from {pre_img.size} to {self.width}x{self.height} to match screen dimensions"
|
||||
)
|
||||
if not isinstance(self.width, int) or not isinstance(self.height, int):
|
||||
raise ToolError("Screen dimensions must be integers")
|
||||
size = (int(self.width), int(self.height))
|
||||
pre_img = pre_img.resize(size, Image.Resampling.LANCZOS)
|
||||
|
||||
self.logger.info(f" Current dimensions: {pre_img.width}x{pre_img.height}")
|
||||
|
||||
if action == "mouse_move":
|
||||
self.logger.info(f"Moving cursor to ({x}, {y})")
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
elif action == "left_click_drag":
|
||||
# Get the start coordinate from kwargs
|
||||
start_coordinate = kwargs.get("start_coordinate")
|
||||
if not start_coordinate:
|
||||
raise ToolError("start_coordinate is required for left_click_drag action")
|
||||
|
||||
start_x, start_y = start_coordinate
|
||||
end_x, end_y = x, y
|
||||
|
||||
self.logger.info(f"Dragging from ({start_x}, {start_y}) to ({end_x}, {end_y})")
|
||||
await self.computer.interface.move_cursor(start_x, start_y)
|
||||
await self.computer.interface.drag_to(end_x, end_y)
|
||||
|
||||
# Wait briefly for any UI changes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take post-action screenshot
|
||||
post_screenshot = await self.computer.interface.screenshot()
|
||||
post_img = Image.open(io.BytesIO(post_screenshot))
|
||||
|
||||
# Scale post-action image if needed
|
||||
if post_img.size != (self.width, self.height):
|
||||
self.logger.info(
|
||||
f"Scaling post-action image from {post_img.size} to {self.width}x{self.height}"
|
||||
)
|
||||
post_img = post_img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
buffer = io.BytesIO()
|
||||
post_img.save(buffer, format="PNG")
|
||||
post_screenshot = buffer.getvalue()
|
||||
|
||||
return ToolResult(
|
||||
output=f"{'Moved cursor to' if action == 'mouse_move' else 'Dragged to'} {x},{y}",
|
||||
base64_image=base64.b64encode(post_screenshot).decode(),
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during {action} action: {str(e)}")
|
||||
raise ToolError(f"Failed to perform {action}: {str(e)}")
|
||||
|
||||
elif action in ("left_click", "right_click", "double_click"):
|
||||
if coordinate:
|
||||
x, y = coordinate
|
||||
self.logger.info(f"Handling {action} action:")
|
||||
self.logger.info(f" Coordinates: ({x}, {y})")
|
||||
|
||||
try:
|
||||
# Perform the click action
|
||||
if action == "left_click":
|
||||
self.logger.info(f"Clicking at ({x}, {y})")
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
await self.computer.interface.left_click()
|
||||
elif action == "right_click":
|
||||
self.logger.info(f"Right clicking at ({x}, {y})")
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
await self.computer.interface.right_click()
|
||||
elif action == "double_click":
|
||||
self.logger.info(f"Double clicking at ({x}, {y})")
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
await self.computer.interface.double_click()
|
||||
|
||||
# Wait briefly for any UI changes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return ToolResult(
|
||||
output=f"Performed {action} at ({x}, {y})",
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during {action} action: {str(e)}")
|
||||
raise ToolError(f"Failed to perform {action}: {str(e)}")
|
||||
else:
|
||||
try:
|
||||
# Perform the click action
|
||||
if action == "left_click":
|
||||
self.logger.info("Performing left click at current position")
|
||||
await self.computer.interface.left_click()
|
||||
elif action == "right_click":
|
||||
self.logger.info("Performing right click at current position")
|
||||
await self.computer.interface.right_click()
|
||||
elif action == "double_click":
|
||||
self.logger.info("Performing double click at current position")
|
||||
await self.computer.interface.double_click()
|
||||
|
||||
# Wait briefly for any UI changes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return ToolResult(
|
||||
output=f"Performed {action} at current position",
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during {action} action: {str(e)}")
|
||||
raise ToolError(f"Failed to perform {action}: {str(e)}")
|
||||
|
||||
elif action in ("key", "type"):
|
||||
if text is None:
|
||||
raise ToolError(f"text is required for {action}")
|
||||
if coordinate is not None:
|
||||
raise ToolError(f"coordinate is not accepted for {action}")
|
||||
if not isinstance(text, str):
|
||||
raise ToolError(f"{text} must be a string")
|
||||
|
||||
try:
|
||||
if action == "key":
|
||||
# Special handling for page up/down on macOS
|
||||
if text.lower() in ["pagedown", "page_down", "page down"]:
|
||||
self.logger.info("Converting page down to fn+down for macOS")
|
||||
await self.computer.interface.hotkey("fn", "down")
|
||||
output_text = "fn+down"
|
||||
elif text.lower() in ["pageup", "page_up", "page up"]:
|
||||
self.logger.info("Converting page up to fn+up for macOS")
|
||||
await self.computer.interface.hotkey("fn", "up")
|
||||
output_text = "fn+up"
|
||||
elif text == "fn+down":
|
||||
self.logger.info("Using fn+down combination")
|
||||
await self.computer.interface.hotkey("fn", "down")
|
||||
output_text = text
|
||||
elif text == "fn+up":
|
||||
self.logger.info("Using fn+up combination")
|
||||
await self.computer.interface.hotkey("fn", "up")
|
||||
output_text = text
|
||||
elif "+" in text:
|
||||
# Handle hotkey combinations
|
||||
keys = text.split("+")
|
||||
self.logger.info(f"Pressing hotkey combination: {text}")
|
||||
await self.computer.interface.hotkey(*keys)
|
||||
output_text = text
|
||||
else:
|
||||
# Handle single key press
|
||||
self.logger.info(f"Pressing key: {text}")
|
||||
try:
|
||||
await self.computer.interface.press_key(text)
|
||||
output_text = text
|
||||
except ValueError as e:
|
||||
raise ToolError(f"Invalid key: {text}. {str(e)}")
|
||||
|
||||
# Wait briefly for UI changes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return ToolResult(
|
||||
output=f"Pressed key: {output_text}",
|
||||
)
|
||||
|
||||
elif action == "type":
|
||||
self.logger.info(f"Typing text: {text}")
|
||||
await self.computer.interface.type_text(text)
|
||||
|
||||
# Wait briefly for UI changes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return ToolResult(
|
||||
output=f"Typed text: {text}",
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during {action} action: {str(e)}")
|
||||
raise ToolError(f"Failed to perform {action}: {str(e)}")
|
||||
|
||||
elif action == "scroll":
|
||||
# Implement scroll action
|
||||
direction = kwargs.get("direction", "down")
|
||||
amount = kwargs.get("amount", 10)
|
||||
|
||||
if direction not in ["up", "down"]:
|
||||
raise ToolError(f"Invalid scroll direction: {direction}. Must be 'up' or 'down'.")
|
||||
|
||||
try:
|
||||
if direction == "down":
|
||||
# Scroll down (Page Down on macOS)
|
||||
self.logger.info(f"Scrolling down, amount: {amount}")
|
||||
await self.computer.interface.scroll_down(amount)
|
||||
else:
|
||||
# Scroll up (Page Up on macOS)
|
||||
self.logger.info(f"Scrolling up, amount: {amount}")
|
||||
await self.computer.interface.scroll_up(amount)
|
||||
|
||||
# Wait briefly for UI changes
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return ToolResult(
|
||||
output=f"Scrolled {direction} by {amount} steps",
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during scroll action: {str(e)}")
|
||||
raise ToolError(f"Failed to perform scroll: {str(e)}")
|
||||
|
||||
elif action == "screenshot":
|
||||
# Take screenshot
|
||||
return await self.screenshot()
|
||||
elif action == "cursor_position":
|
||||
pos = await self.computer.interface.get_cursor_position()
|
||||
x, y = pos # Unpack the tuple
|
||||
return ToolResult(output=f"X={int(x)},Y={int(y)}")
|
||||
raise ToolError(f"Invalid action: {action}")
|
||||
|
||||
async def screenshot(self):
|
||||
"""Take a screenshot and return it as a base64-encoded string."""
|
||||
try:
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
img = Image.open(io.BytesIO(screenshot))
|
||||
|
||||
# Scale image if needed
|
||||
if img.size != (self.width, self.height):
|
||||
self.logger.info(f"Scaling image from {img.size} to {self.width}x{self.height}")
|
||||
if not isinstance(self.width, int) or not isinstance(self.height, int):
|
||||
raise ToolError("Screen dimensions must be integers")
|
||||
size = (int(self.width), int(self.height))
|
||||
img = img.resize(size, Image.Resampling.LANCZOS)
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
screenshot = buffer.getvalue()
|
||||
|
||||
return ToolResult(base64_image=base64.b64encode(screenshot).decode())
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error taking screenshot: {str(e)}")
|
||||
return ToolResult(error=f"Failed to take screenshot: {str(e)}")
|
||||
|
||||
async def shell(self, command: str, take_screenshot=False) -> ToolResult:
|
||||
"""Run a shell command and return the output, error, and optionally a screenshot."""
|
||||
try:
|
||||
_, stdout, stderr = await run(command)
|
||||
base64_image = None
|
||||
|
||||
if take_screenshot:
|
||||
# delay to let things settle before taking a screenshot
|
||||
await asyncio.sleep(self._screenshot_delay)
|
||||
screenshot_result = await self.screenshot()
|
||||
if screenshot_result.error:
|
||||
return ToolResult(
|
||||
output=stdout,
|
||||
error=f"{stderr}\nScreenshot error: {screenshot_result.error}",
|
||||
)
|
||||
base64_image = screenshot_result.base64_image
|
||||
|
||||
return ToolResult(output=stdout, error=stderr, base64_image=base64_image)
|
||||
|
||||
except Exception as e:
|
||||
return ToolResult(error=f"Shell command failed: {str(e)}")
|
||||
@@ -1,326 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Literal, get_args, Dict, Any
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
|
||||
from ....core.tools.edit import BaseEditTool
|
||||
from .run import maybe_truncate
|
||||
|
||||
Command = Literal[
|
||||
"view",
|
||||
"create",
|
||||
"str_replace",
|
||||
"insert",
|
||||
"undo_edit",
|
||||
]
|
||||
SNIPPET_LINES: int = 4
|
||||
|
||||
|
||||
class EditTool(BaseEditTool, BaseAnthropicTool):
|
||||
"""
|
||||
An filesystem editor tool that allows the agent to view, create, and edit files.
|
||||
The tool parameters are defined by Anthropic and are not editable.
|
||||
"""
|
||||
|
||||
api_type: Literal["text_editor_20250124"] = "text_editor_20250124"
|
||||
name: Literal["str_replace_editor"] = "str_replace_editor"
|
||||
_timeout: float = 30.0 # seconds
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the edit tool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for file operations
|
||||
"""
|
||||
# Initialize the base edit tool first
|
||||
BaseEditTool.__init__(self, computer)
|
||||
# Then initialize the Anthropic tool
|
||||
BaseAnthropicTool.__init__(self)
|
||||
|
||||
# Edit history for the current session
|
||||
self.edit_history = defaultdict(list)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
command: Command,
|
||||
path: str,
|
||||
file_text: str | None = None,
|
||||
view_range: list[int] | None = None,
|
||||
old_str: str | None = None,
|
||||
new_str: str | None = None,
|
||||
insert_line: int | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
_path = Path(path)
|
||||
await self.validate_path(command, _path)
|
||||
|
||||
if command == "view":
|
||||
return await self.view(_path, view_range)
|
||||
elif command == "create":
|
||||
if file_text is None:
|
||||
raise ToolError("Parameter `file_text` is required for command: create")
|
||||
await self.write_file(_path, file_text)
|
||||
self.edit_history[_path].append(file_text)
|
||||
return ToolResult(output=f"File created successfully at: {_path}")
|
||||
elif command == "str_replace":
|
||||
if old_str is None:
|
||||
raise ToolError("Parameter `old_str` is required for command: str_replace")
|
||||
return await self.str_replace(_path, old_str, new_str)
|
||||
elif command == "insert":
|
||||
if insert_line is None:
|
||||
raise ToolError("Parameter `insert_line` is required for command: insert")
|
||||
if new_str is None:
|
||||
raise ToolError("Parameter `new_str` is required for command: insert")
|
||||
return await self.insert(_path, insert_line, new_str)
|
||||
elif command == "undo_edit":
|
||||
return await self.undo_edit(_path)
|
||||
|
||||
raise ToolError(
|
||||
f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
|
||||
)
|
||||
|
||||
async def validate_path(self, command: str, path: Path):
|
||||
"""Check that the path/command combination is valid."""
|
||||
# Check if its an absolute path
|
||||
if not path.is_absolute():
|
||||
suggested_path = Path("") / path
|
||||
raise ToolError(
|
||||
f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
|
||||
)
|
||||
|
||||
# Check if path exists using bash commands
|
||||
try:
|
||||
result = await self.computer.interface.run_command(
|
||||
f'[ -e "{str(path)}" ] && echo "exists" || echo "not exists"'
|
||||
)
|
||||
exists = result.stdout.strip() == "exists"
|
||||
|
||||
if exists:
|
||||
result = await self.computer.interface.run_command(
|
||||
f'[ -d "{str(path)}" ] && echo "dir" || echo "file"'
|
||||
)
|
||||
is_dir = result.stdout.strip() == "dir"
|
||||
else:
|
||||
is_dir = False
|
||||
|
||||
# Check path validity
|
||||
if not exists and command != "create":
|
||||
raise ToolError(f"The path {path} does not exist. Please provide a valid path.")
|
||||
if exists and command == "create":
|
||||
raise ToolError(
|
||||
f"File already exists at: {path}. Cannot overwrite files using command `create`."
|
||||
)
|
||||
if is_dir and command != "view":
|
||||
raise ToolError(
|
||||
f"The path {path} is a directory and only the `view` command can be used on directories"
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to validate path: {str(e)}")
|
||||
|
||||
async def view(self, path: Path, view_range: list[int] | None = None):
|
||||
"""Implement the view command"""
|
||||
try:
|
||||
# Check if path is a directory
|
||||
result = await self.computer.interface.run_command(
|
||||
f'[ -d "{str(path)}" ] && echo "dir" || echo "file"'
|
||||
)
|
||||
is_dir = result.stdout.strip() == "dir"
|
||||
|
||||
if is_dir:
|
||||
if view_range:
|
||||
raise ToolError(
|
||||
"The `view_range` parameter is not allowed when `path` points to a directory."
|
||||
)
|
||||
|
||||
# List directory contents using ls
|
||||
result = await self.computer.interface.run_command(f'ls -la "{str(path)}"')
|
||||
contents = result.stdout
|
||||
if contents:
|
||||
stdout = f"Here's the files and directories in {path}:\n{contents}\n"
|
||||
else:
|
||||
stdout = f"Directory {path} is empty\n"
|
||||
return CLIResult(output=stdout)
|
||||
|
||||
# Read file content using cat
|
||||
file_content = await self.read_file(path)
|
||||
init_line = 1
|
||||
|
||||
if view_range:
|
||||
if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
|
||||
raise ToolError("Invalid `view_range`. It should be a list of two integers.")
|
||||
|
||||
file_lines = file_content.split("\n")
|
||||
n_lines_file = len(file_lines)
|
||||
init_line, final_line = view_range
|
||||
|
||||
if init_line < 1 or init_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
|
||||
)
|
||||
if final_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
|
||||
)
|
||||
if final_line != -1 and final_line < init_line:
|
||||
raise ToolError(
|
||||
f"Invalid `view_range`: {view_range}. Its second element `{final_line}` should be larger or equal than its first `{init_line}`"
|
||||
)
|
||||
|
||||
if final_line == -1:
|
||||
file_content = "\n".join(file_lines[init_line - 1 :])
|
||||
else:
|
||||
file_content = "\n".join(file_lines[init_line - 1 : final_line])
|
||||
|
||||
return CLIResult(output=self._make_output(file_content, str(path), init_line=init_line))
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to view path: {str(e)}")
|
||||
|
||||
async def str_replace(self, path: Path, old_str: str, new_str: str | None):
|
||||
"""Implement the str_replace command"""
|
||||
# Read the file content
|
||||
file_content = await self.read_file(path)
|
||||
file_content = file_content.expandtabs()
|
||||
old_str = old_str.expandtabs()
|
||||
new_str = new_str.expandtabs() if new_str is not None else ""
|
||||
|
||||
# Check if old_str is unique in the file
|
||||
occurrences = file_content.count(old_str)
|
||||
if occurrences == 0:
|
||||
raise ToolError(
|
||||
f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
|
||||
)
|
||||
elif occurrences > 1:
|
||||
file_content_lines = file_content.split("\n")
|
||||
lines = [idx + 1 for idx, line in enumerate(file_content_lines) if old_str in line]
|
||||
raise ToolError(
|
||||
f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
|
||||
)
|
||||
|
||||
# Replace old_str with new_str
|
||||
new_file_content = file_content.replace(old_str, new_str)
|
||||
|
||||
# Write the new content to the file
|
||||
await self.write_file(path, new_file_content)
|
||||
|
||||
# Save the content to history
|
||||
self.edit_history[path].append(file_content)
|
||||
|
||||
# Create a snippet of the edited section
|
||||
replacement_line = file_content.split(old_str)[0].count("\n")
|
||||
start_line = max(0, replacement_line - SNIPPET_LINES)
|
||||
end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
|
||||
snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1])
|
||||
|
||||
# Prepare the success message
|
||||
success_msg = f"The file {path} has been edited. "
|
||||
success_msg += self._make_output(snippet, f"a snippet of {path}", start_line + 1)
|
||||
success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
|
||||
|
||||
return CLIResult(output=success_msg)
|
||||
|
||||
async def insert(self, path: Path, insert_line: int, new_str: str):
|
||||
"""Implement the insert command"""
|
||||
file_text = await self.read_file(path)
|
||||
file_text = file_text.expandtabs()
|
||||
new_str = new_str.expandtabs()
|
||||
file_text_lines = file_text.split("\n")
|
||||
n_lines_file = len(file_text_lines)
|
||||
|
||||
if insert_line < 0 or insert_line > n_lines_file:
|
||||
raise ToolError(
|
||||
f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
|
||||
)
|
||||
|
||||
new_str_lines = new_str.split("\n")
|
||||
new_file_text_lines = (
|
||||
file_text_lines[:insert_line] + new_str_lines + file_text_lines[insert_line:]
|
||||
)
|
||||
snippet_lines = (
|
||||
file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
|
||||
+ new_str_lines
|
||||
+ file_text_lines[insert_line : insert_line + SNIPPET_LINES]
|
||||
)
|
||||
|
||||
new_file_text = "\n".join(new_file_text_lines)
|
||||
snippet = "\n".join(snippet_lines)
|
||||
|
||||
await self.write_file(path, new_file_text)
|
||||
self.edit_history[path].append(file_text)
|
||||
|
||||
success_msg = f"The file {path} has been edited. "
|
||||
success_msg += self._make_output(
|
||||
snippet, "a snippet of the edited file", max(1, insert_line - SNIPPET_LINES + 1)
|
||||
)
|
||||
success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
|
||||
return CLIResult(output=success_msg)
|
||||
|
||||
async def undo_edit(self, path: Path):
|
||||
"""Implement the undo_edit command"""
|
||||
if not self.edit_history[path]:
|
||||
raise ToolError(f"No edit history found for {path}.")
|
||||
|
||||
old_text = self.edit_history[path].pop()
|
||||
await self.write_file(path, old_text)
|
||||
|
||||
return CLIResult(
|
||||
output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}"
|
||||
)
|
||||
|
||||
async def read_file(self, path: Path) -> str:
|
||||
"""Read the content of a file using cat command."""
|
||||
try:
|
||||
result = await self.computer.interface.run_command(f'cat "{str(path)}"')
|
||||
if result.stderr: # If there's stderr output
|
||||
raise ToolError(f"Error reading file: {result.stderr}")
|
||||
return result.stdout
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to read {path}: {str(e)}")
|
||||
|
||||
async def write_file(self, path: Path, content: str):
|
||||
"""Write content to a file using echo and redirection."""
|
||||
try:
|
||||
# Create parent directories if they don't exist
|
||||
parent = path.parent
|
||||
if parent != Path("/"):
|
||||
await self.computer.interface.run_command(f'mkdir -p "{str(parent)}"')
|
||||
|
||||
# Write content to file using echo and heredoc to preserve formatting
|
||||
cmd = f"""cat > "{str(path)}" << 'EOFCUA'
|
||||
{content}
|
||||
EOFCUA"""
|
||||
result = await self.computer.interface.run_command(cmd)
|
||||
if result.stderr: # If there's stderr output
|
||||
raise ToolError(f"Error writing file: {result.stderr}")
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to write to {path}: {str(e)}")
|
||||
|
||||
def _make_output(
|
||||
self,
|
||||
file_content: str,
|
||||
file_descriptor: str,
|
||||
init_line: int = 1,
|
||||
expand_tabs: bool = True,
|
||||
) -> str:
|
||||
"""Generate output for the CLI based on the content of a file."""
|
||||
file_content = maybe_truncate(file_content)
|
||||
if expand_tabs:
|
||||
file_content = file_content.expandtabs()
|
||||
file_content = "\n".join(
|
||||
[f"{i + init_line:6}\t{line}" for i, line in enumerate(file_content.split("\n"))]
|
||||
)
|
||||
return (
|
||||
f"Here's the result of running `cat -n` on {file_descriptor}:\n" + file_content + "\n"
|
||||
)
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"type": self.api_type,
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
from typing import Any, Dict, List, cast
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
from computer.computer import Computer
|
||||
|
||||
from ....core.tools import BaseToolManager, ToolResult
|
||||
from ....core.tools.collection import ToolCollection
|
||||
|
||||
from .bash import BashTool
|
||||
from .computer import ComputerTool
|
||||
from .edit import EditTool
|
||||
|
||||
|
||||
class ToolManager(BaseToolManager):
|
||||
"""Manages Anthropic-specific tool initialization and execution."""
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the tool manager.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for computer-related tools
|
||||
"""
|
||||
super().__init__(computer)
|
||||
# Initialize Anthropic-specific tools
|
||||
self.computer_tool = ComputerTool(self.computer)
|
||||
self.bash_tool = BashTool(self.computer)
|
||||
self.edit_tool = EditTool(self.computer)
|
||||
|
||||
def _initialize_tools(self) -> ToolCollection:
|
||||
"""Initialize all available tools."""
|
||||
return ToolCollection(self.computer_tool, self.bash_tool, self.edit_tool)
|
||||
|
||||
async def _initialize_tools_specific(self) -> None:
|
||||
"""Initialize Anthropic-specific tool requirements."""
|
||||
await self.computer_tool.initialize_dimensions()
|
||||
|
||||
def get_tool_params(self) -> List[BetaToolUnionParam]:
|
||||
"""Get tool parameters for Anthropic API calls."""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
return cast(List[BetaToolUnionParam], self.tools.to_params())
|
||||
|
||||
async def execute_tool(self, name: str, tool_input: dict[str, Any]) -> ToolResult:
|
||||
"""Execute a tool with the given input.
|
||||
|
||||
Args:
|
||||
name: Name of the tool to execute
|
||||
tool_input: Input parameters for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool execution
|
||||
"""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
return await self.tools.run(name=name, tool_input=tool_input)
|
||||
@@ -1,42 +0,0 @@
|
||||
"""Utility to run shell commands asynchronously with a timeout."""
|
||||
|
||||
import asyncio
|
||||
|
||||
TRUNCATED_MESSAGE: str = "<response clipped><NOTE>To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for.</NOTE>"
|
||||
MAX_RESPONSE_LEN: int = 16000
|
||||
|
||||
|
||||
def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
|
||||
"""Truncate content and append a notice if content exceeds the specified length."""
|
||||
return (
|
||||
content
|
||||
if not truncate_after or len(content) <= truncate_after
|
||||
else content[:truncate_after] + TRUNCATED_MESSAGE
|
||||
)
|
||||
|
||||
|
||||
async def run(
|
||||
cmd: str,
|
||||
timeout: float | None = 120.0, # seconds
|
||||
truncate_after: int | None = MAX_RESPONSE_LEN,
|
||||
):
|
||||
"""Run a shell command asynchronously with a timeout."""
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
return (
|
||||
process.returncode or 0,
|
||||
maybe_truncate(stdout.decode(), truncate_after=truncate_after),
|
||||
maybe_truncate(stderr.decode(), truncate_after=truncate_after),
|
||||
)
|
||||
except asyncio.TimeoutError as exc:
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
raise TimeoutError(
|
||||
f"Command '{cmd}' timed out after {timeout} seconds"
|
||||
) from exc
|
||||
@@ -1,16 +0,0 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class LLMProvider(StrEnum):
|
||||
"""Enum for supported API providers."""
|
||||
|
||||
ANTHROPIC = "anthropic"
|
||||
BEDROCK = "bedrock"
|
||||
VERTEX = "vertex"
|
||||
|
||||
|
||||
PROVIDER_TO_DEFAULT_MODEL_NAME: dict[LLMProvider, str] = {
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
LLMProvider.BEDROCK: "anthropic.claude-3-7-sonnet-20250219-v2:0",
|
||||
LLMProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
|
||||
}
|
||||
@@ -1,367 +0,0 @@
|
||||
"""Utility functions for Anthropic message handling."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from anthropic.types.beta import BetaMessage
|
||||
from ...core.types import AgentResponse
|
||||
from datetime import datetime
|
||||
|
||||
# Configure module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def to_anthropic_format(
|
||||
messages: List[Dict[str, Any]],
|
||||
) -> Tuple[List[Dict[str, Any]], str]:
|
||||
"""Convert standard OpenAI format messages to Anthropic format.
|
||||
|
||||
Args:
|
||||
messages: List of messages in OpenAI format
|
||||
|
||||
Returns:
|
||||
Tuple containing (anthropic_messages, system_content)
|
||||
"""
|
||||
result = []
|
||||
system_content = ""
|
||||
|
||||
# Process messages in order to maintain conversation flow
|
||||
previous_assistant_tool_use_ids = set() # Track tool_use_ids in the previous assistant message
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
# Collect system messages for later use
|
||||
system_content += content + "\n"
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
# Track tool_use_ids in this assistant message for the next user message
|
||||
previous_assistant_tool_use_ids = set()
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "tool_use" and "id" in item:
|
||||
previous_assistant_tool_use_ids.add(item["id"])
|
||||
|
||||
if role in ["user", "assistant"]:
|
||||
anthropic_msg = {"role": role}
|
||||
|
||||
# Convert content based on type
|
||||
if isinstance(content, str):
|
||||
# Simple text content
|
||||
anthropic_msg["content"] = [{"type": "text", "text": content}]
|
||||
elif isinstance(content, list):
|
||||
# Convert complex content
|
||||
anthropic_content = []
|
||||
for item in content:
|
||||
item_type = item.get("type", "")
|
||||
|
||||
if item_type == "text":
|
||||
anthropic_content.append({"type": "text", "text": item.get("text", "")})
|
||||
elif item_type == "image_url":
|
||||
# Convert OpenAI image format to Anthropic
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
if image_url.startswith("data:"):
|
||||
# Extract base64 data and media type
|
||||
match = re.match(r"data:(.+);base64,(.+)", image_url)
|
||||
if match:
|
||||
media_type, data = match.groups()
|
||||
anthropic_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": data,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Regular URL
|
||||
anthropic_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": image_url,
|
||||
},
|
||||
}
|
||||
)
|
||||
elif item_type == "tool_use":
|
||||
# Always include tool_use blocks
|
||||
anthropic_content.append(item)
|
||||
elif item_type == "tool_result":
|
||||
# Check if this is a user message AND if the tool_use_id exists in the previous assistant message
|
||||
tool_use_id = item.get("tool_use_id")
|
||||
|
||||
# Only include tool_result if it references a tool_use from the immediately preceding assistant message
|
||||
if (
|
||||
role == "user"
|
||||
and tool_use_id
|
||||
and tool_use_id in previous_assistant_tool_use_ids
|
||||
):
|
||||
anthropic_content.append(item)
|
||||
else:
|
||||
content_text = "Tool Result: "
|
||||
if "content" in item:
|
||||
if isinstance(item["content"], list):
|
||||
for content_item in item["content"]:
|
||||
if (
|
||||
isinstance(content_item, dict)
|
||||
and content_item.get("type") == "text"
|
||||
):
|
||||
content_text += content_item.get("text", "")
|
||||
elif isinstance(item["content"], str):
|
||||
content_text += item["content"]
|
||||
anthropic_content.append({"type": "text", "text": content_text})
|
||||
|
||||
anthropic_msg["content"] = anthropic_content
|
||||
|
||||
result.append(anthropic_msg)
|
||||
|
||||
return result, system_content
|
||||
|
||||
|
||||
def from_anthropic_format(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert Anthropic format messages to standard OpenAI format.
|
||||
|
||||
Args:
|
||||
messages: List of messages in Anthropic format
|
||||
|
||||
Returns:
|
||||
List of messages in OpenAI format
|
||||
"""
|
||||
result = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", [])
|
||||
|
||||
if role in ["user", "assistant"]:
|
||||
openai_msg = {"role": role}
|
||||
|
||||
# Simple case: single text block
|
||||
if len(content) == 1 and content[0].get("type") == "text":
|
||||
openai_msg["content"] = content[0].get("text", "")
|
||||
else:
|
||||
# Complex case: multiple blocks or non-text
|
||||
openai_content = []
|
||||
for item in content:
|
||||
item_type = item.get("type", "")
|
||||
|
||||
if item_type == "text":
|
||||
openai_content.append({"type": "text", "text": item.get("text", "")})
|
||||
elif item_type == "image":
|
||||
# Convert Anthropic image to OpenAI format
|
||||
source = item.get("source", {})
|
||||
if source.get("type") == "base64":
|
||||
media_type = source.get("media_type", "image/png")
|
||||
data = source.get("data", "")
|
||||
openai_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{media_type};base64,{data}"},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# URL
|
||||
openai_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": source.get("url", "")},
|
||||
}
|
||||
)
|
||||
elif item_type in ["tool_use", "tool_result"]:
|
||||
# Pass through tool-related content
|
||||
openai_content.append(item)
|
||||
|
||||
openai_msg["content"] = openai_content
|
||||
|
||||
result.append(openai_msg)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
async def to_agent_response_format(
|
||||
response: BetaMessage,
|
||||
messages: List[Dict[str, Any]],
|
||||
parsed_screen: Optional[dict] = None,
|
||||
parser: Optional[Any] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> AgentResponse:
|
||||
"""Convert an Anthropic response to the standard agent response format.
|
||||
|
||||
Args:
|
||||
response: The Anthropic API response (BetaMessage)
|
||||
messages: List of messages in standard format
|
||||
parsed_screen: Optional pre-parsed screen information
|
||||
parser: Optional parser instance for coordinate calculation
|
||||
model: Optional model name
|
||||
|
||||
Returns:
|
||||
A response formatted according to the standard agent response format
|
||||
"""
|
||||
# Create unique IDs for this response
|
||||
response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(response)}"
|
||||
reasoning_id = f"rs_{response_id}"
|
||||
action_id = f"cu_{response_id}"
|
||||
call_id = f"call_{response_id}"
|
||||
|
||||
# Extract content and reasoning from Anthropic response
|
||||
content = []
|
||||
reasoning_text = None
|
||||
action_details = None
|
||||
|
||||
for block in response.content:
|
||||
if block.type == "text":
|
||||
# Use the first text block as reasoning
|
||||
if reasoning_text is None:
|
||||
reasoning_text = block.text
|
||||
content.append({"type": "text", "text": block.text})
|
||||
elif block.type == "tool_use" and block.name == "computer":
|
||||
try:
|
||||
input_dict = cast(Dict[str, Any], block.input)
|
||||
action = input_dict.get("action", "").lower()
|
||||
|
||||
# Extract coordinates from coordinate list if provided
|
||||
coordinates = input_dict.get("coordinate", [100, 100])
|
||||
x, y = coordinates if len(coordinates) == 2 else (100, 100)
|
||||
|
||||
if action == "screenshot":
|
||||
action_details = {
|
||||
"type": "screenshot",
|
||||
}
|
||||
elif action in ["click", "left_click", "right_click", "double_click"]:
|
||||
action_details = {
|
||||
"type": "click",
|
||||
"button": "left" if action in ["click", "left_click"] else "right",
|
||||
"double": action == "double_click",
|
||||
"x": x,
|
||||
"y": y,
|
||||
}
|
||||
elif action == "type":
|
||||
action_details = {
|
||||
"type": "type",
|
||||
"text": input_dict.get("text", ""),
|
||||
}
|
||||
elif action == "key":
|
||||
action_details = {
|
||||
"type": "hotkey",
|
||||
"keys": [input_dict.get("text", "")],
|
||||
}
|
||||
elif action == "scroll":
|
||||
scroll_amount = input_dict.get("scroll_amount", 1)
|
||||
scroll_direction = input_dict.get("scroll_direction", "down")
|
||||
delta_y = scroll_amount if scroll_direction == "down" else -scroll_amount
|
||||
action_details = {
|
||||
"type": "scroll",
|
||||
"x": x,
|
||||
"y": y,
|
||||
"delta_x": 0,
|
||||
"delta_y": delta_y,
|
||||
}
|
||||
elif action == "move":
|
||||
action_details = {
|
||||
"type": "move",
|
||||
"x": x,
|
||||
"y": y,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting action details: {str(e)}")
|
||||
|
||||
# Create output items with reasoning
|
||||
output_items = []
|
||||
if reasoning_text:
|
||||
output_items.append(
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": reasoning_id,
|
||||
"summary": [
|
||||
{
|
||||
"type": "summary_text",
|
||||
"text": reasoning_text,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Add computer_call item with extracted or default action
|
||||
computer_call = {
|
||||
"type": "computer_call",
|
||||
"id": action_id,
|
||||
"call_id": call_id,
|
||||
"action": action_details or {"type": "none", "description": "No action specified"},
|
||||
"pending_safety_checks": [],
|
||||
"status": "completed",
|
||||
}
|
||||
output_items.append(computer_call)
|
||||
|
||||
# Create the standard response format
|
||||
standard_response = {
|
||||
"id": response_id,
|
||||
"object": "response",
|
||||
"created_at": int(datetime.now().timestamp()),
|
||||
"status": "completed",
|
||||
"error": None,
|
||||
"incomplete_details": None,
|
||||
"instructions": None,
|
||||
"max_output_tokens": None,
|
||||
"model": model or "anthropic-default",
|
||||
"output": output_items,
|
||||
"parallel_tool_calls": True,
|
||||
"previous_response_id": None,
|
||||
"reasoning": {"effort": "medium", "generate_summary": "concise"},
|
||||
"store": True,
|
||||
"temperature": 1.0,
|
||||
"text": {"format": {"type": "text"}},
|
||||
"tool_choice": "auto",
|
||||
"tools": [
|
||||
{
|
||||
"type": "computer_use_preview",
|
||||
"display_height": 768,
|
||||
"display_width": 1024,
|
||||
"environment": "mac",
|
||||
}
|
||||
],
|
||||
"top_p": 1.0,
|
||||
"truncation": "auto",
|
||||
"usage": {
|
||||
"input_tokens": 0,
|
||||
"input_tokens_details": {"cached_tokens": 0},
|
||||
"output_tokens": 0,
|
||||
"output_tokens_details": {"reasoning_tokens": 0},
|
||||
"total_tokens": 0,
|
||||
},
|
||||
"user": None,
|
||||
"metadata": {},
|
||||
"response": {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": content,
|
||||
"tool_calls": [],
|
||||
},
|
||||
"finish_reason": response.stop_reason or "stop",
|
||||
}
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
# Add tool calls if present
|
||||
tool_calls = []
|
||||
for block in response.content:
|
||||
if hasattr(block, "type") and block.type == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": f"call_{block.id}",
|
||||
"type": "function",
|
||||
"function": {"name": block.name, "arguments": block.input},
|
||||
}
|
||||
)
|
||||
if tool_calls:
|
||||
standard_response["response"]["choices"][0]["message"]["tool_calls"] = tool_calls
|
||||
|
||||
return cast(AgentResponse, standard_response)
|
||||
@@ -1,8 +0,0 @@
|
||||
"""Omni provider implementation."""
|
||||
|
||||
from ...core.types import LLMProvider
|
||||
from .image_utils import (
|
||||
decode_base64_image,
|
||||
)
|
||||
|
||||
__all__ = ["LLMProvider", "decode_base64_image"]
|
||||
@@ -1,42 +0,0 @@
|
||||
"""API handling for Omni provider."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .prompts import SYSTEM_PROMPT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OmniAPIHandler:
|
||||
"""Handler for Omni API calls."""
|
||||
|
||||
def __init__(self, loop):
|
||||
"""Initialize the API handler.
|
||||
|
||||
Args:
|
||||
loop: Parent loop instance
|
||||
"""
|
||||
self.loop = loop
|
||||
|
||||
async def make_api_call(
|
||||
self, messages: List[Dict[str, Any]], system_prompt: str = SYSTEM_PROMPT
|
||||
) -> Any:
|
||||
"""Make an API call to the appropriate provider.
|
||||
|
||||
Args:
|
||||
messages: List of messages in standard OpenAI format
|
||||
system_prompt: System prompt to use
|
||||
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
if not self.loop._make_api_call:
|
||||
raise RuntimeError("Loop does not have _make_api_call method")
|
||||
|
||||
try:
|
||||
# Use the loop's _make_api_call method with standard messages
|
||||
return await self.loop._make_api_call(messages=messages, system_prompt=system_prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"Error making API call: {str(e)}")
|
||||
raise
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Anthropic API client implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
import asyncio
|
||||
from httpx import ConnectError, ReadTimeout
|
||||
|
||||
from anthropic import AsyncAnthropic, Anthropic
|
||||
from anthropic.types import MessageParam
|
||||
from .base import BaseOmniClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicClient(BaseOmniClient):
|
||||
"""Client for making calls to Anthropic API."""
|
||||
|
||||
def __init__(self, api_key: str, model: str, max_retries: int = 3, retry_delay: float = 1.0):
|
||||
"""Initialize the Anthropic client.
|
||||
|
||||
Args:
|
||||
api_key: Anthropic API key
|
||||
model: Anthropic model name (e.g. "claude-3-opus-20240229")
|
||||
max_retries: Maximum number of retries for API calls
|
||||
retry_delay: Base delay between retries in seconds
|
||||
"""
|
||||
if not model:
|
||||
raise ValueError("Model name must be provided")
|
||||
|
||||
self.client = AsyncAnthropic(api_key=api_key)
|
||||
self.model: str = model # Add explicit type annotation
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
|
||||
def _convert_message_format(self, messages: List[Dict[str, Any]]) -> List[MessageParam]:
|
||||
"""Convert messages from standard format to Anthropic format.
|
||||
|
||||
Args:
|
||||
messages: Messages in standard format
|
||||
|
||||
Returns:
|
||||
Messages in Anthropic format
|
||||
"""
|
||||
anthropic_messages = []
|
||||
|
||||
for message in messages:
|
||||
# Skip messages with empty content
|
||||
if not message.get("content"):
|
||||
continue
|
||||
|
||||
if message["role"] == "user":
|
||||
anthropic_messages.append({"role": "user", "content": message["content"]})
|
||||
elif message["role"] == "assistant":
|
||||
anthropic_messages.append({"role": "assistant", "content": message["content"]})
|
||||
|
||||
# Cast the list to the correct type expected by Anthropic
|
||||
return cast(List[MessageParam], anthropic_messages)
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: int
|
||||
) -> Any:
|
||||
"""Run model with interleaved conversation format.
|
||||
|
||||
Args:
|
||||
messages: List of messages to process
|
||||
system: System prompt
|
||||
max_tokens: Maximum tokens to generate
|
||||
|
||||
Returns:
|
||||
Model response
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# Convert messages to Anthropic format
|
||||
anthropic_messages = self._convert_message_format(messages)
|
||||
|
||||
response = await self.client.messages.create(
|
||||
model=self.model,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0,
|
||||
system=system,
|
||||
messages=anthropic_messages,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except (ConnectError, ReadTimeout) as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
|
||||
)
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in Anthropic API call: {str(e)}")
|
||||
raise RuntimeError(f"Anthropic API call failed: {str(e)}")
|
||||
|
||||
# If we get here, all retries failed
|
||||
raise RuntimeError(f"Connection error after {self.max_retries} retries: {str(last_error)}")
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Base client implementation for Omni providers."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseOmniClient:
|
||||
"""Base class for provider-specific clients."""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None):
|
||||
"""Initialize base client.
|
||||
|
||||
Args:
|
||||
api_key: Optional API key
|
||||
model: Optional model name
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run interleaved chat completion.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
system: System prompt
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
Response dict
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,195 +0,0 @@
|
||||
"""OpenAI-compatible client implementation."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
import aiohttp
|
||||
import re
|
||||
from .base import BaseOmniClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# OpenAI-compatible client for the OmniLoop
|
||||
class OAICompatClient(BaseOmniClient):
|
||||
"""OpenAI-compatible API client implementation.
|
||||
|
||||
This client can be used with any service that implements the OpenAI API protocol, including:
|
||||
- vLLM
|
||||
- LM Studio
|
||||
- LocalAI
|
||||
- Ollama (with OpenAI compatibility)
|
||||
- Text Generation WebUI
|
||||
- Any other service with OpenAI API compatibility
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "Qwen2.5-VL-7B-Instruct",
|
||||
provider_base_url: Optional[str] = "http://localhost:8000/v1",
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.0,
|
||||
):
|
||||
"""Initialize the OpenAI-compatible client.
|
||||
|
||||
Args:
|
||||
api_key: Not used for local endpoints, usually set to "EMPTY"
|
||||
model: Model name to use
|
||||
provider_base_url: API base URL. Typically in the format "http://localhost:PORT/v1"
|
||||
Examples:
|
||||
- vLLM: "http://localhost:8000/v1"
|
||||
- LM Studio: "http://localhost:1234/v1"
|
||||
- LocalAI: "http://localhost:8080/v1"
|
||||
- Ollama: "http://localhost:11434/v1"
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Generation temperature
|
||||
"""
|
||||
super().__init__(api_key=api_key or "EMPTY", model=model)
|
||||
self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
|
||||
self.model = model
|
||||
self.provider_base_url = (
|
||||
provider_base_url or "http://localhost:8000/v1"
|
||||
) # Use default if None
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
|
||||
def _extract_base64_image(self, text: str) -> Optional[str]:
|
||||
"""Extract base64 image data from an HTML img tag."""
|
||||
pattern = r'data:image/[^;]+;base64,([^"]+)'
|
||||
match = re.search(pattern, text)
|
||||
return match.group(1) if match else None
|
||||
|
||||
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Create a loggable version of messages with image data truncated."""
|
||||
loggable_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg.get("content"), list):
|
||||
new_content = []
|
||||
for content in msg["content"]:
|
||||
if content.get("type") == "image":
|
||||
new_content.append(
|
||||
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
|
||||
)
|
||||
else:
|
||||
new_content.append(content)
|
||||
loggable_messages.append({"role": msg["role"], "content": new_content})
|
||||
else:
|
||||
loggable_messages.append(msg)
|
||||
return loggable_messages
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run interleaved chat completion.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
system: System prompt
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
Response dict
|
||||
"""
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
final_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{ "type": "text", "text": system }
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process messages
|
||||
for item in messages:
|
||||
if isinstance(item, dict):
|
||||
if isinstance(item["content"], list):
|
||||
# Content is already in the correct format
|
||||
final_messages.append(item)
|
||||
else:
|
||||
# Single string content, check for image
|
||||
base64_img = self._extract_base64_image(item["content"])
|
||||
if base64_img:
|
||||
message = {
|
||||
"role": item["role"],
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
message = {
|
||||
"role": item["role"],
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": item["content"]
|
||||
}],
|
||||
}
|
||||
final_messages.append(message)
|
||||
else:
|
||||
# String content, check for image
|
||||
base64_img = self._extract_base64_image(item)
|
||||
if base64_img:
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
message = {"role": "user", "content": [{"type": "text", "text": item}]}
|
||||
final_messages.append(message)
|
||||
|
||||
payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
|
||||
payload["max_tokens"] = max_tokens or self.max_tokens
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Use default base URL if none provided
|
||||
base_url = self.provider_base_url or "http://localhost:8000/v1"
|
||||
|
||||
# Check if the base URL already includes the chat/completions endpoint
|
||||
|
||||
endpoint_url = base_url
|
||||
if not endpoint_url.endswith("/chat/completions"):
|
||||
# If URL is RunPod format, make it OpenAI compatible
|
||||
if endpoint_url.startswith("https://api.runpod.ai/v2/"):
|
||||
# Extract RunPod endpoint ID
|
||||
parts = endpoint_url.split("/")
|
||||
if len(parts) >= 5:
|
||||
runpod_id = parts[4]
|
||||
endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
|
||||
# If the URL ends with /v1, append /chat/completions
|
||||
elif endpoint_url.endswith("/v1"):
|
||||
endpoint_url = f"{endpoint_url}/chat/completions"
|
||||
# If the URL doesn't end with /v1, make sure it has a proper structure
|
||||
elif not endpoint_url.endswith("/"):
|
||||
endpoint_url = f"{endpoint_url}/chat/completions"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}chat/completions"
|
||||
|
||||
# Log the endpoint URL for debugging
|
||||
logger.debug(f"Using endpoint URL: {endpoint_url}")
|
||||
|
||||
async with session.post(endpoint_url, headers=headers, json=payload) as response:
|
||||
response_json = await response.json()
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_json.get("error", {}).get(
|
||||
"message", str(response_json)
|
||||
)
|
||||
logger.error(f"Error in API call: {error_msg}")
|
||||
raise Exception(f"API error: {error_msg}")
|
||||
|
||||
return response_json
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in API call: {str(e)}")
|
||||
raise
|
||||
@@ -1,122 +0,0 @@
|
||||
"""Ollama API client implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
import asyncio
|
||||
from httpx import ConnectError, ReadTimeout
|
||||
|
||||
from ollama import AsyncClient, Options
|
||||
from ollama import Message
|
||||
from .base import BaseOmniClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OllamaClient(BaseOmniClient):
|
||||
"""Client for making calls to Ollama API."""
|
||||
|
||||
def __init__(self, api_key: str, model: str, max_retries: int = 3, retry_delay: float = 1.0):
|
||||
"""Initialize the Ollama client.
|
||||
|
||||
Args:
|
||||
api_key: Not used
|
||||
model: Ollama model name (e.g. "gemma3:4b-it-q4_K_M")
|
||||
max_retries: Maximum number of retries for API calls
|
||||
retry_delay: Base delay between retries in seconds
|
||||
"""
|
||||
if not model:
|
||||
raise ValueError("Model name must be provided")
|
||||
|
||||
self.client = AsyncClient(
|
||||
host="http://localhost:11434",
|
||||
)
|
||||
self.model: str = model # Add explicit type annotation
|
||||
self.max_retries = max_retries
|
||||
self.retry_delay = retry_delay
|
||||
|
||||
def _convert_message_format(self, system: str, messages: List[Dict[str, Any]]) -> List[Any]:
|
||||
"""Convert messages from standard format to Ollama format.
|
||||
|
||||
Args:
|
||||
messages: Messages in standard format
|
||||
|
||||
Returns:
|
||||
Messages in Ollama format
|
||||
"""
|
||||
ollama_messages = []
|
||||
|
||||
# Add system message
|
||||
ollama_messages.append(
|
||||
{
|
||||
"role": "system",
|
||||
"content": system,
|
||||
}
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
# Skip messages with empty content
|
||||
if not message.get("content"):
|
||||
continue
|
||||
content = message.get("content", [{}])[0]
|
||||
isImage = content.get("type", "") == "image_url"
|
||||
isText = content.get("type", "") == "text"
|
||||
if isText:
|
||||
data = content.get("text", "")
|
||||
ollama_messages.append({"role": message["role"], "content": data})
|
||||
if isImage:
|
||||
data = content.get("image_url", {}).get("url", "")
|
||||
# remove header
|
||||
data = data.removeprefix("data:image/png;base64,")
|
||||
ollama_messages.append(
|
||||
{"role": message["role"], "content": "Use this image", "images": [data]}
|
||||
)
|
||||
|
||||
# Cast the list to the correct type expected by Ollama
|
||||
return cast(List[Any], ollama_messages)
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: int
|
||||
) -> Any:
|
||||
"""Run model with interleaved conversation format.
|
||||
|
||||
Args:
|
||||
messages: List of messages to process
|
||||
system: System prompt
|
||||
max_tokens: Not used
|
||||
|
||||
Returns:
|
||||
Model response
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# Convert messages to Ollama format
|
||||
ollama_messages = self._convert_message_format(system, messages)
|
||||
|
||||
response = await self.client.chat(
|
||||
model=self.model,
|
||||
options=Options(
|
||||
temperature=0,
|
||||
),
|
||||
messages=ollama_messages,
|
||||
format="json",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except (ConnectError, ReadTimeout) as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
|
||||
)
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in Ollama API call: {str(e)}")
|
||||
raise RuntimeError(f"Ollama API call failed: {str(e)}")
|
||||
|
||||
# If we get here, all retries failed
|
||||
raise RuntimeError(f"Connection error after {self.max_retries} retries: {str(last_error)}")
|
||||
@@ -1,155 +0,0 @@
|
||||
"""OpenAI client implementation."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
import aiohttp
|
||||
import re
|
||||
from datetime import datetime
|
||||
from .base import BaseOmniClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenAI specific client for the OmniLoop
|
||||
|
||||
|
||||
class OpenAIClient(BaseOmniClient):
|
||||
"""OpenAI vision API client implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "gpt-4o",
|
||||
provider_base_url: str = "https://api.openai.com/v1",
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.0,
|
||||
):
|
||||
"""Initialize the OpenAI client.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Model to use
|
||||
provider_base_url: API endpoint
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Generation temperature
|
||||
"""
|
||||
super().__init__(api_key=api_key, model=model)
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("No OpenAI API key provided")
|
||||
|
||||
self.model = model
|
||||
self.provider_base_url = provider_base_url
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
|
||||
def _extract_base64_image(self, text: str) -> Optional[str]:
|
||||
"""Extract base64 image data from an HTML img tag."""
|
||||
pattern = r'data:image/[^;]+;base64,([^"]+)'
|
||||
match = re.search(pattern, text)
|
||||
return match.group(1) if match else None
|
||||
|
||||
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Create a loggable version of messages with image data truncated."""
|
||||
loggable_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg.get("content"), list):
|
||||
new_content = []
|
||||
for content in msg["content"]:
|
||||
if content.get("type") == "image":
|
||||
new_content.append(
|
||||
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
|
||||
)
|
||||
else:
|
||||
new_content.append(content)
|
||||
loggable_messages.append({"role": msg["role"], "content": new_content})
|
||||
else:
|
||||
loggable_messages.append(msg)
|
||||
return loggable_messages
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run interleaved chat completion.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
system: System prompt
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
Response dict
|
||||
"""
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
final_messages = [{"role": "system", "content": system}]
|
||||
|
||||
# Process messages
|
||||
for item in messages:
|
||||
if isinstance(item, dict):
|
||||
if isinstance(item["content"], list):
|
||||
# Content is already in the correct format
|
||||
final_messages.append(item)
|
||||
else:
|
||||
# Single string content, check for image
|
||||
base64_img = self._extract_base64_image(item["content"])
|
||||
if base64_img:
|
||||
message = {
|
||||
"role": item["role"],
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
message = {
|
||||
"role": item["role"],
|
||||
"content": [{"type": "text", "text": item["content"]}],
|
||||
}
|
||||
final_messages.append(message)
|
||||
else:
|
||||
# String content, check for image
|
||||
base64_img = self._extract_base64_image(item)
|
||||
if base64_img:
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
message = {"role": "user", "content": [{"type": "text", "text": item}]}
|
||||
final_messages.append(message)
|
||||
|
||||
payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
|
||||
|
||||
if "o1" in self.model or "o3-mini" in self.model:
|
||||
payload["reasoning_effort"] = "low"
|
||||
payload["max_completion_tokens"] = max_tokens or self.max_tokens
|
||||
else:
|
||||
payload["max_tokens"] = max_tokens or self.max_tokens
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.provider_base_url}/chat/completions", headers=headers, json=payload
|
||||
) as response:
|
||||
response_json = await response.json()
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_json.get("error", {}).get(
|
||||
"message", str(response_json)
|
||||
)
|
||||
logger.error(f"Error in OpenAI API call: {error_msg}")
|
||||
raise Exception(f"OpenAI API error: {error_msg}")
|
||||
|
||||
return response_json
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in OpenAI API call: {str(e)}")
|
||||
raise
|
||||
@@ -1,25 +0,0 @@
|
||||
import base64
|
||||
|
||||
def is_image_path(text: str) -> bool:
|
||||
"""Check if a text string is an image file path.
|
||||
|
||||
Args:
|
||||
text: Text string to check
|
||||
|
||||
Returns:
|
||||
True if text ends with image extension, False otherwise
|
||||
"""
|
||||
image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
|
||||
return text.endswith(image_extensions)
|
||||
|
||||
def encode_image(image_path: str) -> str:
|
||||
"""Encode image file to base64.
|
||||
|
||||
Args:
|
||||
image_path: Path to image file
|
||||
|
||||
Returns:
|
||||
Base64 encoded image string
|
||||
"""
|
||||
with open(image_path, "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode("utf-8")
|
||||
@@ -1,34 +0,0 @@
|
||||
"""Image processing utilities for the Cua provider."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def decode_base64_image(img_base64: str) -> Optional[Image.Image]:
|
||||
"""Decode a base64 encoded image to a PIL Image.
|
||||
|
||||
Args:
|
||||
img_base64: Base64 encoded image, may include data URL prefix
|
||||
|
||||
Returns:
|
||||
PIL Image or None if decoding fails
|
||||
"""
|
||||
try:
|
||||
# Remove data URL prefix if present
|
||||
if img_base64.startswith("data:image"):
|
||||
img_base64 = img_base64.split(",")[1]
|
||||
|
||||
# Decode base64 to bytes
|
||||
img_data = base64.b64decode(img_base64)
|
||||
|
||||
# Convert bytes to PIL Image
|
||||
return Image.open(BytesIO(img_data))
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding base64 image: {str(e)}")
|
||||
return None
|
||||
@@ -1,990 +0,0 @@
|
||||
"""Omni-specific agent loop implementation."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import asyncio
|
||||
from httpx import ConnectError, ReadTimeout
|
||||
from typing import cast
|
||||
|
||||
from .parser import OmniParser, ParseResult
|
||||
from ...core.base import BaseLoop
|
||||
from ...core.visualization import VisualizationHelper
|
||||
from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
||||
from .utils import to_openai_agent_response_format
|
||||
from ...core.types import AgentResponse
|
||||
from computer import Computer
|
||||
from ...core.types import LLMProvider
|
||||
from .clients.openai import OpenAIClient
|
||||
from .clients.anthropic import AnthropicClient
|
||||
from .clients.ollama import OllamaClient
|
||||
from .clients.oaicompat import OAICompatClient
|
||||
from .prompts import SYSTEM_PROMPT
|
||||
from .api_handler import OmniAPIHandler
|
||||
from .tools.manager import ToolManager
|
||||
from .tools import ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def extract_data(input_string: str, data_type: str) -> str:
|
||||
"""Extract content from code blocks."""
|
||||
pattern = f"```{data_type}" + r"(.*?)(```|$)"
|
||||
matches = re.findall(pattern, input_string, re.DOTALL)
|
||||
return matches[0][0].strip() if matches else input_string
|
||||
|
||||
|
||||
class OmniLoop(BaseLoop):
|
||||
"""Omni-specific implementation of the agent loop.
|
||||
|
||||
This class extends BaseLoop to provide support for multimodal models
|
||||
from various providers (OpenAI, Anthropic, etc.) with UI parsing
|
||||
and desktop automation capabilities.
|
||||
"""
|
||||
|
||||
###########################################
|
||||
# INITIALIZATION AND CONFIGURATION
|
||||
###########################################
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parser: OmniParser,
|
||||
provider: LLMProvider,
|
||||
api_key: str,
|
||||
model: str,
|
||||
computer: Computer,
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
save_trajectory: bool = True,
|
||||
provider_base_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the loop.
|
||||
|
||||
Args:
|
||||
parser: Parser instance
|
||||
provider: API provider
|
||||
api_key: API key
|
||||
model: Model name
|
||||
computer: Computer instance
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
base_dir: Base directory for saving experiment data
|
||||
max_retries: Maximum number of retries for API calls
|
||||
retry_delay: Delay between retries in seconds
|
||||
save_trajectory: Whether to save trajectory data
|
||||
provider_base_url: Base URL for the API provider (used for OAICOMPAT)
|
||||
"""
|
||||
# Set parser and provider before initializing base class
|
||||
self.parser = parser
|
||||
self.provider = provider
|
||||
self.provider_base_url = provider_base_url
|
||||
|
||||
# Initialize message manager with image retention config
|
||||
self.message_manager = StandardMessageManager(
|
||||
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
||||
)
|
||||
|
||||
# Initialize base class (which will set up experiment manager)
|
||||
super().__init__(
|
||||
computer=computer,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
base_dir=base_dir,
|
||||
save_trajectory=save_trajectory,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Set API client attributes
|
||||
self.client = None
|
||||
self.retry_count = 0
|
||||
self.loop_task = None # Store the loop task for cancellation
|
||||
|
||||
# Initialize handlers
|
||||
self.api_handler = OmniAPIHandler(loop=self)
|
||||
self.viz_helper = VisualizationHelper(agent=self)
|
||||
|
||||
# Initialize tool manager
|
||||
self.tool_manager = ToolManager(computer=computer, provider=provider)
|
||||
|
||||
logger.info("OmniLoop initialized with StandardMessageManager")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the loop by setting up tools and clients."""
|
||||
# Initialize base class
|
||||
await super().initialize()
|
||||
|
||||
# Initialize tool manager with error handling
|
||||
try:
|
||||
logger.info("Initializing tool manager...")
|
||||
await self.tool_manager.initialize()
|
||||
logger.info("Tool manager initialized successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tool manager: {str(e)}")
|
||||
logger.warning("Will attempt to initialize tools on first use.")
|
||||
|
||||
# Initialize API clients based on provider
|
||||
if self.provider == LLMProvider.ANTHROPIC:
|
||||
self.client = AnthropicClient(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
)
|
||||
elif self.provider == LLMProvider.OPENAI:
|
||||
self.client = OpenAIClient(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
)
|
||||
elif self.provider == LLMProvider.OLLAMA:
|
||||
self.client = OllamaClient(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
)
|
||||
elif self.provider == LLMProvider.OAICOMPAT:
|
||||
self.client = OAICompatClient(
|
||||
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
|
||||
model=self.model,
|
||||
provider_base_url=self.provider_base_url,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
###########################################
|
||||
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def initialize_client(self) -> None:
|
||||
"""Initialize the appropriate client based on provider.
|
||||
|
||||
Implements abstract method from BaseLoop to set up the specific
|
||||
provider client (OpenAI, Anthropic, etc.).
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Initializing {self.provider} client with model {self.model}...")
|
||||
|
||||
if self.provider == LLMProvider.OPENAI:
|
||||
self.client = OpenAIClient(api_key=self.api_key, model=self.model)
|
||||
elif self.provider == LLMProvider.ANTHROPIC:
|
||||
self.client = AnthropicClient(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
max_retries=self.max_retries,
|
||||
retry_delay=self.retry_delay,
|
||||
)
|
||||
elif self.provider == LLMProvider.OLLAMA:
|
||||
self.client = OllamaClient(
|
||||
api_key=self.api_key,
|
||||
model=self.model,
|
||||
)
|
||||
elif self.provider == LLMProvider.OAICOMPAT:
|
||||
self.client = OAICompatClient(
|
||||
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
|
||||
model=self.model,
|
||||
provider_base_url=self.provider_base_url,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
logger.info(f"Initialized {self.provider} client with model {self.model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing client: {str(e)}")
|
||||
self.client = None
|
||||
raise RuntimeError(f"Failed to initialize client: {str(e)}")
|
||||
|
||||
###########################################
|
||||
# API CALL HANDLING
|
||||
###########################################
|
||||
|
||||
async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any:
|
||||
"""Make API call to provider with retry logic."""
|
||||
# Create new turn directory for this API call
|
||||
self._create_turn_dir()
|
||||
|
||||
request_data = None
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# Ensure client is initialized
|
||||
if self.client is None:
|
||||
logger.info(
|
||||
f"Client not initialized in _make_api_call (attempt {attempt+1}), initializing now..."
|
||||
)
|
||||
await self.initialize_client()
|
||||
if self.client is None:
|
||||
raise RuntimeError("Failed to initialize client")
|
||||
|
||||
# Get messages in standard format from the message manager
|
||||
self.message_manager.messages = messages.copy()
|
||||
prepared_messages = self.message_manager.get_messages()
|
||||
|
||||
# Special handling for Anthropic
|
||||
if self.provider == LLMProvider.ANTHROPIC:
|
||||
# Convert to Anthropic format
|
||||
anthropic_messages, anthropic_system = self.message_manager.to_anthropic_format(
|
||||
prepared_messages
|
||||
)
|
||||
|
||||
# Filter out any empty/invalid messages
|
||||
filtered_messages = [
|
||||
msg
|
||||
for msg in anthropic_messages
|
||||
if msg.get("role") in ["user", "assistant"]
|
||||
]
|
||||
|
||||
# Ensure there's at least one message for Anthropic
|
||||
if not filtered_messages:
|
||||
logger.warning(
|
||||
"No valid messages found for Anthropic API call. Adding a default user message."
|
||||
)
|
||||
filtered_messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Please help with this task."}
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Combine system prompts if needed
|
||||
final_system_prompt = anthropic_system or system_prompt
|
||||
|
||||
# Log request
|
||||
request_data = {
|
||||
"messages": filtered_messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"system": final_system_prompt,
|
||||
}
|
||||
|
||||
self._log_api_call("request", request_data)
|
||||
|
||||
# Make API call
|
||||
response = await self.client.run_interleaved(
|
||||
messages=filtered_messages,
|
||||
system=final_system_prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
else:
|
||||
# For OpenAI and others, use standard format directly
|
||||
# Log request
|
||||
request_data = {
|
||||
"messages": prepared_messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"system": system_prompt,
|
||||
}
|
||||
|
||||
self._log_api_call("request", request_data)
|
||||
|
||||
# Make API call
|
||||
response = await self.client.run_interleaved(
|
||||
messages=prepared_messages,
|
||||
system=system_prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
# Log success response
|
||||
self._log_api_call("response", request_data, response)
|
||||
|
||||
return response
|
||||
|
||||
except (ConnectError, ReadTimeout) as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
|
||||
)
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
||||
# Reset client on connection errors to force re-initialization
|
||||
self.client = None
|
||||
continue
|
||||
|
||||
except RuntimeError as e:
|
||||
# Handle client initialization errors specifically
|
||||
last_error = e
|
||||
self._log_api_call("error", request_data, error=e)
|
||||
logger.error(
|
||||
f"Client initialization error (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
|
||||
)
|
||||
if attempt < self.max_retries - 1:
|
||||
# Reset client to force re-initialization
|
||||
self.client = None
|
||||
await asyncio.sleep(self.retry_delay)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
# Log unexpected error
|
||||
last_error = e
|
||||
self._log_api_call("error", request_data, error=e)
|
||||
logger.error(f"Unexpected error in API call: {str(e)}")
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay)
|
||||
continue
|
||||
|
||||
# If we get here, all retries failed
|
||||
error_message = f"API call failed after {self.max_retries} attempts"
|
||||
if last_error:
|
||||
error_message += f": {str(last_error)}"
|
||||
|
||||
logger.error(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
###########################################
|
||||
# RESPONSE AND ACTION HANDLING
|
||||
###########################################
|
||||
|
||||
async def _handle_response(
|
||||
self, response: Any, messages: List[Dict[str, Any]], parsed_screen: ParseResult
|
||||
) -> Tuple[bool, bool]:
|
||||
"""Handle API response.
|
||||
|
||||
Args:
|
||||
response: API response
|
||||
messages: List of messages to update
|
||||
parsed_screen: Current parsed screen information
|
||||
|
||||
Returns:
|
||||
Tuple of (should_continue, action_screenshot_saved)
|
||||
"""
|
||||
action_screenshot_saved = False
|
||||
|
||||
# Helper function to safely add assistant messages using the message manager
|
||||
def add_assistant_message(content):
|
||||
if isinstance(content, str):
|
||||
# Convert string to proper format
|
||||
formatted_content = [{"type": "text", "text": content}]
|
||||
self.message_manager.add_assistant_message(formatted_content)
|
||||
logger.info("Added formatted text assistant message")
|
||||
elif isinstance(content, list):
|
||||
# Already in proper format
|
||||
self.message_manager.add_assistant_message(content)
|
||||
logger.info("Added structured assistant message")
|
||||
else:
|
||||
# Default case - convert to string
|
||||
formatted_content = [{"type": "text", "text": str(content)}]
|
||||
self.message_manager.add_assistant_message(formatted_content)
|
||||
logger.info("Added converted assistant message")
|
||||
|
||||
try:
|
||||
# Step 1: Normalize response to standard format based on provider
|
||||
standard_content = []
|
||||
raw_text = None
|
||||
|
||||
# Convert response to standardized content based on provider
|
||||
if self.provider == LLMProvider.ANTHROPIC:
|
||||
if hasattr(response, "content") and isinstance(response.content, list):
|
||||
# Convert Anthropic response to standard format
|
||||
for block in response.content:
|
||||
if hasattr(block, "type"):
|
||||
if block.type == "text":
|
||||
standard_content.append({"type": "text", "text": block.text})
|
||||
# Store raw text for JSON parsing
|
||||
if raw_text is None:
|
||||
raw_text = block.text
|
||||
else:
|
||||
raw_text += "\n" + block.text
|
||||
else:
|
||||
# Add other block types
|
||||
block_dict = {}
|
||||
for key, value in vars(block).items():
|
||||
if not key.startswith("_"):
|
||||
block_dict[key] = value
|
||||
standard_content.append(block_dict)
|
||||
else:
|
||||
logger.warning("Invalid Anthropic response format")
|
||||
return True, action_screenshot_saved
|
||||
elif self.provider == LLMProvider.OLLAMA:
|
||||
try:
|
||||
raw_text = response["message"]["content"]
|
||||
standard_content = [{"type": "text", "text": raw_text}]
|
||||
except (KeyError, TypeError, IndexError) as e:
|
||||
logger.error(f"Invalid response format: {str(e)}")
|
||||
return True, action_screenshot_saved
|
||||
elif self.provider == LLMProvider.OAICOMPAT:
|
||||
try:
|
||||
# OpenAI-compatible response format
|
||||
raw_text = response["choices"][0]["message"]["content"]
|
||||
standard_content = [{"type": "text", "text": raw_text}]
|
||||
except (KeyError, TypeError, IndexError) as e:
|
||||
logger.error(f"Invalid response format: {str(e)}")
|
||||
return True, action_screenshot_saved
|
||||
else:
|
||||
# Assume OpenAI or compatible format
|
||||
try:
|
||||
raw_text = response["choices"][0]["message"]["content"]
|
||||
standard_content = [{"type": "text", "text": raw_text}]
|
||||
except (KeyError, TypeError, IndexError) as e:
|
||||
logger.error(f"Invalid response format: {str(e)}")
|
||||
return True, action_screenshot_saved
|
||||
|
||||
# Step 2: Add the normalized response to message history
|
||||
add_assistant_message(standard_content)
|
||||
|
||||
# Step 3: Extract JSON from the content for action execution
|
||||
parsed_content = None
|
||||
|
||||
# If we have raw text, try to extract JSON from it
|
||||
if raw_text:
|
||||
# Try different approaches to extract JSON
|
||||
try:
|
||||
# First try to parse the whole content as JSON
|
||||
parsed_content = json.loads(raw_text)
|
||||
logger.info("Successfully parsed whole content as JSON")
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
# Try to find JSON block
|
||||
json_content = extract_data(raw_text, "json")
|
||||
parsed_content = json.loads(json_content)
|
||||
logger.info("Successfully parsed JSON from code block")
|
||||
except (json.JSONDecodeError, IndexError):
|
||||
try:
|
||||
# Look for JSON object pattern
|
||||
import re # Local import to ensure availability
|
||||
|
||||
json_pattern = r"\{[^}]+\}"
|
||||
json_match = re.search(json_pattern, raw_text)
|
||||
if json_match:
|
||||
json_str = json_match.group(0)
|
||||
parsed_content = json.loads(json_str)
|
||||
logger.info("Successfully parsed JSON from text")
|
||||
else:
|
||||
logger.error(f"No JSON found in content")
|
||||
return True, action_screenshot_saved
|
||||
except json.JSONDecodeError as e:
|
||||
# Try to sanitize the JSON string and retry
|
||||
try:
|
||||
# Remove or replace invalid control characters
|
||||
import re # Local import to ensure availability
|
||||
|
||||
sanitized_text = re.sub(r"[\x00-\x1F\x7F]", "", raw_text)
|
||||
# Try parsing again with sanitized text
|
||||
parsed_content = json.loads(sanitized_text)
|
||||
logger.info(
|
||||
"Successfully parsed JSON after sanitizing control characters"
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse JSON from text: {str(e)}")
|
||||
return True, action_screenshot_saved
|
||||
|
||||
# Step 4: Process the parsed content if available
|
||||
if parsed_content:
|
||||
# Clean up Box ID format
|
||||
if "Box ID" in parsed_content and isinstance(parsed_content["Box ID"], str):
|
||||
parsed_content["Box ID"] = parsed_content["Box ID"].replace("Box #", "")
|
||||
|
||||
# Add any explanatory text as reasoning if not present
|
||||
if "Explanation" not in parsed_content and raw_text:
|
||||
# Extract any text before the JSON as reasoning
|
||||
text_before_json = raw_text.split("{")[0].strip()
|
||||
if text_before_json:
|
||||
parsed_content["Explanation"] = text_before_json
|
||||
|
||||
# Log the parsed content for debugging
|
||||
logger.info(f"Parsed content: {json.dumps(parsed_content, indent=2)}")
|
||||
|
||||
# Step 5: Execute the action
|
||||
try:
|
||||
# Execute action using the common helper method
|
||||
should_continue, action_screenshot_saved = (
|
||||
await self._execute_action_with_tools(
|
||||
parsed_content, cast(ParseResult, parsed_screen)
|
||||
)
|
||||
)
|
||||
|
||||
# Check if task is complete
|
||||
if parsed_content.get("Action") == "None":
|
||||
return False, action_screenshot_saved
|
||||
return should_continue, action_screenshot_saved
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing action: {str(e)}")
|
||||
# Update the last assistant message with error
|
||||
error_message = [{"type": "text", "text": f"Error executing action: {str(e)}"}]
|
||||
# Replace the last assistant message with the error
|
||||
self.message_manager.add_assistant_message(error_message)
|
||||
return False, action_screenshot_saved
|
||||
|
||||
return True, action_screenshot_saved
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling response: {str(e)}")
|
||||
# Add error message using the message manager
|
||||
error_message = [{"type": "text", "text": f"Error: {str(e)}"}]
|
||||
self.message_manager.add_assistant_message(error_message)
|
||||
raise
|
||||
|
||||
###########################################
|
||||
# SCREEN PARSING - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def _get_parsed_screen_som(self, save_screenshot: bool = True) -> ParseResult:
|
||||
"""Get parsed screen information with Screen Object Model.
|
||||
|
||||
Extends the base class method to use the OmniParser to parse the screen
|
||||
and extract UI elements.
|
||||
|
||||
Args:
|
||||
save_screenshot: Whether to save the screenshot (set to False when screenshots will be saved elsewhere)
|
||||
|
||||
Returns:
|
||||
ParseResult containing screen information and elements
|
||||
"""
|
||||
try:
|
||||
# Use the parser's parse_screen method which handles the screenshot internally
|
||||
parsed_screen = await self.parser.parse_screen(computer=self.computer)
|
||||
|
||||
# Log information about the parsed results
|
||||
logger.info(
|
||||
f"Parsed screen with {len(parsed_screen.elements) if parsed_screen.elements else 0} elements"
|
||||
)
|
||||
|
||||
# Save screenshot if requested and if we have image data
|
||||
if save_screenshot and self.save_trajectory and parsed_screen.annotated_image_base64:
|
||||
try:
|
||||
# Extract just the image data (remove data:image/png;base64, prefix)
|
||||
img_data = parsed_screen.annotated_image_base64
|
||||
if "," in img_data:
|
||||
img_data = img_data.split(",")[1]
|
||||
|
||||
# Process screenshot through hooks and save if needed
|
||||
await self.handle_screenshot(img_data, action_type="state", parsed_screen=parsed_screen)
|
||||
|
||||
# Save with a generic "state" action type to indicate this is the current screen state
|
||||
self._save_screenshot(img_data, action_type="state")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving screenshot: {str(e)}")
|
||||
|
||||
return parsed_screen
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting parsed screen: {str(e)}")
|
||||
raise
|
||||
|
||||
def _get_system_prompt(self) -> str:
|
||||
"""Get the system prompt for the model."""
|
||||
return SYSTEM_PROMPT
|
||||
|
||||
###########################################
|
||||
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
messages: List of messages in standard OpenAI format
|
||||
|
||||
Yields:
|
||||
Agent response format
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting OmniLoop run with {len(messages)} messages")
|
||||
|
||||
# Initialize the message manager with the provided messages
|
||||
self.message_manager.messages = messages.copy()
|
||||
|
||||
# Create queue for response streaming
|
||||
queue = asyncio.Queue()
|
||||
|
||||
# Start loop in background task
|
||||
self.loop_task = asyncio.create_task(self._run_loop(queue, messages))
|
||||
|
||||
# Process and yield messages as they arrive
|
||||
while True:
|
||||
try:
|
||||
item = await queue.get()
|
||||
if item is None: # Stop signal
|
||||
break
|
||||
yield item
|
||||
queue.task_done()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing queue item: {str(e)}")
|
||||
continue
|
||||
|
||||
# Wait for loop to complete
|
||||
await self.loop_task
|
||||
|
||||
# Send completion message
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": "Task completed successfully.",
|
||||
"metadata": {"title": "✅ Complete"},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run method: {str(e)}")
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
|
||||
async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Internal method to run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
queue: Queue to put responses into
|
||||
messages: List of messages in standard OpenAI format
|
||||
"""
|
||||
# Continue running until explicitly told to stop
|
||||
running = True
|
||||
turn_created = False
|
||||
# Track if an action-specific screenshot has been saved this turn
|
||||
action_screenshot_saved = False
|
||||
|
||||
attempt = 0
|
||||
max_attempts = 3
|
||||
|
||||
while running and attempt < max_attempts:
|
||||
try:
|
||||
# Create a new turn directory if it's not already created
|
||||
if not turn_created:
|
||||
self._create_turn_dir()
|
||||
turn_created = True
|
||||
|
||||
# Ensure client is initialized
|
||||
if self.client is None:
|
||||
logger.info("Initializing client...")
|
||||
await self.initialize_client()
|
||||
if self.client is None:
|
||||
raise RuntimeError("Failed to initialize client")
|
||||
logger.info("Client initialized successfully")
|
||||
|
||||
# Get up-to-date screen information
|
||||
parsed_screen = await self._get_parsed_screen_som()
|
||||
|
||||
# Process screen info and update messages in standard format
|
||||
try:
|
||||
# Get image from parsed screen
|
||||
image = parsed_screen.annotated_image_base64 or None
|
||||
if image:
|
||||
# Save elements as JSON if we have a turn directory
|
||||
if self.current_turn_dir and hasattr(parsed_screen, "elements"):
|
||||
elements_path = os.path.join(self.current_turn_dir, "elements.json")
|
||||
with open(elements_path, "w") as f:
|
||||
# Convert elements to dicts for JSON serialization
|
||||
elements_json = [
|
||||
elem.model_dump() for elem in parsed_screen.elements
|
||||
]
|
||||
json.dump(elements_json, f, indent=2)
|
||||
logger.info(f"Saved elements to {elements_path}")
|
||||
|
||||
# Remove data URL prefix if present
|
||||
if "," in image:
|
||||
image = image.split(",")[1]
|
||||
|
||||
# Add screenshot to message history using message manager
|
||||
self.message_manager.add_user_message(
|
||||
[
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{image}"},
|
||||
}
|
||||
]
|
||||
)
|
||||
logger.info("Added screenshot to message history")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing screen info: {str(e)}")
|
||||
raise
|
||||
|
||||
# Get system prompt
|
||||
system_prompt = self._get_system_prompt()
|
||||
|
||||
# Make API call with retries using the APIHandler
|
||||
response = await self.api_handler.make_api_call(
|
||||
self.message_manager.messages, system_prompt
|
||||
)
|
||||
|
||||
# Handle the response (may execute actions)
|
||||
# Returns: (should_continue, action_screenshot_saved)
|
||||
should_continue, new_screenshot_saved = await self._handle_response(
|
||||
response, self.message_manager.messages, parsed_screen
|
||||
)
|
||||
|
||||
# Update whether an action screenshot was saved this turn
|
||||
action_screenshot_saved = action_screenshot_saved or new_screenshot_saved
|
||||
|
||||
# Create OpenAI-compatible response format using utility function
|
||||
openai_compatible_response = await to_openai_agent_response_format(
|
||||
response=response,
|
||||
messages=self.message_manager.messages,
|
||||
model=self.model,
|
||||
parsed_screen=parsed_screen,
|
||||
parser=self.parser
|
||||
)
|
||||
# Log standardized response for ease of parsing
|
||||
self._log_api_call("agent_response", request=None, response=openai_compatible_response)
|
||||
|
||||
# Put the response in the queue
|
||||
await queue.put(openai_compatible_response)
|
||||
|
||||
# Check if we should continue this conversation
|
||||
running = should_continue
|
||||
|
||||
# Create a new turn directory if we're continuing
|
||||
if running:
|
||||
turn_created = False
|
||||
|
||||
# Reset attempt counter on success
|
||||
attempt = 0
|
||||
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
error_msg = f"Error in _run_loop method (attempt {attempt}/{max_attempts}): {str(e)}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# If this is our last attempt, provide more info about the error
|
||||
if attempt >= max_attempts:
|
||||
logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}")
|
||||
|
||||
await queue.put({
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
})
|
||||
|
||||
# Create a brief delay before retrying
|
||||
await asyncio.sleep(1)
|
||||
finally:
|
||||
# Signal that we're done
|
||||
await queue.put(None)
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel the currently running agent loop task.
|
||||
|
||||
This method stops the ongoing processing in the agent loop
|
||||
by cancelling the loop_task if it exists and is running.
|
||||
"""
|
||||
if self.loop_task and not self.loop_task.done():
|
||||
logger.info("Cancelling Omni loop task")
|
||||
self.loop_task.cancel()
|
||||
try:
|
||||
# Wait for the task to be cancelled with a timeout
|
||||
await asyncio.wait_for(self.loop_task, timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout while waiting for loop task to cancel")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Loop task cancelled successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error while cancelling loop task: {str(e)}")
|
||||
finally:
|
||||
logger.info("Omni loop task cancelled")
|
||||
else:
|
||||
logger.info("No active Omni loop task to cancel")
|
||||
|
||||
async def process_model_response(self, response_text: str) -> Optional[Dict[str, Any]]:
|
||||
"""Process model response to extract tool calls.
|
||||
|
||||
Args:
|
||||
response_text: Model response text
|
||||
|
||||
Returns:
|
||||
Extracted tool information, or None if no tool call was found
|
||||
"""
|
||||
try:
|
||||
# Ensure tools are initialized before use
|
||||
await self._ensure_tools_initialized()
|
||||
|
||||
# Look for tool use in the response
|
||||
if "function_call" in response_text or "tool_use" in response_text:
|
||||
# The extract_tool_call method should be implemented in the OmniAPIHandler
|
||||
# For now, we'll just use a simple approach
|
||||
# This will be replaced with the proper implementation
|
||||
tool_info = None
|
||||
if "function_call" in response_text:
|
||||
# Extract function call params
|
||||
try:
|
||||
# Simple extraction - in real code this would be more robust
|
||||
import json
|
||||
import re
|
||||
|
||||
match = re.search(r'"function_call"\s*:\s*{([^}]+)}', response_text)
|
||||
if match:
|
||||
function_text = "{" + match.group(1) + "}"
|
||||
tool_info = json.loads(function_text)
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting function call: {str(e)}")
|
||||
|
||||
if tool_info:
|
||||
try:
|
||||
# Execute the tool
|
||||
result = await self.tool_manager.execute_tool(
|
||||
name=tool_info.get("name"), tool_input=tool_info.get("arguments", {})
|
||||
)
|
||||
# Handle the result
|
||||
return {"tool_result": result}
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"Error executing tool '{tool_info.get('name', 'unknown')}': {str(e)}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
return {"tool_result": ToolResult(error=error_msg)}
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing tool call: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
async def process_response_with_tools(
|
||||
self, response_text: str, parsed_screen: Optional[ParseResult] = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""Process model response and execute tools.
|
||||
|
||||
Args:
|
||||
response_text: Model response text
|
||||
parsed_screen: Current parsed screen information (optional)
|
||||
|
||||
Returns:
|
||||
Tuple of (action_taken, observation)
|
||||
"""
|
||||
logger.info("Processing response with tools")
|
||||
|
||||
# Process the response to extract tool calls
|
||||
tool_result = await self.process_model_response(response_text)
|
||||
|
||||
if tool_result and "tool_result" in tool_result:
|
||||
# A tool was executed
|
||||
result = tool_result["tool_result"]
|
||||
if result.error:
|
||||
return False, f"ERROR: {result.error}"
|
||||
else:
|
||||
return True, result.output or "Tool executed successfully"
|
||||
|
||||
# No action or tool call found
|
||||
return False, "No action taken - no tool call detected in response"
|
||||
|
||||
###########################################
|
||||
# UTILITY METHODS
|
||||
###########################################
|
||||
|
||||
async def _ensure_tools_initialized(self) -> None:
|
||||
"""Ensure the tool manager and tools are initialized before use."""
|
||||
if not hasattr(self.tool_manager, "tools") or self.tool_manager.tools is None:
|
||||
logger.info("Tools not initialized. Initializing now...")
|
||||
await self.tool_manager.initialize()
|
||||
logger.info("Tools initialized successfully.")
|
||||
|
||||
async def _execute_action_with_tools(
|
||||
self, action_data: Dict[str, Any], parsed_screen: ParseResult
|
||||
) -> Tuple[bool, bool]:
|
||||
"""Execute an action using the tools-based approach.
|
||||
|
||||
Args:
|
||||
action_data: Dictionary containing action details
|
||||
parsed_screen: Current parsed screen information
|
||||
|
||||
Returns:
|
||||
Tuple of (should_continue, action_screenshot_saved)
|
||||
"""
|
||||
action_screenshot_saved = False
|
||||
action_type = None # Initialize for possible use in post-action screenshot
|
||||
|
||||
try:
|
||||
# Extract the action
|
||||
parsed_action = action_data.get("Action", "").lower()
|
||||
|
||||
# Only process if we have a valid action
|
||||
if not parsed_action or parsed_action == "none":
|
||||
return False, action_screenshot_saved
|
||||
|
||||
# Convert the parsed content to a format suitable for the tools system
|
||||
tool_name = "computer" # Default to computer tool
|
||||
tool_args = {"action": parsed_action}
|
||||
|
||||
# Add specific arguments based on action type
|
||||
if parsed_action in ["left_click", "right_click", "double_click", "move_cursor"]:
|
||||
# Calculate coordinates from Box ID using parser
|
||||
try:
|
||||
box_id = int(action_data["Box ID"])
|
||||
x, y = await self.parser.calculate_click_coordinates(
|
||||
box_id, cast(ParseResult, parsed_screen)
|
||||
)
|
||||
tool_args["x"] = x
|
||||
tool_args["y"] = y
|
||||
|
||||
# Visualize action if screenshot is available
|
||||
if parsed_screen and parsed_screen.annotated_image_base64:
|
||||
img_data = parsed_screen.annotated_image_base64
|
||||
# Remove data URL prefix if present
|
||||
if img_data.startswith("data:image"):
|
||||
img_data = img_data.split(",")[1]
|
||||
# Save visualization for coordinate-based actions
|
||||
self.viz_helper.visualize_action(x, y, img_data)
|
||||
action_screenshot_saved = True
|
||||
|
||||
except (ValueError, KeyError) as e:
|
||||
logger.error(f"Error processing Box ID: {str(e)}")
|
||||
return False, action_screenshot_saved
|
||||
|
||||
elif parsed_action == "type_text":
|
||||
tool_args["text"] = action_data.get("Value", "")
|
||||
# For type_text, store the value in the action type for screenshot naming
|
||||
action_type = f"type_{tool_args['text'][:20]}" # Truncate if too long
|
||||
|
||||
elif parsed_action == "press_key":
|
||||
tool_args["key"] = action_data.get("Value", "")
|
||||
action_type = f"press_{tool_args['key']}"
|
||||
|
||||
elif parsed_action == "hotkey":
|
||||
value = action_data.get("Value", "")
|
||||
if isinstance(value, list):
|
||||
tool_args["keys"] = value
|
||||
action_type = f"hotkey_{'_'.join(value)}"
|
||||
else:
|
||||
# Split string format like "command+space" into a list
|
||||
keys = [k.strip() for k in value.lower().split("+")]
|
||||
tool_args["keys"] = keys
|
||||
action_type = f"hotkey_{value.replace('+', '_')}"
|
||||
|
||||
elif parsed_action in ["scroll_down", "scroll_up"]:
|
||||
clicks = int(action_data.get("amount", 1))
|
||||
tool_args["amount"] = clicks
|
||||
action_type = f"scroll_{parsed_action.split('_')[1]}_{clicks}"
|
||||
|
||||
# Visualize scrolling if screenshot is available
|
||||
if parsed_screen and parsed_screen.annotated_image_base64:
|
||||
img_data = parsed_screen.annotated_image_base64
|
||||
# Remove data URL prefix if present
|
||||
if img_data.startswith("data:image"):
|
||||
img_data = img_data.split(",")[1]
|
||||
direction = "down" if parsed_action == "scroll_down" else "up"
|
||||
# For scrolling, we save the visualization
|
||||
self.viz_helper.visualize_scroll(direction, clicks, img_data)
|
||||
action_screenshot_saved = True
|
||||
|
||||
# Ensure tools are initialized before use
|
||||
await self._ensure_tools_initialized()
|
||||
|
||||
# Execute tool with prepared arguments
|
||||
result = await self.tool_manager.execute_tool(name=tool_name, tool_input=tool_args)
|
||||
|
||||
# Take a new screenshot after the action if we haven't already saved one
|
||||
if not action_screenshot_saved:
|
||||
try:
|
||||
# Get a new screenshot after the action
|
||||
new_parsed_screen = await self._get_parsed_screen_som(save_screenshot=False)
|
||||
if new_parsed_screen and new_parsed_screen.annotated_image_base64:
|
||||
img_data = new_parsed_screen.annotated_image_base64
|
||||
# Remove data URL prefix if present
|
||||
if img_data.startswith("data:image"):
|
||||
img_data = img_data.split(",")[1]
|
||||
# Save with action type if defined, otherwise use the action name
|
||||
if action_type:
|
||||
self._save_screenshot(img_data, action_type=action_type)
|
||||
else:
|
||||
self._save_screenshot(img_data, action_type=parsed_action)
|
||||
action_screenshot_saved = True
|
||||
except Exception as screenshot_error:
|
||||
logger.error(f"Error taking post-action screenshot: {str(screenshot_error)}")
|
||||
|
||||
# Continue the loop if the action is not "None"
|
||||
return True, action_screenshot_saved
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing action: {str(e)}")
|
||||
# Update the last assistant message with error
|
||||
error_message = [{"type": "text", "text": f"Error executing action: {str(e)}"}]
|
||||
# Replace the last assistant message with the error
|
||||
self.message_manager.add_assistant_message(error_message)
|
||||
return False, action_screenshot_saved
|
||||
@@ -1,307 +0,0 @@
|
||||
"""Parser implementation for the Omni provider."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import base64
|
||||
import torch
|
||||
|
||||
# Import from the SOM package
|
||||
from som import OmniParser as OmniDetectParser
|
||||
from som.models import ParseResult, ParserMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OmniParser:
|
||||
"""Parser for handling responses from multiple providers."""
|
||||
|
||||
# Class-level shared OmniDetectParser instance
|
||||
_shared_parser = None
|
||||
|
||||
def __init__(self, force_device: Optional[str] = None):
|
||||
"""Initialize the OmniParser.
|
||||
|
||||
Args:
|
||||
force_device: Optional device to force for detection (cpu/cuda/mps)
|
||||
"""
|
||||
self.response_buffer = []
|
||||
|
||||
# Use shared parser if available, otherwise create a new one
|
||||
if OmniParser._shared_parser is None:
|
||||
logger.info("Initializing shared OmniDetectParser...")
|
||||
|
||||
# Determine the best device to use
|
||||
device = force_device
|
||||
if not device:
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif (
|
||||
hasattr(torch, "backends")
|
||||
and hasattr(torch.backends, "mps")
|
||||
and torch.backends.mps.is_available()
|
||||
):
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
logger.info(f"Using device: {device} for OmniDetectParser")
|
||||
self.detect_parser = OmniDetectParser(force_device=device)
|
||||
|
||||
# Preload the detection model to avoid repeated loading
|
||||
try:
|
||||
# Access the detector to trigger model loading
|
||||
detector = self.detect_parser.detector
|
||||
if detector.model is None:
|
||||
logger.info("Preloading detection model...")
|
||||
detector.load_model()
|
||||
logger.info("Detection model preloaded successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error preloading detection model: {str(e)}")
|
||||
|
||||
# Store as shared instance
|
||||
OmniParser._shared_parser = self.detect_parser
|
||||
else:
|
||||
logger.info("Using existing shared OmniDetectParser")
|
||||
self.detect_parser = OmniParser._shared_parser
|
||||
|
||||
async def parse_screen(self, computer: Any) -> ParseResult:
|
||||
"""Parse a screenshot and extract screen information.
|
||||
|
||||
Args:
|
||||
computer: Computer instance
|
||||
|
||||
Returns:
|
||||
ParseResult with screen elements and image data
|
||||
"""
|
||||
try:
|
||||
# Get screenshot from computer
|
||||
logger.info("Taking screenshot...")
|
||||
screenshot = await computer.interface.screenshot()
|
||||
|
||||
# Log screenshot info
|
||||
logger.info(f"Screenshot type: {type(screenshot)}")
|
||||
logger.info(f"Screenshot is bytes: {isinstance(screenshot, bytes)}")
|
||||
logger.info(f"Screenshot is str: {isinstance(screenshot, str)}")
|
||||
logger.info(f"Screenshot length: {len(screenshot) if screenshot else 0}")
|
||||
|
||||
# If screenshot is a string (likely base64), convert it to bytes
|
||||
if isinstance(screenshot, str):
|
||||
try:
|
||||
screenshot = base64.b64decode(screenshot)
|
||||
logger.info("Successfully converted base64 string to bytes")
|
||||
logger.info(f"Decoded bytes length: {len(screenshot)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding base64: {str(e)}")
|
||||
logger.error(f"First 100 chars of screenshot string: {screenshot[:100]}")
|
||||
|
||||
# Pass screenshot to OmniDetectParser
|
||||
logger.info("Passing screenshot to OmniDetectParser...")
|
||||
parse_result = self.detect_parser.parse(
|
||||
screenshot_data=screenshot, box_threshold=0.3, iou_threshold=0.1, use_ocr=True
|
||||
)
|
||||
logger.info("Screenshot parsed successfully")
|
||||
logger.info(f"Parse result has {len(parse_result.elements)} elements")
|
||||
|
||||
# Log element IDs for debugging
|
||||
for i, elem in enumerate(parse_result.elements):
|
||||
logger.info(
|
||||
f"Element {i+1} (ID: {elem.id}): {elem.type} with confidence {elem.confidence:.3f}"
|
||||
)
|
||||
|
||||
return parse_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing screen: {str(e)}")
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Create a minimal valid result for error cases
|
||||
return ParseResult(
|
||||
elements=[],
|
||||
screen_info=None,
|
||||
annotated_image_base64="",
|
||||
parsed_content_list=[{"error": str(e)}],
|
||||
metadata=ParserMetadata(
|
||||
image_size=(0, 0),
|
||||
num_icons=0,
|
||||
num_text=0,
|
||||
device="cpu",
|
||||
ocr_enabled=False,
|
||||
latency=0.0,
|
||||
),
|
||||
)
|
||||
|
||||
def parse_tool_call(self, response: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""Parse a tool call from the response.
|
||||
|
||||
Args:
|
||||
response: Response from the provider
|
||||
|
||||
Returns:
|
||||
Parsed tool call or None if no tool call found
|
||||
"""
|
||||
try:
|
||||
# Handle Anthropic format
|
||||
if "tool_calls" in response:
|
||||
tool_call = response["tool_calls"][0]
|
||||
return {
|
||||
"name": tool_call["function"]["name"],
|
||||
"arguments": tool_call["function"]["arguments"],
|
||||
}
|
||||
|
||||
# Handle OpenAI format
|
||||
if "function_call" in response:
|
||||
return {
|
||||
"name": response["function_call"]["name"],
|
||||
"arguments": response["function_call"]["arguments"],
|
||||
}
|
||||
|
||||
# Handle Groq format (OpenAI-compatible)
|
||||
if "choices" in response and response["choices"]:
|
||||
choice = response["choices"][0]
|
||||
if "function_call" in choice:
|
||||
return {
|
||||
"name": choice["function_call"]["name"],
|
||||
"arguments": choice["function_call"]["arguments"],
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing tool call: {str(e)}")
|
||||
return None
|
||||
|
||||
def parse_response(self, response: Dict[str, Any]) -> Tuple[str, Dict[str, Any]]:
|
||||
"""Parse a response from any provider.
|
||||
|
||||
Args:
|
||||
response: Response from the provider
|
||||
|
||||
Returns:
|
||||
Tuple of (content, metadata)
|
||||
"""
|
||||
try:
|
||||
content = ""
|
||||
metadata = {}
|
||||
|
||||
# Handle Anthropic format
|
||||
if "content" in response and isinstance(response["content"], list):
|
||||
for item in response["content"]:
|
||||
if item["type"] == "text":
|
||||
content += item["text"]
|
||||
|
||||
# Handle OpenAI format
|
||||
elif "choices" in response and response["choices"]:
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
|
||||
# Handle direct content
|
||||
elif isinstance(response.get("content"), str):
|
||||
content = response["content"]
|
||||
|
||||
# Extract metadata if present
|
||||
if "metadata" in response:
|
||||
metadata = response["metadata"]
|
||||
|
||||
return content, metadata
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing response: {str(e)}")
|
||||
return str(e), {"error": True}
|
||||
|
||||
def format_for_provider(
|
||||
self, messages: List[Dict[str, Any]], provider: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Format messages for a specific provider.
|
||||
|
||||
Args:
|
||||
messages: List of messages to format
|
||||
provider: Provider to format for
|
||||
|
||||
Returns:
|
||||
Formatted messages
|
||||
"""
|
||||
try:
|
||||
formatted = []
|
||||
|
||||
for msg in messages:
|
||||
formatted_msg = {"role": msg["role"]}
|
||||
|
||||
# Handle content formatting
|
||||
if isinstance(msg["content"], list):
|
||||
# For providers that support multimodal
|
||||
if provider in ["anthropic", "openai"]:
|
||||
formatted_msg["content"] = msg["content"]
|
||||
else:
|
||||
# Extract text only for other providers
|
||||
text_content = next(
|
||||
(item["text"] for item in msg["content"] if item["type"] == "text"), ""
|
||||
)
|
||||
formatted_msg["content"] = text_content
|
||||
else:
|
||||
formatted_msg["content"] = msg["content"]
|
||||
|
||||
formatted.append(formatted_msg)
|
||||
|
||||
return formatted
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error formatting messages: {str(e)}")
|
||||
return messages # Return original messages on error
|
||||
|
||||
async def calculate_click_coordinates(
|
||||
self, box_id: int, parsed_screen: ParseResult
|
||||
) -> Tuple[int, int]:
|
||||
"""Calculate click coordinates based on box ID.
|
||||
|
||||
Args:
|
||||
box_id: The ID of the box to click
|
||||
parsed_screen: The parsed screen information
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates
|
||||
|
||||
Raises:
|
||||
ValueError: If box_id is invalid or missing from parsed screen
|
||||
"""
|
||||
# First try to use structured elements data
|
||||
logger.info(f"Elements count: {len(parsed_screen.elements)}")
|
||||
|
||||
# Try to find element with matching ID
|
||||
for element in parsed_screen.elements:
|
||||
if element.id == box_id:
|
||||
logger.info(f"Found element with ID {box_id}: {element}")
|
||||
bbox = element.bbox
|
||||
|
||||
# Get screen dimensions from the metadata if available, or fallback
|
||||
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
|
||||
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
|
||||
logger.info(f"Screen dimensions: width={width}, height={height}")
|
||||
|
||||
# Create a dictionary from the element's bbox for calculate_element_center
|
||||
bbox_dict = {"x1": bbox.x1, "y1": bbox.y1, "x2": bbox.x2, "y2": bbox.y2}
|
||||
from ...core.visualization import calculate_element_center
|
||||
|
||||
center_x, center_y = calculate_element_center(bbox_dict, width, height)
|
||||
logger.info(f"Calculated center: ({center_x}, {center_y})")
|
||||
|
||||
# Validate coordinates - if they're (0,0) or unreasonably small,
|
||||
# use a default position in the center of the screen
|
||||
if center_x == 0 and center_y == 0:
|
||||
logger.warning("Got (0,0) coordinates, using fallback position")
|
||||
center_x = width // 2
|
||||
center_y = height // 2
|
||||
logger.info(f"Using fallback center: ({center_x}, {center_y})")
|
||||
|
||||
return center_x, center_y
|
||||
|
||||
# If we couldn't find the box, use center of screen
|
||||
logger.error(
|
||||
f"Box ID {box_id} not found in structured elements (count={len(parsed_screen.elements)})"
|
||||
)
|
||||
|
||||
# Use center of screen as fallback
|
||||
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
|
||||
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
|
||||
logger.warning(f"Using fallback position in center of screen ({width//2}, {height//2})")
|
||||
return width // 2, height // 2
|
||||
@@ -1,64 +0,0 @@
|
||||
"""Prompts for the Omni agent."""
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are using a macOS device.
|
||||
You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
|
||||
|
||||
You may be given some history plan and actions, this is the response from the previous loop.
|
||||
You should carefully consider your plan base on the task, screenshot, and history actions.
|
||||
|
||||
Your available "Next Action" only include:
|
||||
- type_text: types a string of text.
|
||||
- left_click: move mouse to box id and left clicks.
|
||||
- right_click: move mouse to box id and right clicks.
|
||||
- double_click: move mouse to box id and double clicks.
|
||||
- move_cursor: move mouse to box id.
|
||||
- scroll_up: scrolls the screen up to view previous content.
|
||||
- scroll_down: scrolls the screen down, when the desired button is not visible, or you need to see more content.
|
||||
- hotkey: press a sequence of keys.
|
||||
- wait: waits for 1 second for the device to load or respond.
|
||||
|
||||
Based on the visual information from the screenshot image and the detected bounding boxes, please determine the next action, the Box ID you should operate on (if action is one of 'type', 'hover', 'scroll_up', 'scroll_down', 'wait', there should be no Box ID field), and the value (if the action is 'type') in order to complete the task.
|
||||
|
||||
Output format:
|
||||
{
|
||||
"Explanation": str, # describe what is in the current screen, taking into account the history, then describe your step-by-step thoughts on how to achieve the task, choose one action from available actions at a time.
|
||||
"Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely.
|
||||
"Box ID": n,
|
||||
"Value": "xxx" # only provide value field if the action is type, else don't include value key
|
||||
}
|
||||
|
||||
One Example:
|
||||
{
|
||||
"Explanation": "The current screen shows google result of amazon, in previous action I have searched amazon on google. Then I need to click on the first search results to go to amazon.com.",
|
||||
"Action": "left_click",
|
||||
"Box ID": 4
|
||||
}
|
||||
|
||||
Another Example:
|
||||
{
|
||||
"Explanation": "The current screen shows the front page of amazon. There is no previous action. Therefore I need to type "Apple watch" in the search bar.",
|
||||
"Action": "type_text",
|
||||
"Box ID": 2,
|
||||
"Value": "Apple watch"
|
||||
}
|
||||
|
||||
Another Example:
|
||||
{
|
||||
"Explanation": "I am starting a Spotlight search to find the Safari browser.",
|
||||
"Action": "hotkey",
|
||||
"Value": "command+space"
|
||||
}
|
||||
|
||||
IMPORTANT NOTES:
|
||||
1. You should only give a single action at a time.
|
||||
2. The Box ID is the id of the element you should operate on, it is a number. Its background color corresponds to the color of the bounding box of the element.
|
||||
3. You should give an analysis to the current screen, and reflect on what has been done by looking at the history, then describe your step-by-step thoughts on how to achieve the task.
|
||||
4. Attach the next action prediction in the "Action" field.
|
||||
5. For starting applications, always use the "hotkey" action with command+space for starting a Spotlight search.
|
||||
6. When the task is completed, don't complete additional actions. You should say "Action": "None" in the json field.
|
||||
7. The tasks involve buying multiple products or navigating through multiple pages. You should break it into subgoals and complete each subgoal one by one in the order of the instructions.
|
||||
8. Avoid choosing the same action/elements multiple times in a row, if it happens, reflect to yourself, what may have gone wrong, and predict a different action.
|
||||
9. Reflect whether the element is clickable or not, for example reflect if it is an hyperlink or a button or a normal text.
|
||||
10. If you are prompted with login information page or captcha page, or you think it need user's permission to do the next action, you should say "Action": "None" in the json field.
|
||||
"""
|
||||
@@ -1,30 +0,0 @@
|
||||
"""Omni provider tools - compatible with multiple LLM providers."""
|
||||
|
||||
from ....core.tools import BaseTool, ToolResult, ToolError, ToolFailure, CLIResult
|
||||
from .base import BaseOmniTool
|
||||
from .computer import ComputerTool
|
||||
from .bash import BashTool
|
||||
from .manager import ToolManager
|
||||
|
||||
# Re-export the tools with Omni-specific names for backward compatibility
|
||||
OmniToolResult = ToolResult
|
||||
OmniToolError = ToolError
|
||||
OmniToolFailure = ToolFailure
|
||||
OmniCLIResult = CLIResult
|
||||
|
||||
# We'll export specific tools once implemented
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"BaseOmniTool",
|
||||
"ToolResult",
|
||||
"ToolError",
|
||||
"ToolFailure",
|
||||
"CLIResult",
|
||||
"OmniToolResult",
|
||||
"OmniToolError",
|
||||
"OmniToolFailure",
|
||||
"OmniCLIResult",
|
||||
"ComputerTool",
|
||||
"BashTool",
|
||||
"ToolManager",
|
||||
]
|
||||
@@ -1,29 +0,0 @@
|
||||
"""Omni-specific tool base classes."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Dict
|
||||
|
||||
from ....core.tools.base import BaseTool
|
||||
|
||||
|
||||
class BaseOmniTool(BaseTool, metaclass=ABCMeta):
|
||||
"""Abstract base class for Omni provider tools."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the base Omni tool."""
|
||||
# No specific initialization needed yet, but included for future extensibility
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> Any:
|
||||
"""Executes the tool with the given arguments."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to Omni provider-specific API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters for the specific API
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,74 +0,0 @@
|
||||
"""Bash tool for Omni provider."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from computer import Computer
|
||||
from ....core.tools import ToolResult, ToolError
|
||||
from .base import BaseOmniTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BashTool(BaseOmniTool):
|
||||
"""Tool for executing bash commands."""
|
||||
|
||||
name = "bash"
|
||||
description = "Execute bash commands on the system"
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the bash tool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance
|
||||
"""
|
||||
super().__init__()
|
||||
self.computer = computer
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "The bash command to execute",
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Execute bash command.
|
||||
|
||||
Args:
|
||||
**kwargs: Command parameters
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
try:
|
||||
command = kwargs.get("command", "")
|
||||
if not command:
|
||||
return ToolResult(error="No command specified")
|
||||
|
||||
# The true implementation would use the actual method to run terminal commands
|
||||
# Since we're getting linter errors, we'll just implement a placeholder that will
|
||||
# be replaced with the correct implementation when this tool is fully integrated
|
||||
logger.info(f"Would execute command: {command}")
|
||||
return ToolResult(output=f"Command executed (placeholder): {command}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in bash tool: {str(e)}")
|
||||
return ToolResult(error=f"Error: {str(e)}")
|
||||
@@ -1,179 +0,0 @@
|
||||
"""Computer tool for Omni provider."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
import json
|
||||
|
||||
from computer import Computer
|
||||
from ....core.tools import ToolResult, ToolError
|
||||
from .base import BaseOmniTool
|
||||
from ..parser import ParseResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ComputerTool(BaseOmniTool):
|
||||
"""Tool for interacting with the computer UI."""
|
||||
|
||||
name = "computer"
|
||||
description = "Interact with the computer's graphical user interface"
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the computer tool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance
|
||||
"""
|
||||
super().__init__()
|
||||
self.computer = computer
|
||||
# Default to standard screen dimensions (will be set more accurately during initialization)
|
||||
self.screen_dimensions = {"width": 1440, "height": 900}
|
||||
|
||||
async def initialize_dimensions(self) -> None:
|
||||
"""Initialize screen dimensions."""
|
||||
# For now, we'll use default values
|
||||
# In the future, we can implement proper screen dimension detection
|
||||
logger.info(f"Using default screen dimensions: {self.screen_dimensions}")
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"left_click",
|
||||
"right_click",
|
||||
"double_click",
|
||||
"move_cursor",
|
||||
"drag_to",
|
||||
"type_text",
|
||||
"press_key",
|
||||
"hotkey",
|
||||
"scroll_up",
|
||||
"scroll_down",
|
||||
],
|
||||
"description": "The action to perform",
|
||||
},
|
||||
"x": {
|
||||
"type": "number",
|
||||
"description": "X coordinate for click or cursor movement",
|
||||
},
|
||||
"y": {
|
||||
"type": "number",
|
||||
"description": "Y coordinate for click or cursor movement",
|
||||
},
|
||||
"box_id": {
|
||||
"type": "integer",
|
||||
"description": "ID of the UI element to interact with",
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "Text to type",
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Key to press",
|
||||
},
|
||||
"keys": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Keys to press as hotkey combination",
|
||||
},
|
||||
"amount": {
|
||||
"type": "integer",
|
||||
"description": "Amount to scroll",
|
||||
},
|
||||
"duration": {
|
||||
"type": "number",
|
||||
"description": "Duration for drag operations",
|
||||
},
|
||||
},
|
||||
"required": ["action"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
async def __call__(self, **kwargs) -> ToolResult:
|
||||
"""Execute computer action.
|
||||
|
||||
Args:
|
||||
**kwargs: Action parameters
|
||||
|
||||
Returns:
|
||||
Tool execution result
|
||||
"""
|
||||
try:
|
||||
action = kwargs.get("action", "").lower()
|
||||
if not action:
|
||||
return ToolResult(error="No action specified")
|
||||
|
||||
# Execute the action on the computer
|
||||
method = getattr(self.computer.interface, action, None)
|
||||
if not method:
|
||||
return ToolResult(error=f"Unsupported action: {action}")
|
||||
|
||||
# Prepare arguments based on action type
|
||||
args = {}
|
||||
if action in ["left_click", "right_click", "double_click", "move_cursor"]:
|
||||
x = kwargs.get("x")
|
||||
y = kwargs.get("y")
|
||||
if x is None or y is None:
|
||||
box_id = kwargs.get("box_id")
|
||||
if box_id is None:
|
||||
return ToolResult(error="Box ID or coordinates required")
|
||||
# Get coordinates from box_id implementation would be here
|
||||
# For now, return error
|
||||
return ToolResult(error="Box ID-based clicking not implemented yet")
|
||||
args["x"] = x
|
||||
args["y"] = y
|
||||
elif action == "drag_to":
|
||||
x = kwargs.get("x")
|
||||
y = kwargs.get("y")
|
||||
if x is None or y is None:
|
||||
return ToolResult(error="Coordinates required for drag_to")
|
||||
args.update(
|
||||
{
|
||||
"x": x,
|
||||
"y": y,
|
||||
"button": kwargs.get("button", "left"),
|
||||
"duration": float(kwargs.get("duration", 0.5)),
|
||||
}
|
||||
)
|
||||
elif action == "type_text":
|
||||
text = kwargs.get("text")
|
||||
if not text:
|
||||
return ToolResult(error="Text required for type_text")
|
||||
args["text"] = text
|
||||
elif action == "press_key":
|
||||
key = kwargs.get("key")
|
||||
if not key:
|
||||
return ToolResult(error="Key required for press_key")
|
||||
args["key"] = key
|
||||
elif action == "hotkey":
|
||||
keys = kwargs.get("keys")
|
||||
if not keys:
|
||||
return ToolResult(error="Keys required for hotkey")
|
||||
# Call with positional arguments instead of kwargs
|
||||
await method(*keys)
|
||||
return ToolResult(output=f"Hotkey executed: {'+'.join(keys)}")
|
||||
elif action in ["scroll_down", "scroll_up"]:
|
||||
args["clicks"] = int(kwargs.get("amount", 1))
|
||||
|
||||
# Execute action with prepared arguments
|
||||
await method(**args)
|
||||
return ToolResult(output=f"Action {action} executed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing computer action: {str(e)}")
|
||||
return ToolResult(error=f"Error: {str(e)}")
|
||||
@@ -1,61 +0,0 @@
|
||||
"""Tool manager for the Omni provider."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
from computer.computer import Computer
|
||||
|
||||
from ....core.tools import BaseToolManager, ToolResult
|
||||
from ....core.tools.collection import ToolCollection
|
||||
from .computer import ComputerTool
|
||||
from .bash import BashTool
|
||||
from ....core.types import LLMProvider
|
||||
|
||||
|
||||
class ToolManager(BaseToolManager):
|
||||
"""Manages Omni provider tool initialization and execution."""
|
||||
|
||||
def __init__(self, computer: Computer, provider: LLMProvider):
|
||||
"""Initialize the tool manager.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for computer-related tools
|
||||
provider: The LLM provider being used
|
||||
"""
|
||||
super().__init__(computer)
|
||||
self.provider = provider
|
||||
# Initialize Omni-specific tools
|
||||
self.computer_tool = ComputerTool(self.computer)
|
||||
self.bash_tool = BashTool(self.computer)
|
||||
|
||||
def _initialize_tools(self) -> ToolCollection:
|
||||
"""Initialize all available tools."""
|
||||
return ToolCollection(self.computer_tool, self.bash_tool)
|
||||
|
||||
async def _initialize_tools_specific(self) -> None:
|
||||
"""Initialize Omni provider-specific tool requirements."""
|
||||
await self.computer_tool.initialize_dimensions()
|
||||
|
||||
def get_tool_params(self) -> List[Dict[str, Any]]:
|
||||
"""Get tool parameters for API calls.
|
||||
|
||||
Returns:
|
||||
List of tool parameters for the current provider's API
|
||||
"""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
|
||||
return self.tools.to_params()
|
||||
|
||||
async def execute_tool(self, name: str, tool_input: dict[str, Any]) -> ToolResult:
|
||||
"""Execute a tool with the given input.
|
||||
|
||||
Args:
|
||||
name: Name of the tool to execute
|
||||
tool_input: Input parameters for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool execution
|
||||
"""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
|
||||
return await self.tools.run(name=name, tool_input=tool_input)
|
||||
@@ -1,236 +0,0 @@
|
||||
"""Main entry point for computer agents."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from som.models import ParseResult
|
||||
from ...core.types import AgentResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def to_openai_agent_response_format(
|
||||
response: Any,
|
||||
messages: List[Dict[str, Any]],
|
||||
parsed_screen: Optional[ParseResult] = None,
|
||||
parser: Optional[Any] = None,
|
||||
model: Optional[str] = None,
|
||||
) -> AgentResponse:
|
||||
"""Create an OpenAI computer use agent compatible response format.
|
||||
|
||||
Args:
|
||||
response: The original API response
|
||||
messages: List of messages in standard OpenAI format
|
||||
parsed_screen: Optional pre-parsed screen information
|
||||
parser: Optional parser instance for coordinate calculation
|
||||
model: Optional model name
|
||||
|
||||
Returns:
|
||||
A response formatted according to OpenAI's computer use agent standard, including:
|
||||
- All standard OpenAI computer use agent fields
|
||||
- Original response in response.choices[0].message
|
||||
- Full message history in messages field
|
||||
"""
|
||||
from datetime import datetime
|
||||
import time
|
||||
|
||||
# Create a unique ID for this response
|
||||
response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(response)}"
|
||||
reasoning_id = f"rs_{response_id}"
|
||||
action_id = f"cu_{response_id}"
|
||||
call_id = f"call_{response_id}"
|
||||
|
||||
# Extract the last assistant message
|
||||
assistant_msg = None
|
||||
for msg in reversed(messages):
|
||||
if msg["role"] == "assistant":
|
||||
assistant_msg = msg
|
||||
break
|
||||
|
||||
if not assistant_msg:
|
||||
# If no assistant message found, create a default one
|
||||
assistant_msg = {"role": "assistant", "content": "No response available"}
|
||||
|
||||
# Initialize output array
|
||||
output_items = []
|
||||
|
||||
# Extract reasoning and action details from the response
|
||||
content = assistant_msg["content"]
|
||||
reasoning_text = None
|
||||
action_details = None
|
||||
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
try:
|
||||
# Try to parse JSON from text block
|
||||
text_content = item.get("text", "")
|
||||
parsed_json = json.loads(text_content)
|
||||
|
||||
# Get reasoning text
|
||||
if reasoning_text is None:
|
||||
reasoning_text = parsed_json.get("Explanation", "")
|
||||
|
||||
# Extract action details
|
||||
action = parsed_json.get("Action", "").lower()
|
||||
text_input = parsed_json.get("Text", "")
|
||||
value = parsed_json.get("Value", "") # Also handle Value field
|
||||
box_id = parsed_json.get("Box ID") # Extract Box ID
|
||||
|
||||
if action in ["click", "left_click"]:
|
||||
# Always calculate coordinates from Box ID for click actions
|
||||
x, y = 100, 100 # Default fallback values
|
||||
|
||||
if parsed_screen and box_id is not None and parser is not None:
|
||||
try:
|
||||
box_id_int = (
|
||||
box_id
|
||||
if isinstance(box_id, int)
|
||||
else int(str(box_id)) if str(box_id).isdigit() else None
|
||||
)
|
||||
if box_id_int is not None:
|
||||
# Use the parser's method to calculate coordinates
|
||||
x, y = await parser.calculate_click_coordinates(
|
||||
box_id_int, parsed_screen
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error extracting coordinates for Box ID {box_id}: {str(e)}"
|
||||
)
|
||||
|
||||
action_details = {
|
||||
"type": "click",
|
||||
"button": "left",
|
||||
"box_id": (
|
||||
(
|
||||
box_id
|
||||
if isinstance(box_id, int)
|
||||
else int(box_id) if str(box_id).isdigit() else None
|
||||
)
|
||||
if box_id is not None
|
||||
else None
|
||||
),
|
||||
"x": x,
|
||||
"y": y,
|
||||
}
|
||||
elif action in ["type", "type_text"] and (text_input or value):
|
||||
action_details = {
|
||||
"type": "type",
|
||||
"text": text_input or value,
|
||||
}
|
||||
elif action == "hotkey" and value:
|
||||
action_details = {
|
||||
"type": "hotkey",
|
||||
"keys": value,
|
||||
}
|
||||
elif action == "scroll":
|
||||
# Use default coordinates for scrolling
|
||||
delta_x = 0
|
||||
delta_y = 0
|
||||
# Try to extract scroll delta values from content if available
|
||||
scroll_data = parsed_json.get("Scroll", {})
|
||||
if scroll_data:
|
||||
delta_x = scroll_data.get("delta_x", 0)
|
||||
delta_y = scroll_data.get("delta_y", 0)
|
||||
action_details = {
|
||||
"type": "scroll",
|
||||
"x": 100,
|
||||
"y": 100,
|
||||
"scroll_x": delta_x,
|
||||
"scroll_y": delta_y,
|
||||
}
|
||||
elif action == "none":
|
||||
# Handle case when action is None (task completion)
|
||||
action_details = {"type": "none", "description": "Task completed"}
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON, just use as reasoning text
|
||||
if reasoning_text is None:
|
||||
reasoning_text = ""
|
||||
reasoning_text += item.get("text", "")
|
||||
|
||||
# Add reasoning item if we have text content
|
||||
if reasoning_text:
|
||||
output_items.append(
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": reasoning_id,
|
||||
"summary": [
|
||||
{
|
||||
"type": "summary_text",
|
||||
"text": reasoning_text[:200], # Truncate to reasonable length
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# If no action details extracted, use default
|
||||
if not action_details:
|
||||
action_details = {
|
||||
"type": "click",
|
||||
"button": "left",
|
||||
"x": 100,
|
||||
"y": 100,
|
||||
}
|
||||
|
||||
# Add computer_call item
|
||||
computer_call = {
|
||||
"type": "computer_call",
|
||||
"id": action_id,
|
||||
"call_id": call_id,
|
||||
"action": action_details,
|
||||
"pending_safety_checks": [],
|
||||
"status": "completed",
|
||||
}
|
||||
output_items.append(computer_call)
|
||||
|
||||
# Extract user and assistant messages from the history
|
||||
user_messages = []
|
||||
assistant_messages = []
|
||||
for msg in messages:
|
||||
if msg["role"] == "user":
|
||||
user_messages.append(msg)
|
||||
elif msg["role"] == "assistant":
|
||||
assistant_messages.append(msg)
|
||||
|
||||
# Create the OpenAI-compatible response format with all expected fields
|
||||
return {
|
||||
"id": response_id,
|
||||
"object": "response",
|
||||
"created_at": int(time.time()),
|
||||
"status": "completed",
|
||||
"error": None,
|
||||
"incomplete_details": None,
|
||||
"instructions": None,
|
||||
"max_output_tokens": None,
|
||||
"model": model or "unknown",
|
||||
"output": output_items,
|
||||
"parallel_tool_calls": True,
|
||||
"previous_response_id": None,
|
||||
"reasoning": {"effort": "medium", "generate_summary": "concise"},
|
||||
"store": True,
|
||||
"temperature": 1.0,
|
||||
"text": {"format": {"type": "text"}},
|
||||
"tool_choice": "auto",
|
||||
"tools": [
|
||||
{
|
||||
"type": "computer_use_preview",
|
||||
"display_height": 768,
|
||||
"display_width": 1024,
|
||||
"environment": "mac",
|
||||
}
|
||||
],
|
||||
"top_p": 1.0,
|
||||
"truncation": "auto",
|
||||
"usage": {
|
||||
"input_tokens": 0, # Placeholder values
|
||||
"input_tokens_details": {"cached_tokens": 0},
|
||||
"output_tokens": 0, # Placeholder values
|
||||
"output_tokens_details": {"reasoning_tokens": 0},
|
||||
"total_tokens": 0, # Placeholder values
|
||||
},
|
||||
"user": None,
|
||||
"metadata": {},
|
||||
# Include the original response for backward compatibility
|
||||
"response": {"choices": [{"message": assistant_msg, "finish_reason": "stop"}]},
|
||||
}
|
||||
@@ -1,6 +0,0 @@
|
||||
"""OpenAI Agent Response API provider for computer control."""
|
||||
|
||||
from .types import LLMProvider
|
||||
from .loop import OpenAILoop
|
||||
|
||||
__all__ = ["OpenAILoop", "LLMProvider"]
|
||||
@@ -1,456 +0,0 @@
|
||||
"""API handler for the OpenAI provider."""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .loop import OpenAILoop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIAPIHandler:
|
||||
"""Handler for OpenAI API interactions."""
|
||||
|
||||
def __init__(self, loop: "OpenAILoop"):
|
||||
"""Initialize the API handler.
|
||||
|
||||
Args:
|
||||
loop: OpenAI loop instance
|
||||
"""
|
||||
self.loop = loop
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
self.api_base = "https://api.openai.com/v1"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Add organization if specified
|
||||
org_id = os.getenv("OPENAI_ORG")
|
||||
if org_id:
|
||||
self.headers["OpenAI-Organization"] = org_id
|
||||
|
||||
logger.info("Initialized OpenAI API handler")
|
||||
|
||||
async def send_initial_request(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
display_width: str,
|
||||
display_height: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send an initial request to the OpenAI API with a screenshot.
|
||||
|
||||
Args:
|
||||
messages: List of message objects in standard format
|
||||
display_width: Width of the display in pixels
|
||||
display_height: Height of the display in pixels
|
||||
previous_response_id: Optional ID of the previous response to link requests
|
||||
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
# Convert display dimensions to integers
|
||||
try:
|
||||
width = int(display_width)
|
||||
height = int(display_height)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Failed to convert display dimensions to integers: {str(e)}")
|
||||
raise ValueError(
|
||||
f"Display dimensions must be integers: width={display_width}, height={display_height}"
|
||||
)
|
||||
|
||||
# Extract the latest text message and screenshot from messages
|
||||
latest_text = None
|
||||
latest_screenshot = None
|
||||
|
||||
for msg in reversed(messages):
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
content = msg.get("content", [])
|
||||
|
||||
if isinstance(content, str) and not latest_text:
|
||||
latest_text = content
|
||||
continue
|
||||
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# Look for text if we don't have it yet
|
||||
if not latest_text and item.get("type") == "text" and "text" in item:
|
||||
latest_text = item.get("text", "")
|
||||
|
||||
# Look for an image if we don't have it yet
|
||||
if not latest_screenshot and item.get("type") == "image":
|
||||
source = item.get("source", {})
|
||||
if source.get("type") == "base64" and "data" in source:
|
||||
latest_screenshot = source["data"]
|
||||
|
||||
# Prepare the input array
|
||||
input_array = []
|
||||
|
||||
# Add the text message if found
|
||||
if latest_text:
|
||||
input_array.append({"role": "user", "content": latest_text})
|
||||
|
||||
# Add the screenshot if found and no previous_response_id is provided
|
||||
if latest_screenshot and not previous_response_id:
|
||||
input_array.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{latest_screenshot}",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare the request payload - using minimal format from docs
|
||||
payload = {
|
||||
"model": "computer-use-preview",
|
||||
"tools": [
|
||||
{
|
||||
"type": "computer_use_preview",
|
||||
"display_width": width,
|
||||
"display_height": height,
|
||||
"environment": "mac", # We're on macOS
|
||||
}
|
||||
],
|
||||
"input": input_array,
|
||||
"reasoning": {
|
||||
"generate_summary": "concise",
|
||||
},
|
||||
"truncation": "auto",
|
||||
}
|
||||
|
||||
# Add previous_response_id if provided
|
||||
if previous_response_id:
|
||||
payload["previous_response_id"] = previous_response_id
|
||||
|
||||
# Log the request using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("request", payload)
|
||||
|
||||
# Log for debug purposes
|
||||
logger.info("Sending initial request to OpenAI API")
|
||||
logger.debug(f"Request payload: {self._sanitize_response(payload)}")
|
||||
|
||||
# Send the request
|
||||
response = requests.post(
|
||||
f"{self.api_base}/responses",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_message = f"OpenAI API error: {response.status_code} {response.text}"
|
||||
logger.error(error_message)
|
||||
# Log the error using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("error", payload, error=Exception(error_message))
|
||||
raise Exception(error_message)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Log the response using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("response", payload, response_data)
|
||||
|
||||
# Log for debug purposes
|
||||
logger.info("Received response from OpenAI API")
|
||||
logger.debug(f"Response data: {self._sanitize_response(response_data)}")
|
||||
|
||||
return response_data
|
||||
|
||||
async def send_computer_call_request(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
display_width: str,
|
||||
display_height: str,
|
||||
previous_response_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a request to the OpenAI API with computer_call_output.
|
||||
|
||||
Args:
|
||||
messages: List of message objects in standard format
|
||||
display_width: Width of the display in pixels
|
||||
display_height: Height of the display in pixels
|
||||
system_prompt: System prompt to include
|
||||
previous_response_id: ID of the previous response to link requests
|
||||
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
# Convert display dimensions to integers
|
||||
try:
|
||||
width = int(display_width)
|
||||
height = int(display_height)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Failed to convert display dimensions to integers: {str(e)}")
|
||||
raise ValueError(
|
||||
f"Display dimensions must be integers: width={display_width}, height={display_height}"
|
||||
)
|
||||
|
||||
# Find the most recent computer_call_output with call_id
|
||||
call_id = None
|
||||
screenshot_base64 = None
|
||||
|
||||
# Look for call_id and screenshot in messages
|
||||
for msg in reversed(messages):
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
# Check if the message itself has a call_id
|
||||
if "call_id" in msg and not call_id:
|
||||
call_id = msg["call_id"]
|
||||
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# Look for call_id
|
||||
if not call_id and "call_id" in item:
|
||||
call_id = item["call_id"]
|
||||
|
||||
# Look for screenshot in computer_call_output
|
||||
if not screenshot_base64 and item.get("type") == "computer_call_output":
|
||||
output = item.get("output", {})
|
||||
if isinstance(output, dict) and "image_url" in output:
|
||||
image_url = output.get("image_url", "")
|
||||
if image_url.startswith("data:image/png;base64,"):
|
||||
screenshot_base64 = image_url[len("data:image/png;base64,") :]
|
||||
|
||||
# Look for screenshot in image type
|
||||
if not screenshot_base64 and item.get("type") == "image":
|
||||
source = item.get("source", {})
|
||||
if source.get("type") == "base64" and "data" in source:
|
||||
screenshot_base64 = source["data"]
|
||||
|
||||
if not call_id or not screenshot_base64:
|
||||
logger.error("Missing call_id or screenshot for computer_call_output")
|
||||
logger.error(f"Last message: {messages[-1] if messages else None}")
|
||||
raise ValueError("Cannot create computer call request: missing call_id or screenshot")
|
||||
|
||||
# Prepare the request payload using minimal format from docs
|
||||
payload = {
|
||||
"model": "computer-use-preview",
|
||||
"previous_response_id": previous_response_id,
|
||||
"tools": [
|
||||
{
|
||||
"type": "computer_use_preview",
|
||||
"display_width": width,
|
||||
"display_height": height,
|
||||
"environment": "mac", # We're on macOS
|
||||
}
|
||||
],
|
||||
"input": [
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
||||
},
|
||||
}
|
||||
],
|
||||
"truncation": "auto",
|
||||
}
|
||||
|
||||
# Log the request using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("request", payload)
|
||||
|
||||
# Log for debug purposes
|
||||
logger.info("Sending computer call request to OpenAI API")
|
||||
logger.debug(f"Request payload: {self._sanitize_response(payload)}")
|
||||
|
||||
# Send the request
|
||||
response = requests.post(
|
||||
f"{self.api_base}/responses",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_message = f"OpenAI API error: {response.status_code} {response.text}"
|
||||
logger.error(error_message)
|
||||
# Log the error using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("error", payload, error=Exception(error_message))
|
||||
raise Exception(error_message)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Log the response using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("response", payload, response_data)
|
||||
|
||||
# Log for debug purposes
|
||||
logger.info("Received response from OpenAI API")
|
||||
logger.debug(f"Response data: {self._sanitize_response(response_data)}")
|
||||
|
||||
return response_data
|
||||
|
||||
def _format_messages_for_agent_response(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Format messages for the OpenAI Agent Response API.
|
||||
|
||||
The Agent Response API requires specific content types:
|
||||
- For user messages: use "input_text", "input_image", etc.
|
||||
- For assistant messages: use "output_text" only
|
||||
|
||||
Additionally, when using the computer tool, only one image can be sent.
|
||||
|
||||
Args:
|
||||
messages: List of standard messages
|
||||
|
||||
Returns:
|
||||
Messages formatted for the Agent Response API
|
||||
"""
|
||||
formatted_messages = []
|
||||
has_image = False # Track if we've already included an image
|
||||
|
||||
# We need to process messages in reverse to ensure we keep the most recent image
|
||||
# but preserve the original order in the final output
|
||||
reversed_messages = list(reversed(messages))
|
||||
temp_formatted = []
|
||||
|
||||
for msg in reversed_messages:
|
||||
if not msg:
|
||||
continue
|
||||
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
logger.debug(f"Processing message - Role: {role}, Content type: {type(content)}")
|
||||
if isinstance(content, list):
|
||||
logger.debug(
|
||||
f"List content items: {[item.get('type') for item in content if isinstance(item, dict)]}"
|
||||
)
|
||||
|
||||
if isinstance(content, str):
|
||||
# For string content, create a message with the appropriate text type
|
||||
if role == "user":
|
||||
temp_formatted.append(
|
||||
{"role": role, "content": [{"type": "input_text", "text": content}]}
|
||||
)
|
||||
elif role == "assistant":
|
||||
# For assistant, we need explicit output_text
|
||||
temp_formatted.append(
|
||||
{"role": role, "content": [{"type": "output_text", "text": content}]}
|
||||
)
|
||||
elif role == "system":
|
||||
# System messages need to be formatted as input_text as well
|
||||
temp_formatted.append(
|
||||
{"role": role, "content": [{"type": "input_text", "text": content}]}
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
# For list content, convert each item to the correct type based on role
|
||||
formatted_content = []
|
||||
has_image_in_this_message = False
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_type = item.get("type")
|
||||
|
||||
if role == "user":
|
||||
# Handle user message formatting
|
||||
if item_type == "text" or item_type == "input_text":
|
||||
# Text from user is input_text
|
||||
formatted_content.append(
|
||||
{"type": "input_text", "text": item.get("text", "")}
|
||||
)
|
||||
elif (item_type == "image" or item_type == "image_url") and not has_image:
|
||||
# Only include the first/most recent image we encounter
|
||||
if item_type == "image":
|
||||
# Image from user is input_image
|
||||
source = item.get("source", {})
|
||||
if source.get("type") == "base64" and "data" in source:
|
||||
formatted_content.append(
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{source['data']}",
|
||||
}
|
||||
)
|
||||
has_image = True
|
||||
has_image_in_this_message = True
|
||||
elif item_type == "image_url":
|
||||
# Convert "image_url" to "input_image"
|
||||
formatted_content.append(
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": item.get("image_url", {}).get("url", ""),
|
||||
}
|
||||
)
|
||||
has_image = True
|
||||
has_image_in_this_message = True
|
||||
elif role == "assistant":
|
||||
# Handle assistant message formatting - only output_text is supported
|
||||
if item_type == "text" or item_type == "output_text":
|
||||
formatted_content.append(
|
||||
{"type": "output_text", "text": item.get("text", "")}
|
||||
)
|
||||
|
||||
if formatted_content:
|
||||
# If this message had an image, mark it for inclusion
|
||||
temp_formatted.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": formatted_content,
|
||||
"_had_image": has_image_in_this_message, # Temporary marker
|
||||
}
|
||||
)
|
||||
|
||||
# Reverse back to original order and cleanup
|
||||
for msg in reversed(temp_formatted):
|
||||
# Remove our temporary marker
|
||||
if "_had_image" in msg:
|
||||
del msg["_had_image"]
|
||||
formatted_messages.append(msg)
|
||||
|
||||
# Log summary for debugging
|
||||
num_images = sum(
|
||||
1
|
||||
for msg in formatted_messages
|
||||
for item in (msg.get("content", []) if isinstance(msg.get("content"), list) else [])
|
||||
if isinstance(item, dict) and item.get("type") == "input_image"
|
||||
)
|
||||
logger.info(f"Formatted {len(messages)} messages for OpenAI API with {num_images} images")
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _sanitize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sanitize response for logging by removing large image data.
|
||||
|
||||
Args:
|
||||
response: Response to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized response
|
||||
"""
|
||||
from .utils import sanitize_message
|
||||
|
||||
# Deep copy to avoid modifying the original
|
||||
sanitized = response.copy()
|
||||
|
||||
# Sanitize output items if present
|
||||
if "output" in sanitized and isinstance(sanitized["output"], list):
|
||||
sanitized["output"] = [sanitize_message(item) for item in sanitized["output"]]
|
||||
|
||||
return sanitized
|
||||
@@ -1,472 +0,0 @@
|
||||
"""OpenAI Agent Response API provider implementation."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional, AsyncGenerator, Callable, Awaitable, TYPE_CHECKING
|
||||
|
||||
from computer import Computer
|
||||
from ...core.base import BaseLoop
|
||||
from ...core.types import AgentResponse
|
||||
from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
||||
|
||||
from .api_handler import OpenAIAPIHandler
|
||||
from .response_handler import OpenAIResponseHandler
|
||||
from .tools.manager import ToolManager
|
||||
from .types import LLMProvider, ResponseItemType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAILoop(BaseLoop):
|
||||
"""OpenAI-specific implementation of the agent loop.
|
||||
|
||||
This class extends BaseLoop to provide specialized support for OpenAI's Agent Response API
|
||||
with computer control capabilities.
|
||||
"""
|
||||
|
||||
###########################################
|
||||
# INITIALIZATION AND CONFIGURATION
|
||||
###########################################
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
computer: Computer,
|
||||
model: str = "computer-use-preview",
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
save_trajectory: bool = True,
|
||||
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the OpenAI loop.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Model name (ignored, always uses computer-use-preview)
|
||||
computer: Computer instance
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
base_dir: Base directory for saving experiment data
|
||||
max_retries: Maximum number of retries for API calls
|
||||
retry_delay: Delay between retries in seconds
|
||||
save_trajectory: Whether to save trajectory data
|
||||
acknowledge_safety_check_callback: Optional callback for safety check acknowledgment
|
||||
**kwargs: Additional provider-specific arguments
|
||||
"""
|
||||
# Always use computer-use-preview model
|
||||
if model != "computer-use-preview":
|
||||
logger.info(
|
||||
f"Overriding provided model '{model}' with required model 'computer-use-preview'"
|
||||
)
|
||||
|
||||
# Initialize base class with core config
|
||||
super().__init__(
|
||||
computer=computer,
|
||||
model="computer-use-preview", # Always use computer-use-preview
|
||||
api_key=api_key,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
base_dir=base_dir,
|
||||
save_trajectory=save_trajectory,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize message manager
|
||||
self.message_manager = StandardMessageManager(
|
||||
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
||||
)
|
||||
|
||||
# OpenAI-specific attributes
|
||||
self.provider = LLMProvider.OPENAI
|
||||
self.client = None
|
||||
self.retry_count = 0
|
||||
self.acknowledge_safety_check_callback = acknowledge_safety_check_callback
|
||||
self.queue = asyncio.Queue() # Initialize queue
|
||||
self.last_response_id = None # Store the last response ID across runs
|
||||
self.loop_task = None # Store the loop task for cancellation
|
||||
|
||||
# Initialize handlers
|
||||
self.api_handler = OpenAIAPIHandler(self)
|
||||
self.response_handler = OpenAIResponseHandler(self)
|
||||
|
||||
# Initialize tool manager with callback
|
||||
self.tool_manager = ToolManager(
|
||||
computer=computer, acknowledge_safety_check_callback=acknowledge_safety_check_callback
|
||||
)
|
||||
|
||||
###########################################
|
||||
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def initialize_client(self) -> None:
|
||||
"""Initialize the OpenAI API client and tools.
|
||||
|
||||
Implements abstract method from BaseLoop to set up the OpenAI-specific
|
||||
client, tool manager, and message manager.
|
||||
"""
|
||||
try:
|
||||
# Initialize tool manager
|
||||
await self.tool_manager.initialize()
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing OpenAI client: {str(e)}")
|
||||
self.client = None
|
||||
raise RuntimeError(f"Failed to initialize OpenAI client: {str(e)}")
|
||||
|
||||
###########################################
|
||||
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
messages: List of message objects in standard format
|
||||
|
||||
Yields:
|
||||
Agent response format
|
||||
"""
|
||||
try:
|
||||
logger.info("Starting OpenAI loop run")
|
||||
|
||||
# Create queue for response streaming
|
||||
self.queue = asyncio.Queue()
|
||||
|
||||
# Ensure tool manager is initialized
|
||||
await self.tool_manager.initialize()
|
||||
|
||||
# Start loop in background task
|
||||
self.loop_task = asyncio.create_task(self._run_loop(self.queue, messages))
|
||||
|
||||
# Process and yield messages as they arrive
|
||||
while True:
|
||||
try:
|
||||
item = await self.queue.get()
|
||||
if item is None: # Stop signal
|
||||
break
|
||||
yield item
|
||||
self.queue.task_done()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing queue item: {str(e)}")
|
||||
continue
|
||||
|
||||
# Wait for loop to complete
|
||||
await self.loop_task
|
||||
|
||||
# Send completion message
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": "Task completed successfully.",
|
||||
"metadata": {"title": "✅ Complete"},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing task: {str(e)}")
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel the currently running agent loop task.
|
||||
|
||||
This method stops the ongoing processing in the agent loop
|
||||
by cancelling the loop_task if it exists and is running.
|
||||
"""
|
||||
if self.loop_task and not self.loop_task.done():
|
||||
logger.info("Cancelling OpenAI loop task")
|
||||
self.loop_task.cancel()
|
||||
try:
|
||||
# Wait for the task to be cancelled with a timeout
|
||||
await asyncio.wait_for(self.loop_task, timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout while waiting for loop task to cancel")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Loop task cancelled successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error while cancelling loop task: {str(e)}")
|
||||
finally:
|
||||
# Put None in the queue to signal any waiting consumers to stop
|
||||
await self.queue.put(None)
|
||||
logger.info("OpenAI loop task cancelled")
|
||||
else:
|
||||
logger.info("No active OpenAI loop task to cancel")
|
||||
|
||||
###########################################
|
||||
# AGENT LOOP IMPLEMENTATION
|
||||
###########################################
|
||||
|
||||
async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
queue: Queue for response streaming
|
||||
messages: List of messages in standard format
|
||||
"""
|
||||
try:
|
||||
# Use the instance-level last_response_id instead of creating a local variable
|
||||
# This way it persists between runs
|
||||
|
||||
# Capture initial screenshot
|
||||
try:
|
||||
# Take screenshot
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
logger.info("Screenshot captured successfully")
|
||||
|
||||
# Convert to base64 if needed
|
||||
if isinstance(screenshot, bytes):
|
||||
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
|
||||
elif isinstance(screenshot, (bytearray, memoryview)):
|
||||
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
|
||||
else:
|
||||
screenshot_base64 = str(screenshot)
|
||||
|
||||
# Emit screenshot callbacks
|
||||
await self.handle_screenshot(screenshot_base64, action_type="initial_state")
|
||||
self._save_screenshot(screenshot_base64, action_type="state")
|
||||
|
||||
# First add any existing user messages that were passed to run()
|
||||
user_query = None
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
user_content = msg.get("content", "")
|
||||
if isinstance(user_content, str) and user_content:
|
||||
user_query = user_content
|
||||
# Add the user's original query to the message manager
|
||||
self.message_manager.add_user_message(
|
||||
[{"type": "text", "text": user_content}]
|
||||
)
|
||||
break
|
||||
|
||||
# Add screenshot to message manager
|
||||
message_content = [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": screenshot_base64,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Add appropriate text with the screenshot
|
||||
message_content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": user_query,
|
||||
}
|
||||
)
|
||||
|
||||
# Add the screenshot and text to the message manager
|
||||
self.message_manager.add_user_message(message_content)
|
||||
|
||||
# Process user request and convert our standard message format to one OpenAI expects
|
||||
messages = self.message_manager.messages
|
||||
logger.info(f"Starting agent loop with {len(messages)} messages")
|
||||
|
||||
# Create initial turn directory
|
||||
if self.save_trajectory:
|
||||
self._create_turn_dir()
|
||||
|
||||
# Call API
|
||||
screen_size = await self.computer.interface.get_screen_size()
|
||||
response = await self.api_handler.send_initial_request(
|
||||
messages=self.message_manager.get_messages(), # Apply image retention policy
|
||||
display_width=str(screen_size["width"]),
|
||||
display_height=str(screen_size["height"]),
|
||||
previous_response_id=self.last_response_id,
|
||||
)
|
||||
|
||||
# Store response ID for next request
|
||||
# OpenAI API response structure: the ID is in the response dictionary
|
||||
if isinstance(response, dict) and "id" in response:
|
||||
self.last_response_id = response["id"] # Update instance variable
|
||||
logger.info(f"Received response with ID: {self.last_response_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not find response ID in OpenAI response: {type(response)}"
|
||||
)
|
||||
# Don't reset last_response_id to None - keep the previous value if available
|
||||
|
||||
|
||||
# Log standardized response for ease of parsing
|
||||
# Since this is the openAI responses format, we don't need to convert it to agent response format
|
||||
self._log_api_call("agent_response", request=None, response=response)
|
||||
# Process API response
|
||||
await queue.put(response)
|
||||
|
||||
# Loop to continue processing responses until task is complete
|
||||
task_complete = False
|
||||
while not task_complete:
|
||||
# Check if there are any computer calls
|
||||
output_items = response.get("output", []) or []
|
||||
computer_calls = [
|
||||
item for item in output_items if item.get("type") == "computer_call"
|
||||
]
|
||||
|
||||
if not computer_calls:
|
||||
logger.info("No computer calls in response, task may be complete.")
|
||||
task_complete = True
|
||||
continue
|
||||
|
||||
# Process the first computer call
|
||||
computer_call = computer_calls[0]
|
||||
action = computer_call.get("action", {})
|
||||
call_id = computer_call.get("call_id")
|
||||
|
||||
# Check for safety checks
|
||||
pending_safety_checks = computer_call.get("pending_safety_checks", [])
|
||||
acknowledged_safety_checks = []
|
||||
|
||||
if pending_safety_checks:
|
||||
# Log safety checks
|
||||
for check in pending_safety_checks:
|
||||
logger.warning(
|
||||
f"Safety check: {check.get('code')} - {check.get('message')}"
|
||||
)
|
||||
|
||||
# If we have a callback, use it to acknowledge safety checks
|
||||
if self.acknowledge_safety_check_callback:
|
||||
acknowledged = await self.acknowledge_safety_check_callback(
|
||||
pending_safety_checks
|
||||
)
|
||||
if not acknowledged:
|
||||
logger.warning("Safety check acknowledgment failed")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Safety checks were not acknowledged. Cannot proceed with action.",
|
||||
"metadata": {"title": "⚠️ Safety Warning"},
|
||||
}
|
||||
)
|
||||
continue
|
||||
acknowledged_safety_checks = pending_safety_checks
|
||||
|
||||
# Execute the action
|
||||
try:
|
||||
# Create a new turn directory for this action if saving trajectories
|
||||
if self.save_trajectory:
|
||||
self._create_turn_dir()
|
||||
|
||||
# Execute the tool
|
||||
result = await self.tool_manager.execute_tool("computer", action)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
if isinstance(screenshot, bytes):
|
||||
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
|
||||
elif isinstance(screenshot, (bytearray, memoryview)):
|
||||
screenshot_base64 = base64.b64encode(bytes(screenshot)).decode("utf-8")
|
||||
else:
|
||||
screenshot_base64 = str(screenshot)
|
||||
|
||||
# Process screenshot through hooks
|
||||
action_type = f"after_{action.get('type', 'action')}"
|
||||
await self.handle_screenshot(screenshot_base64, action_type=action_type)
|
||||
self._save_screenshot(screenshot_base64, action_type=action_type)
|
||||
|
||||
# Create computer_call_output
|
||||
computer_call_output = {
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
||||
},
|
||||
}
|
||||
|
||||
# Add acknowledged safety checks if any
|
||||
if acknowledged_safety_checks:
|
||||
computer_call_output["acknowledged_safety_checks"] = (
|
||||
acknowledged_safety_checks
|
||||
)
|
||||
|
||||
# Save to message manager for history
|
||||
self.message_manager.add_system_message(
|
||||
f"[Computer action executed: {action.get('type')}]"
|
||||
)
|
||||
self.message_manager.add_user_message([computer_call_output])
|
||||
|
||||
# For follow-up requests with previous_response_id, we only need to send
|
||||
# the computer_call_output, not the full message history
|
||||
# The API handler will extract this from the message history
|
||||
if isinstance(self.last_response_id, str):
|
||||
response = await self.api_handler.send_computer_call_request(
|
||||
messages=self.message_manager.get_messages(), # Apply image retention policy
|
||||
display_width=str(screen_size["width"]),
|
||||
display_height=str(screen_size["height"]),
|
||||
previous_response_id=self.last_response_id, # Use instance variable
|
||||
)
|
||||
|
||||
# Store response ID for next request
|
||||
if isinstance(response, dict) and "id" in response:
|
||||
self.last_response_id = response["id"] # Update instance variable
|
||||
logger.info(f"Received response with ID: {self.last_response_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not find response ID in OpenAI response: {type(response)}"
|
||||
)
|
||||
# Keep using the previous response ID if we can't find a new one
|
||||
|
||||
# Process the response
|
||||
# await self.response_handler.process_response(response, queue)
|
||||
self._log_api_call("agent_response", request=None, response=response)
|
||||
await queue.put(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing computer action: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error executing action: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
task_complete = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing initial screenshot: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error capturing screenshot: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
await queue.put(None) # Signal that we're done
|
||||
return
|
||||
|
||||
# Signal that we're done
|
||||
await queue.put(None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _run_loop: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
await queue.put(None) # Signal that we're done
|
||||
|
||||
def get_last_response_id(self) -> Optional[str]:
|
||||
"""Get the last response ID.
|
||||
|
||||
Returns:
|
||||
The last response ID or None if no response has been received
|
||||
"""
|
||||
return self.last_response_id
|
||||
|
||||
def set_last_response_id(self, response_id: str) -> None:
|
||||
"""Set the last response ID.
|
||||
|
||||
Args:
|
||||
response_id: OpenAI response ID to set
|
||||
"""
|
||||
self.last_response_id = response_id
|
||||
logger.info(f"Manually set response ID to: {self.last_response_id}")
|
||||
@@ -1,205 +0,0 @@
|
||||
"""Response handler for the OpenAI provider."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, AsyncGenerator
|
||||
import base64
|
||||
|
||||
from ...core.types import AgentResponse
|
||||
from .types import ResponseItemType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .loop import OpenAILoop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIResponseHandler:
|
||||
"""Handler for OpenAI API responses."""
|
||||
|
||||
def __init__(self, loop: "OpenAILoop"):
|
||||
"""Initialize the response handler.
|
||||
|
||||
Args:
|
||||
loop: OpenAI loop instance
|
||||
"""
|
||||
self.loop = loop
|
||||
logger.info("Initialized OpenAI response handler")
|
||||
|
||||
async def process_response(self, response: Dict[str, Any], queue: asyncio.Queue) -> None:
|
||||
"""Process the response from the OpenAI API.
|
||||
|
||||
Args:
|
||||
response: Response from the API
|
||||
queue: Queue for response streaming
|
||||
"""
|
||||
try:
|
||||
# Get output items
|
||||
output_items = response.get("output", []) or []
|
||||
|
||||
# Process each output item
|
||||
for item in output_items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_type = item.get("type")
|
||||
|
||||
# For computer_call items, we only need to add to the queue
|
||||
# The loop is now handling executing the action and creating the computer_call_output
|
||||
if item_type == ResponseItemType.COMPUTER_CALL:
|
||||
# Send computer_call to queue so it can be processed
|
||||
await queue.put(item)
|
||||
|
||||
elif item_type == ResponseItemType.MESSAGE:
|
||||
# Send message to queue
|
||||
await queue.put(item)
|
||||
|
||||
elif item_type == ResponseItemType.REASONING:
|
||||
# Process reasoning summary
|
||||
summary = None
|
||||
if "summary" in item and isinstance(item["summary"], list):
|
||||
for summary_item in item["summary"]:
|
||||
if (
|
||||
isinstance(summary_item, dict)
|
||||
and summary_item.get("type") == "summary_text"
|
||||
):
|
||||
summary = summary_item.get("text")
|
||||
break
|
||||
|
||||
if summary:
|
||||
# Log the reasoning summary
|
||||
logger.info(f"Reasoning summary: {summary}")
|
||||
|
||||
# Send reasoning summary to queue with a special format
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"[Reasoning: {summary}]",
|
||||
"metadata": {"title": "💭 Reasoning", "is_summary": True},
|
||||
}
|
||||
)
|
||||
|
||||
# Also pass the original reasoning item to the queue for complete context
|
||||
await queue.put(item)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing response: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error processing response: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
|
||||
def _process_message_item(self, item: Dict[str, Any]) -> AgentResponse:
|
||||
"""Process a message item from the response.
|
||||
|
||||
Args:
|
||||
item: Message item from the response
|
||||
|
||||
Returns:
|
||||
Processed message in AgentResponse format
|
||||
"""
|
||||
# Extract content items - add null check
|
||||
content_items = item.get("content", []) or []
|
||||
|
||||
# Extract text from content items - use output_text type from OpenAI
|
||||
text = ""
|
||||
for content_item in content_items:
|
||||
# Skip if content_item is None or not a dict
|
||||
if content_item is None or not isinstance(content_item, dict):
|
||||
continue
|
||||
|
||||
# In OpenAI Agent Response API, text content is in "output_text" type items
|
||||
if content_item.get("type") == "output_text":
|
||||
text += content_item.get("text", "")
|
||||
|
||||
# Create agent response
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": text
|
||||
or "I don't have a response for that right now.", # Provide fallback when text is empty
|
||||
"metadata": {"title": "💬 Response"},
|
||||
}
|
||||
|
||||
async def _process_computer_call(self, item: Dict[str, Any], queue: asyncio.Queue) -> None:
|
||||
"""Process a computer call item from the response.
|
||||
|
||||
Args:
|
||||
item: Computer call item
|
||||
queue: Queue to add responses to
|
||||
"""
|
||||
try:
|
||||
# Log the computer call
|
||||
action = item.get("action", {}) or {}
|
||||
if not isinstance(action, dict):
|
||||
logger.warning(f"Expected dict for action, got {type(action)}")
|
||||
action = {}
|
||||
|
||||
action_type = action.get("type", "unknown")
|
||||
logger.info(f"Processing computer call: {action_type}")
|
||||
|
||||
# Execute the tool call
|
||||
result = await self.loop.tool_manager.execute_tool("computer", action)
|
||||
|
||||
# Add any message to the conversation history and queue
|
||||
if result and result.base64_image:
|
||||
# Update message history with the call output
|
||||
self.loop.message_manager.add_user_message(
|
||||
[{"type": "text", "text": f"[Computer action completed: {action_type}]"}]
|
||||
)
|
||||
|
||||
# Add image to messages (using correct content types for Agent Response API)
|
||||
self.loop.message_manager.add_user_message(
|
||||
[
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": result.base64_image,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# If browser environment, include URL if available
|
||||
# if (
|
||||
# hasattr(self.loop.computer, "environment")
|
||||
# and self.loop.computer.environment == "browser"
|
||||
# ):
|
||||
# try:
|
||||
# if hasattr(self.loop.computer.interface, "get_current_url"):
|
||||
# current_url = await self.loop.computer.interface.get_current_url()
|
||||
# self.loop.message_manager.add_user_message(
|
||||
# [
|
||||
# {
|
||||
# "type": "text",
|
||||
# "text": f"Current URL: {current_url}",
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to get current URL: {str(e)}")
|
||||
|
||||
# Log successful completion
|
||||
logger.info(f"Computer call {action_type} executed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing computer call: {str(e)}")
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
# Add error to conversation
|
||||
self.loop.message_manager.add_user_message(
|
||||
[{"type": "text", "text": f"Error executing computer action: {str(e)}"}]
|
||||
)
|
||||
|
||||
# Send error to queue
|
||||
error_response = {
|
||||
"role": "assistant",
|
||||
"content": f"Error executing computer action: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
await queue.put(error_response)
|
||||
@@ -1,15 +0,0 @@
|
||||
"""OpenAI tools module for computer control."""
|
||||
|
||||
from .manager import ToolManager
|
||||
from .computer import ComputerTool
|
||||
from .base import BaseOpenAITool, ToolResult, ToolError, ToolFailure, CLIResult
|
||||
|
||||
__all__ = [
|
||||
"ToolManager",
|
||||
"ComputerTool",
|
||||
"BaseOpenAITool",
|
||||
"ToolResult",
|
||||
"ToolError",
|
||||
"ToolFailure",
|
||||
"CLIResult",
|
||||
]
|
||||
@@ -1,79 +0,0 @@
|
||||
"""OpenAI-specific tool base classes."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ....core.tools.base import BaseTool
|
||||
|
||||
|
||||
class BaseOpenAITool(BaseTool, metaclass=ABCMeta):
|
||||
"""Abstract base class for OpenAI-defined tools."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the base OpenAI tool."""
|
||||
# No specific initialization needed yet, but included for future extensibility
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> Any:
|
||||
"""Executes the tool with the given arguments."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to OpenAI-specific API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters for OpenAI API
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class ToolResult:
|
||||
"""Represents the result of a tool execution."""
|
||||
|
||||
output: str | None = None
|
||||
error: str | None = None
|
||||
base64_image: str | None = None
|
||||
system: str | None = None
|
||||
content: list[dict] | None = None
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field.name) for field in fields(self))
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
content=self.content or other.content, # Use first non-None content
|
||||
)
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new ToolResult with the given fields replaced."""
|
||||
return replace(self, **kwargs)
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""A ToolResult that can be rendered as a CLI output."""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
class ToolError(Exception):
|
||||
"""Raised when a tool encounters an error."""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
@@ -1,326 +0,0 @@
|
||||
"""Computer tool for OpenAI."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from typing import Literal, Any, Dict, Optional, List, Union
|
||||
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseOpenAITool, ToolError, ToolResult
|
||||
from ....core.tools.computer import BaseComputerTool
|
||||
|
||||
TYPING_DELAY_MS = 12
|
||||
TYPING_GROUP_SIZE = 50
|
||||
|
||||
# Key mapping for special keys
|
||||
KEY_MAPPING = {
|
||||
"enter": "return",
|
||||
"backspace": "delete",
|
||||
"delete": "forwarddelete",
|
||||
"escape": "esc",
|
||||
"pageup": "page_up",
|
||||
"pagedown": "page_down",
|
||||
"arrowup": "up",
|
||||
"arrowdown": "down",
|
||||
"arrowleft": "left",
|
||||
"arrowright": "right",
|
||||
"home": "home",
|
||||
"end": "end",
|
||||
"tab": "tab",
|
||||
"space": "space",
|
||||
"shift": "shift",
|
||||
"control": "control",
|
||||
"alt": "alt",
|
||||
"meta": "command",
|
||||
}
|
||||
|
||||
Action = Literal[
|
||||
"key",
|
||||
"type",
|
||||
"mouse_move",
|
||||
"left_click",
|
||||
"right_click",
|
||||
"double_click",
|
||||
"screenshot",
|
||||
"scroll",
|
||||
"drag",
|
||||
]
|
||||
|
||||
|
||||
class ComputerTool(BaseComputerTool, BaseOpenAITool):
|
||||
"""
|
||||
A tool that allows the agent to interact with the screen, keyboard, and mouse of the current computer.
|
||||
"""
|
||||
|
||||
name: Literal["computer"] = "computer"
|
||||
api_type: Literal["computer_use_preview"] = "computer_use_preview"
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
display_num: Optional[int] = None
|
||||
computer: Computer # The CUA Computer instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the computer tool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance
|
||||
"""
|
||||
self.computer = computer
|
||||
self.width = None
|
||||
self.height = None
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize the base computer tool first
|
||||
BaseComputerTool.__init__(self, computer)
|
||||
# Then initialize the OpenAI tool
|
||||
BaseOpenAITool.__init__(self)
|
||||
|
||||
# Additional initialization
|
||||
self.width = None # Will be initialized from computer interface
|
||||
self.height = None # Will be initialized from computer interface
|
||||
self.display_num = None
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
if self.width is None or self.height is None:
|
||||
raise RuntimeError(
|
||||
"Screen dimensions not initialized. Call initialize_dimensions() first."
|
||||
)
|
||||
return {
|
||||
"type": self.api_type,
|
||||
"display_width": self.width,
|
||||
"display_height": self.height,
|
||||
"display_number": self.display_num,
|
||||
}
|
||||
|
||||
async def initialize_dimensions(self):
|
||||
"""Initialize screen dimensions from the computer interface."""
|
||||
try:
|
||||
display_size = await self.computer.interface.get_screen_size()
|
||||
self.width = display_size["width"]
|
||||
self.height = display_size["height"]
|
||||
assert isinstance(self.width, int) and isinstance(self.height, int)
|
||||
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
|
||||
except Exception as e:
|
||||
# Fall back to defaults if we can't get accurate dimensions
|
||||
self.width = 1024
|
||||
self.height = 768
|
||||
self.logger.warning(
|
||||
f"Failed to get screen dimensions, using defaults: {self.width}x{self.height}. Error: {e}"
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
type: str, # OpenAI uses 'type' instead of 'action'
|
||||
text: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
# Ensure dimensions are initialized
|
||||
if self.width is None or self.height is None:
|
||||
await self.initialize_dimensions()
|
||||
if self.width is None or self.height is None:
|
||||
raise ToolError("Failed to initialize screen dimensions")
|
||||
|
||||
if type == "type":
|
||||
if text is None:
|
||||
raise ToolError("text is required for type action")
|
||||
return await self.handle_typing(text)
|
||||
elif type == "click":
|
||||
# Map button to correct action name
|
||||
button = kwargs.get("button")
|
||||
if button is None:
|
||||
raise ToolError("button is required for click action")
|
||||
return await self.handle_click(button, kwargs["x"], kwargs["y"])
|
||||
elif type == "keypress":
|
||||
# Check for keys in kwargs if text is None
|
||||
if text is None:
|
||||
if "keys" in kwargs and isinstance(kwargs["keys"], list):
|
||||
# Pass the keys list directly instead of joining and then splitting
|
||||
return await self.handle_key(kwargs["keys"])
|
||||
else:
|
||||
raise ToolError("Either 'text' or 'keys' is required for keypress action")
|
||||
return await self.handle_key(text)
|
||||
elif type == "mouse_move":
|
||||
if "coordinates" not in kwargs:
|
||||
raise ToolError("coordinates is required for mouse_move action")
|
||||
return await self.handle_mouse_move(
|
||||
kwargs["coordinates"][0], kwargs["coordinates"][1]
|
||||
)
|
||||
elif type == "scroll":
|
||||
# Get x, y coordinates directly from kwargs
|
||||
x = kwargs.get("x")
|
||||
y = kwargs.get("y")
|
||||
if x is None or y is None:
|
||||
raise ToolError("x and y coordinates are required for scroll action")
|
||||
scroll_x = kwargs.get("scroll_x", 0) // 50
|
||||
scroll_y = kwargs.get("scroll_y", 0) // 50
|
||||
return await self.handle_scroll(x, y, scroll_x, scroll_y)
|
||||
elif type == "drag":
|
||||
path = kwargs.get("path")
|
||||
if not path or not isinstance(path, list) or len(path) < 2:
|
||||
raise ToolError("path is required for drag action and must contain at least 2 points")
|
||||
return await self.handle_drag(path)
|
||||
elif type == "screenshot":
|
||||
return await self.screenshot()
|
||||
elif type == "wait":
|
||||
duration = kwargs.get("duration", 1.0)
|
||||
await asyncio.sleep(duration)
|
||||
return await self.screenshot()
|
||||
else:
|
||||
raise ToolError(f"Unsupported action: {type}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in ComputerTool.__call__: {str(e)}")
|
||||
raise ToolError(f"Failed to execute {type}: {str(e)}")
|
||||
|
||||
async def handle_click(self, button: str, x: int, y: int) -> ToolResult:
|
||||
"""Handle mouse clicks."""
|
||||
try:
|
||||
# Perform the click based on button type
|
||||
if button == "left":
|
||||
await self.computer.interface.left_click(x, y)
|
||||
elif button == "right":
|
||||
await self.computer.interface.right_click(x, y)
|
||||
elif button == "double":
|
||||
await self.computer.interface.double_click(x, y)
|
||||
else:
|
||||
raise ToolError(f"Unsupported button type: {button}")
|
||||
|
||||
# Wait briefly for UI to update
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
return ToolResult(
|
||||
output=f"Performed {button} click at ({x}, {y})",
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_click: {str(e)}")
|
||||
raise ToolError(f"Failed to perform {button} click at ({x}, {y}): {str(e)}")
|
||||
|
||||
async def handle_typing(self, text: str) -> ToolResult:
|
||||
"""Handle typing text with a small delay between characters."""
|
||||
try:
|
||||
# Type the text with a small delay
|
||||
await self.computer.interface.type_text(text)
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
return ToolResult(output=f"Typed: {text}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_typing: {str(e)}")
|
||||
raise ToolError(f"Failed to type '{text}': {str(e)}")
|
||||
|
||||
async def handle_key(self, key: Union[str, List[str]]) -> ToolResult:
|
||||
"""Handle key press, supporting both single keys and combinations.
|
||||
|
||||
Args:
|
||||
key: Either a string (e.g. "ctrl+c") or a list of keys (e.g. ["ctrl", "c"])
|
||||
"""
|
||||
try:
|
||||
# Check if key is already a list
|
||||
if isinstance(key, list):
|
||||
keys = [k.strip().lower() for k in key]
|
||||
else:
|
||||
# Split key string into list if it's a combination (e.g. "ctrl+c")
|
||||
keys = [k.strip().lower() for k in key.split("+")]
|
||||
|
||||
# Map each key
|
||||
mapped_keys = [KEY_MAPPING.get(k, k) for k in keys]
|
||||
|
||||
if len(mapped_keys) > 1:
|
||||
# For key combinations (like Ctrl+C)
|
||||
await self.computer.interface.hotkey(*mapped_keys)
|
||||
else:
|
||||
# Single key press
|
||||
await self.computer.interface.press_key(mapped_keys[0])
|
||||
|
||||
# Wait briefly
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
return ToolResult(output=f"Pressed key: {key}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_key: {str(e)}")
|
||||
raise ToolError(f"Failed to press key '{key}': {str(e)}")
|
||||
|
||||
async def handle_mouse_move(self, x: int, y: int) -> ToolResult:
|
||||
"""Handle mouse movement."""
|
||||
try:
|
||||
# Move cursor to position
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
|
||||
# Wait briefly
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
return ToolResult(output=f"Moved cursor to ({x}, {y})")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_mouse_move: {str(e)}")
|
||||
raise ToolError(f"Failed to move cursor to ({x}, {y}): {str(e)}")
|
||||
|
||||
async def handle_scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> ToolResult:
|
||||
"""Handle scrolling."""
|
||||
try:
|
||||
# Move cursor to position first
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
|
||||
# Scroll based on direction
|
||||
if scroll_y > 0:
|
||||
await self.computer.interface.scroll_down(abs(scroll_y))
|
||||
elif scroll_y < 0:
|
||||
await self.computer.interface.scroll_up(abs(scroll_y))
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return ToolResult(output=f"Scrolled at ({x}, {y}) by ({scroll_x}, {scroll_y})")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_scroll: {str(e)}")
|
||||
raise ToolError(f"Failed to scroll at ({x}, {y}): {str(e)}")
|
||||
|
||||
async def handle_drag(self, path: List[Dict[str, int]]) -> ToolResult:
|
||||
"""Handle mouse drag operation using a path of coordinates.
|
||||
|
||||
Args:
|
||||
path: List of coordinate points {"x": int, "y": int} defining the drag path
|
||||
|
||||
Returns:
|
||||
ToolResult with the operation result and screenshot
|
||||
"""
|
||||
try:
|
||||
# Convert from [{"x": x, "y": y}, ...] format to [(x, y), ...] format
|
||||
points = [(p["x"], p["y"]) for p in path]
|
||||
|
||||
# Perform drag action
|
||||
if len(points) == 2:
|
||||
await self.computer.interface.move_cursor(points[0][0], points[0][1])
|
||||
await self.computer.interface.drag_to(points[1][0], points[1][1])
|
||||
else:
|
||||
await self.computer.interface.drag(points, button="left")
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return ToolResult(
|
||||
output=f"Dragged from ({path[0]['x']}, {path[0]['y']}) to ({path[-1]['x']}, {path[-1]['y']})",
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_drag: {str(e)}")
|
||||
raise ToolError(f"Failed to perform drag operation: {str(e)}")
|
||||
|
||||
async def screenshot(self) -> ToolResult:
|
||||
"""Take a screenshot."""
|
||||
try:
|
||||
# Take screenshot
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(output="Screenshot taken", base64_image=base64_screenshot)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in screenshot: {str(e)}")
|
||||
raise ToolError(f"Failed to take screenshot: {str(e)}")
|
||||
@@ -1,106 +0,0 @@
|
||||
"""Tool manager for the OpenAI provider."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Callable, Awaitable, Union
|
||||
|
||||
from computer import Computer
|
||||
from ..types import ComputerAction, ResponseItemType
|
||||
from .computer import ComputerTool
|
||||
from ....core.tools.base import ToolResult, ToolFailure
|
||||
from ....core.tools.collection import ToolCollection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""Manager for computer tools in the OpenAI agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
computer: Computer,
|
||||
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
|
||||
):
|
||||
"""Initialize the tool manager.
|
||||
|
||||
Args:
|
||||
computer: Computer instance
|
||||
acknowledge_safety_check_callback: Optional callback for safety check acknowledgment
|
||||
"""
|
||||
self.computer = computer
|
||||
self.acknowledge_safety_check_callback = acknowledge_safety_check_callback
|
||||
self._initialized = False
|
||||
self.computer_tool = ComputerTool(computer)
|
||||
self.tools = None
|
||||
logger.info("Initialized OpenAI ToolManager")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the tool manager."""
|
||||
if not self._initialized:
|
||||
logger.info("Initializing OpenAI ToolManager")
|
||||
|
||||
# Initialize the computer tool
|
||||
await self.computer_tool.initialize_dimensions()
|
||||
|
||||
# Initialize tool collection
|
||||
self.tools = ToolCollection(self.computer_tool)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("OpenAI ToolManager initialized")
|
||||
|
||||
async def get_tools_definition(self) -> List[Dict[str, Any]]:
|
||||
"""Get the tools definition for the OpenAI agent.
|
||||
|
||||
Returns:
|
||||
Tools definition for the OpenAI agent
|
||||
"""
|
||||
if not self.tools:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
|
||||
# For the OpenAI Agent Response API, we use a special "computer-preview" tool
|
||||
# which provides the correct interface for computer control
|
||||
display_width, display_height = await self._get_computer_dimensions()
|
||||
|
||||
# Get environment, using "mac" as default since we're on macOS
|
||||
environment = getattr(self.computer, "environment", "mac")
|
||||
|
||||
# Ensure environment is one of the allowed values
|
||||
if environment not in ["windows", "mac", "linux", "browser"]:
|
||||
logger.warning(f"Invalid environment value: {environment}, using 'mac' instead")
|
||||
environment = "mac"
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "computer-preview",
|
||||
"display_width": display_width,
|
||||
"display_height": display_height,
|
||||
"environment": environment,
|
||||
}
|
||||
]
|
||||
|
||||
async def _get_computer_dimensions(self) -> tuple[int, int]:
|
||||
"""Get the dimensions of the computer display.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
# If computer tool is initialized, use its dimensions
|
||||
if self.computer_tool.width is not None and self.computer_tool.height is not None:
|
||||
return (self.computer_tool.width, self.computer_tool.height)
|
||||
|
||||
# Try to get from computer.interface if available
|
||||
screen_size = await self.computer.interface.get_screen_size()
|
||||
return (int(screen_size["width"]), int(screen_size["height"]))
|
||||
|
||||
async def execute_tool(self, name: str, tool_input: Dict[str, Any]) -> ToolResult:
|
||||
"""Execute a tool with the given input.
|
||||
|
||||
Args:
|
||||
name: Name of the tool to execute
|
||||
tool_input: Input parameters for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool execution
|
||||
"""
|
||||
if not self.tools:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
return await self.tools.run(name=name, tool_input=tool_input)
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Type definitions for the OpenAI provider."""
|
||||
|
||||
from enum import StrEnum, auto
|
||||
from typing import Dict, List, Optional, Union, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class LLMProvider(StrEnum):
|
||||
"""OpenAI LLM provider types."""
|
||||
|
||||
OPENAI = "openai"
|
||||
|
||||
|
||||
class ResponseItemType(StrEnum):
|
||||
"""Types of items in OpenAI Agent Response output."""
|
||||
|
||||
MESSAGE = "message"
|
||||
COMPUTER_CALL = "computer_call"
|
||||
COMPUTER_CALL_OUTPUT = "computer_call_output"
|
||||
REASONING = "reasoning"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputerAction:
|
||||
"""Represents a computer action to be performed."""
|
||||
|
||||
type: str
|
||||
x: Optional[int] = None
|
||||
y: Optional[int] = None
|
||||
text: Optional[str] = None
|
||||
button: Optional[str] = None
|
||||
keys: Optional[List[str]] = None
|
||||
ms: Optional[int] = None
|
||||
scroll_x: Optional[int] = None
|
||||
scroll_y: Optional[int] = None
|
||||
path: Optional[List[Dict[str, int]]] = None
|
||||
@@ -1,98 +0,0 @@
|
||||
"""Utility functions for the OpenAI provider."""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ...core.types import AgentResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def format_images_for_openai(images_base64: List[str]) -> List[Dict[str, Any]]:
|
||||
"""Format images for OpenAI Agent Response API.
|
||||
|
||||
Args:
|
||||
images_base64: List of base64 encoded images
|
||||
|
||||
Returns:
|
||||
List of formatted image items for Agent Response API
|
||||
"""
|
||||
return [
|
||||
{"type": "input_image", "image_url": f"data:image/png;base64,{image}"}
|
||||
for image in images_base64
|
||||
]
|
||||
|
||||
|
||||
def extract_message_content(message: Dict[str, Any]) -> str:
|
||||
"""Extract text content from a message.
|
||||
|
||||
Args:
|
||||
message: Message to extract content from
|
||||
|
||||
Returns:
|
||||
Text content from the message
|
||||
"""
|
||||
if isinstance(message.get("content"), str):
|
||||
return message["content"]
|
||||
|
||||
if isinstance(message.get("content"), list):
|
||||
text = ""
|
||||
role = message.get("role", "user")
|
||||
|
||||
for item in message["content"]:
|
||||
if isinstance(item, dict):
|
||||
# For user messages
|
||||
if role == "user" and item.get("type") == "input_text":
|
||||
text += item.get("text", "")
|
||||
# For standard format
|
||||
elif item.get("type") == "text":
|
||||
text += item.get("text", "")
|
||||
# For assistant messages in Agent Response API format
|
||||
elif item.get("type") == "output_text":
|
||||
text += item.get("text", "")
|
||||
return text
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def sanitize_message(msg: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sanitize a message for logging by removing large image data.
|
||||
|
||||
Args:
|
||||
msg: Message to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized message
|
||||
"""
|
||||
if not isinstance(msg, dict):
|
||||
return msg
|
||||
|
||||
sanitized = msg.copy()
|
||||
|
||||
# Handle message content
|
||||
if isinstance(sanitized.get("content"), list):
|
||||
sanitized_content = []
|
||||
for item in sanitized["content"]:
|
||||
if isinstance(item, dict):
|
||||
# Handle various image types
|
||||
if item.get("type") == "image_url" and "image_url" in item:
|
||||
sanitized_content.append({"type": "image_url", "image_url": "[omitted]"})
|
||||
elif item.get("type") == "input_image" and "image_url" in item:
|
||||
sanitized_content.append({"type": "input_image", "image_url": "[omitted]"})
|
||||
elif item.get("type") == "image" and "source" in item:
|
||||
sanitized_content.append({"type": "image", "source": "[omitted]"})
|
||||
else:
|
||||
sanitized_content.append(item)
|
||||
else:
|
||||
sanitized_content.append(item)
|
||||
sanitized["content"] = sanitized_content
|
||||
|
||||
# Handle computer_call_output
|
||||
if sanitized.get("type") == "computer_call_output" and "output" in sanitized:
|
||||
output = sanitized["output"]
|
||||
if isinstance(output, dict) and "image_url" in output:
|
||||
sanitized["output"] = {**output, "image_url": "[omitted]"}
|
||||
|
||||
return sanitized
|
||||
@@ -1 +0,0 @@
|
||||
"""UI-TARS Agent provider package."""
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Base client implementation for Omni providers."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseUITarsClient:
|
||||
"""Base class for provider-specific clients."""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None):
|
||||
"""Initialize base client.
|
||||
|
||||
Args:
|
||||
api_key: Optional API key
|
||||
model: Optional model name
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run interleaved chat completion.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
system: System prompt
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
Response dict
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -1,263 +0,0 @@
|
||||
"""MLX LVM client implementation."""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import base64
|
||||
import tempfile
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
from typing import Dict, List, Optional, Any, cast, Tuple
|
||||
from PIL import Image
|
||||
|
||||
from .base import BaseUITarsClient
|
||||
import mlx.core as mx
|
||||
from mlx_vlm import load, generate
|
||||
from mlx_vlm.prompt_utils import apply_chat_template
|
||||
from mlx_vlm.utils import load_config
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants for smart_resize
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
def round_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
def ceil_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
def floor_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
class MLXVLMUITarsClient(BaseUITarsClient):
|
||||
"""MLX LVM client implementation class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "mlx-community/UI-TARS-1.5-7B-4bit"
|
||||
):
|
||||
"""Initialize MLX LVM client.
|
||||
|
||||
Args:
|
||||
model: Model name or path (defaults to mlx-community/UI-TARS-1.5-7B-4bit)
|
||||
"""
|
||||
# Load model and processor
|
||||
model_obj, processor = load(
|
||||
model,
|
||||
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
|
||||
)
|
||||
self.config = load_config(model)
|
||||
self.model = model_obj
|
||||
self.processor = processor
|
||||
self.model_name = model
|
||||
|
||||
def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
|
||||
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
|
||||
|
||||
Args:
|
||||
text: Text containing box tokens
|
||||
original_size: Original image size (width, height)
|
||||
model_size: Model processed image size (width, height)
|
||||
|
||||
Returns:
|
||||
Text with processed coordinates
|
||||
"""
|
||||
# Find all box tokens
|
||||
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
|
||||
|
||||
def process_coords(match):
|
||||
model_x, model_y = int(match.group(1)), int(match.group(2))
|
||||
# Scale coordinates from model space to original image space
|
||||
# Both original_size and model_size are in (width, height) format
|
||||
new_x = int(model_x * original_size[0] / model_size[0]) # Width
|
||||
new_y = int(model_y * original_size[1] / model_size[1]) # Height
|
||||
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
|
||||
|
||||
return re.sub(box_pattern, process_coords, text)
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run interleaved chat completion.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
system: System prompt
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
Response dict
|
||||
"""
|
||||
# Ensure the system message is included
|
||||
if not any(msg.get("role") == "system" for msg in messages):
|
||||
messages = [{"role": "system", "content": system}] + messages
|
||||
|
||||
# Create a deep copy of messages to avoid modifying the original
|
||||
processed_messages = messages.copy()
|
||||
|
||||
# Extract images and process messages
|
||||
images = []
|
||||
original_sizes = {} # Track original sizes of images for coordinate mapping
|
||||
model_sizes = {} # Track model processed sizes
|
||||
image_index = 0
|
||||
|
||||
for msg_idx, msg in enumerate(messages):
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
# Create a copy of the content list to modify
|
||||
processed_content = []
|
||||
|
||||
for item_idx, item in enumerate(content):
|
||||
if item.get("type") == "image_url":
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
pil_image = None
|
||||
|
||||
if image_url.startswith("data:image/"):
|
||||
# Extract base64 data
|
||||
base64_data = image_url.split(',')[1]
|
||||
# Convert base64 to PIL Image
|
||||
image_data = base64.b64decode(base64_data)
|
||||
pil_image = Image.open(io.BytesIO(image_data))
|
||||
else:
|
||||
# Handle file path or URL
|
||||
pil_image = Image.open(image_url)
|
||||
|
||||
# Store original image size for coordinate mapping
|
||||
original_size = pil_image.size
|
||||
original_sizes[image_index] = original_size
|
||||
|
||||
# Use smart_resize to determine model size
|
||||
# Note: smart_resize expects (height, width) but PIL gives (width, height)
|
||||
height, width = original_size[1], original_size[0]
|
||||
new_height, new_width = smart_resize(height, width)
|
||||
# Store model size in (width, height) format for consistent coordinate processing
|
||||
model_sizes[image_index] = (new_width, new_height)
|
||||
|
||||
# Resize the image using the calculated dimensions from smart_resize
|
||||
resized_image = pil_image.resize((new_width, new_height))
|
||||
images.append(resized_image)
|
||||
image_index += 1
|
||||
|
||||
# Copy items to processed content list
|
||||
processed_content.append(item.copy())
|
||||
|
||||
# Update the processed message content
|
||||
processed_messages[msg_idx] = msg.copy()
|
||||
processed_messages[msg_idx]["content"] = processed_content
|
||||
|
||||
logger.info(f"resized {len(images)} from {original_sizes[0]} to {model_sizes[0]}")
|
||||
|
||||
# Process user text input with box coordinates after image processing
|
||||
# Swap original_size and model_size arguments for inverse transformation
|
||||
for msg_idx, msg in enumerate(processed_messages):
|
||||
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
|
||||
if "<|box_start|>" in msg.get("content") and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
||||
orig_size = original_sizes[0]
|
||||
model_size = model_sizes[0]
|
||||
# Swap arguments to perform inverse transformation for user input
|
||||
processed_messages[msg_idx]["content"] = self._process_coordinates(msg["content"], model_size, orig_size)
|
||||
|
||||
try:
|
||||
# Format prompt according to model requirements using the processor directly
|
||||
prompt = self.processor.apply_chat_template(
|
||||
processed_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
tokenizer = cast(PreTrainedTokenizer, self.processor)
|
||||
|
||||
print("generating response...")
|
||||
|
||||
# Generate response
|
||||
text_content, usage = generate(
|
||||
self.model,
|
||||
tokenizer,
|
||||
str(prompt),
|
||||
images,
|
||||
verbose=False,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
from pprint import pprint
|
||||
print("DEBUG - AGENT GENERATION --------")
|
||||
pprint(text_content)
|
||||
print("DEBUG - AGENT GENERATION --------")
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating response: {str(e)}")
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": f"Error generating response: {str(e)}"
|
||||
},
|
||||
"finish_reason": "error"
|
||||
}
|
||||
],
|
||||
"model": self.model_name,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Process coordinates in the response back to original image space
|
||||
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
||||
# Get original image size and model size (using the first image)
|
||||
orig_size = original_sizes[0]
|
||||
model_size = model_sizes[0]
|
||||
|
||||
# Check if output contains box tokens that need processing
|
||||
if "<|box_start|>" in text_content:
|
||||
# Process coordinates from model space back to original image space
|
||||
text_content = self._process_coordinates(text_content, orig_size, model_size)
|
||||
|
||||
# Format response to match OpenAI format
|
||||
response = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text_content
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"model": self.model_name,
|
||||
"usage": usage
|
||||
}
|
||||
|
||||
return response
|
||||
@@ -1,214 +0,0 @@
|
||||
"""OpenAI-compatible client implementation."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
import aiohttp
|
||||
import re
|
||||
from .base import BaseUITarsClient
|
||||
import asyncio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# OpenAI-compatible client for the UI_Tars
|
||||
class OAICompatClient(BaseUITarsClient):
|
||||
"""OpenAI-compatible API client implementation.
|
||||
|
||||
This client can be used with any service that implements the OpenAI API protocol, including:
|
||||
- Huggingface Text Generation Interface endpoints
|
||||
- vLLM
|
||||
- LM Studio
|
||||
- LocalAI
|
||||
- Ollama (with OpenAI compatibility)
|
||||
- Text Generation WebUI
|
||||
- Any other service with OpenAI API compatibility
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "Qwen2.5-VL-7B-Instruct",
|
||||
provider_base_url: Optional[str] = "http://localhost:8000/v1",
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.0,
|
||||
):
|
||||
"""Initialize the OpenAI-compatible client.
|
||||
|
||||
Args:
|
||||
api_key: Not used for local endpoints, usually set to "EMPTY"
|
||||
model: Model name to use
|
||||
provider_base_url: API base URL. Typically in the format "http://localhost:PORT/v1"
|
||||
Examples:
|
||||
- vLLM: "http://localhost:8000/v1"
|
||||
- LM Studio: "http://localhost:1234/v1"
|
||||
- LocalAI: "http://localhost:8080/v1"
|
||||
- Ollama: "http://localhost:11434/v1"
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Generation temperature
|
||||
"""
|
||||
super().__init__(api_key=api_key or "EMPTY", model=model)
|
||||
self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
|
||||
self.model = model
|
||||
self.provider_base_url = (
|
||||
provider_base_url or "http://localhost:8000/v1"
|
||||
) # Use default if None
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
|
||||
def _extract_base64_image(self, text: str) -> Optional[str]:
|
||||
"""Extract base64 image data from an HTML img tag."""
|
||||
pattern = r'data:image/[^;]+;base64,([^"]+)'
|
||||
match = re.search(pattern, text)
|
||||
return match.group(1) if match else None
|
||||
|
||||
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Create a loggable version of messages with image data truncated."""
|
||||
loggable_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg.get("content"), list):
|
||||
new_content = []
|
||||
for content in msg["content"]:
|
||||
if content.get("type") == "image":
|
||||
new_content.append(
|
||||
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
|
||||
)
|
||||
else:
|
||||
new_content.append(content)
|
||||
loggable_messages.append({"role": msg["role"], "content": new_content})
|
||||
else:
|
||||
loggable_messages.append(msg)
|
||||
return loggable_messages
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run interleaved chat completion.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
system: System prompt
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
Response dict
|
||||
"""
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
final_messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [
|
||||
{ "type": "text", "text": system }
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Process messages
|
||||
for item in messages:
|
||||
if isinstance(item, dict):
|
||||
if isinstance(item["content"], list):
|
||||
# Content is already in the correct format
|
||||
final_messages.append(item)
|
||||
else:
|
||||
# Single string content, check for image
|
||||
base64_img = self._extract_base64_image(item["content"])
|
||||
if base64_img:
|
||||
message = {
|
||||
"role": item["role"],
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
message = {
|
||||
"role": item["role"],
|
||||
"content": [{"type": "text", "text": item["content"]}],
|
||||
}
|
||||
final_messages.append(message)
|
||||
else:
|
||||
# String content, check for image
|
||||
base64_img = self._extract_base64_image(item)
|
||||
if base64_img:
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
message = {"role": "user", "content": [{"type": "text", "text": item}]}
|
||||
final_messages.append(message)
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"messages": final_messages,
|
||||
"max_tokens": max_tokens or self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": 0.7,
|
||||
}
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Use default base URL if none provided
|
||||
base_url = self.provider_base_url or "http://localhost:8000/v1"
|
||||
|
||||
# Check if the base URL already includes the chat/completions endpoint
|
||||
|
||||
endpoint_url = base_url
|
||||
if not endpoint_url.endswith("/chat/completions"):
|
||||
# If URL is RunPod format, make it OpenAI compatible
|
||||
if endpoint_url.startswith("https://api.runpod.ai/v2/"):
|
||||
# Extract RunPod endpoint ID
|
||||
parts = endpoint_url.split("/")
|
||||
if len(parts) >= 5:
|
||||
runpod_id = parts[4]
|
||||
endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
|
||||
# If the URL ends with /v1, append /chat/completions
|
||||
elif endpoint_url.endswith("/v1"):
|
||||
endpoint_url = f"{endpoint_url}/chat/completions"
|
||||
# If the URL doesn't end with /v1, make sure it has a proper structure
|
||||
elif not endpoint_url.endswith("/"):
|
||||
endpoint_url = f"{endpoint_url}/chat/completions"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}chat/completions"
|
||||
|
||||
# Log the endpoint URL for debugging
|
||||
logger.debug(f"Using endpoint URL: {endpoint_url}")
|
||||
|
||||
async with session.post(endpoint_url, headers=headers, json=payload) as response:
|
||||
# Log the status and content type
|
||||
logger.debug(f"Status: {response.status}")
|
||||
logger.debug(f"Content-Type: {response.headers.get('Content-Type')}")
|
||||
|
||||
# Get the raw text of the response
|
||||
response_text = await response.text()
|
||||
logger.debug(f"Response content: {response_text}")
|
||||
|
||||
# if 503, then the endpoint is still warming up
|
||||
if response.status == 503:
|
||||
logger.error(f"Endpoint is still warming up, trying again in 30 seconds...")
|
||||
await asyncio.sleep(30)
|
||||
raise Exception(f"Endpoint is still warming up: {response_text}")
|
||||
|
||||
# Try to parse as JSON if the content type is appropriate
|
||||
if "application/json" in response.headers.get('Content-Type', ''):
|
||||
response_json = await response.json()
|
||||
else:
|
||||
raise Exception(f"Response is not JSON format")
|
||||
|
||||
if response.status != 200:
|
||||
logger.error(f"Error in API call: {response_text}")
|
||||
raise Exception(f"API error: {response_text}")
|
||||
|
||||
return response_json
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in API call: {str(e)}")
|
||||
raise
|
||||
@@ -1,660 +0,0 @@
|
||||
"""UI-TARS-specific agent loop implementation."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import re
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
import copy
|
||||
from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator, cast
|
||||
|
||||
from httpx import ConnectError, ReadTimeout
|
||||
|
||||
from ...core.base import BaseLoop
|
||||
from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
||||
from ...core.types import AgentResponse, LLMProvider
|
||||
from ...core.visualization import VisualizationHelper
|
||||
from computer import Computer
|
||||
|
||||
from .utils import add_box_token, parse_actions, parse_action_parameters, to_agent_response_format
|
||||
from .tools.manager import ToolManager
|
||||
from .tools.computer import ToolResult
|
||||
from .prompts import COMPUTER_USE, SYSTEM_PROMPT, MAC_SPECIFIC_NOTES
|
||||
|
||||
from .clients.oaicompat import OAICompatClient
|
||||
from .clients.mlxvlm import MLXVLMUITarsClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class UITARSLoop(BaseLoop):
|
||||
"""UI-TARS-specific implementation of the agent loop.
|
||||
|
||||
This class extends BaseLoop to provide support for the UI-TARS model
|
||||
with computer control capabilities.
|
||||
"""
|
||||
|
||||
###########################################
|
||||
# INITIALIZATION AND CONFIGURATION
|
||||
###########################################
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
computer: Computer,
|
||||
api_key: str,
|
||||
model: str,
|
||||
provider: Optional[LLMProvider] = None,
|
||||
provider_base_url: Optional[str] = "http://localhost:8000/v1",
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
save_trajectory: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the loop.
|
||||
|
||||
Args:
|
||||
computer: Computer instance
|
||||
api_key: API key (may not be needed for local endpoints)
|
||||
model: Model name (e.g., "ui-tars")
|
||||
provider_base_url: Base URL for the API provider
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
base_dir: Base directory for saving experiment data
|
||||
max_retries: Maximum number of retries for API calls
|
||||
retry_delay: Delay between retries in seconds
|
||||
save_trajectory: Whether to save trajectory data
|
||||
provider: The LLM provider to use (defaults to OAICOMPAT if not specified)
|
||||
"""
|
||||
# Set provider before initializing base class
|
||||
self.provider = provider or LLMProvider.OAICOMPAT
|
||||
self.provider_base_url = provider_base_url
|
||||
|
||||
# Initialize message manager with image retention config
|
||||
self.message_manager = StandardMessageManager(
|
||||
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
||||
)
|
||||
|
||||
# Initialize base class (which will set up experiment manager)
|
||||
super().__init__(
|
||||
computer=computer,
|
||||
model=model,
|
||||
api_key=api_key,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
base_dir=base_dir,
|
||||
save_trajectory=save_trajectory,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Set API client attributes
|
||||
self.client = None
|
||||
self.retry_count = 0
|
||||
self.loop_task = None # Store the loop task for cancellation
|
||||
|
||||
# Initialize visualization helper
|
||||
self.viz_helper = VisualizationHelper(agent=self)
|
||||
|
||||
# Initialize tool manager
|
||||
self.tool_manager = ToolManager(computer=computer)
|
||||
|
||||
logger.info("UITARSLoop initialized with StandardMessageManager")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the loop by setting up tools and clients."""
|
||||
# Initialize base class
|
||||
await super().initialize()
|
||||
|
||||
# Initialize tool manager with error handling
|
||||
try:
|
||||
logger.info("Initializing tool manager...")
|
||||
await self.tool_manager.initialize()
|
||||
logger.info("Tool manager initialized successfully.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing tool manager: {str(e)}")
|
||||
logger.warning("Will attempt to initialize tools on first use.")
|
||||
|
||||
# Initialize client for the selected provider
|
||||
try:
|
||||
await self.initialize_client()
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing client: {str(e)}")
|
||||
raise RuntimeError(f"Failed to initialize client: {str(e)}")
|
||||
|
||||
###########################################
|
||||
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def initialize_client(self) -> None:
|
||||
"""Initialize the appropriate client.
|
||||
|
||||
Implements abstract method from BaseLoop to set up the specific
|
||||
provider client based on the configured provider.
|
||||
"""
|
||||
try:
|
||||
if self.provider == LLMProvider.MLXVLM:
|
||||
logger.info(f"Initializing MLX VLM client for UI-TARS with model {self.model}...")
|
||||
|
||||
self.client = MLXVLMUITarsClient(
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
logger.info(f"Initialized MLX VLM client with model {self.model}")
|
||||
else:
|
||||
# Default to OAICompat client for other providers
|
||||
logger.info(f"Initializing OAICompat client for UI-TARS with model {self.model}...")
|
||||
|
||||
self.client = OAICompatClient(
|
||||
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
|
||||
model=self.model,
|
||||
provider_base_url=self.provider_base_url,
|
||||
)
|
||||
|
||||
logger.info(f"Initialized OAICompat client with model {self.model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing client: {str(e)}")
|
||||
self.client = None
|
||||
raise RuntimeError(f"Failed to initialize client: {str(e)}")
|
||||
|
||||
###########################################
|
||||
# MESSAGE FORMATTING
|
||||
###########################################
|
||||
|
||||
def to_uitars_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert messages to UI-TARS compatible format.
|
||||
|
||||
Args:
|
||||
messages: List of messages in standard format
|
||||
|
||||
Returns:
|
||||
List of messages formatted for UI-TARS
|
||||
"""
|
||||
# Create a copy of the messages to avoid modifying the original
|
||||
uitars_messages = copy.deepcopy(messages)
|
||||
|
||||
# Find the first user message to modify
|
||||
first_user_idx = None
|
||||
instruction = ""
|
||||
|
||||
for idx, msg in enumerate(uitars_messages):
|
||||
if msg.get("role") == "user":
|
||||
first_user_idx = idx
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
instruction = content
|
||||
break
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
instruction = item.get("text", "")
|
||||
break
|
||||
if instruction:
|
||||
break
|
||||
|
||||
# Only modify the first user message if found
|
||||
if first_user_idx is not None and instruction:
|
||||
# Create the computer use prompt
|
||||
user_prompt = COMPUTER_USE.format(
|
||||
instruction='\n'.join([instruction, MAC_SPECIFIC_NOTES]),
|
||||
language="English"
|
||||
)
|
||||
|
||||
# Replace the content of the first user message
|
||||
if isinstance(uitars_messages[first_user_idx].get("content", ""), str):
|
||||
uitars_messages[first_user_idx]["content"] = [{"type": "text", "text": user_prompt}]
|
||||
elif isinstance(uitars_messages[first_user_idx].get("content", ""), list):
|
||||
# Find and replace only the text part, keeping images
|
||||
for i, item in enumerate(uitars_messages[first_user_idx]["content"]):
|
||||
if item.get("type") == "text":
|
||||
uitars_messages[first_user_idx]["content"][i]["text"] = user_prompt
|
||||
break
|
||||
|
||||
# Add box tokens to assistant responses
|
||||
for idx, msg in enumerate(uitars_messages):
|
||||
if msg.get("role") == "assistant":
|
||||
content = msg.get("content", "")
|
||||
if content and isinstance(content, list):
|
||||
for i, part in enumerate(content):
|
||||
if part.get('type') == 'text':
|
||||
uitars_messages[idx]["content"][i]["text"] = add_box_token(part['text'])
|
||||
|
||||
return uitars_messages
|
||||
|
||||
###########################################
|
||||
# API CALL HANDLING
|
||||
###########################################
|
||||
|
||||
async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any:
|
||||
"""Make API call to provider with retry logic."""
|
||||
# Create new turn directory for this API call
|
||||
self._create_turn_dir()
|
||||
|
||||
request_data = None
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# Ensure client is initialized
|
||||
if self.client is None:
|
||||
logger.info(
|
||||
f"Client not initialized in _make_api_call (attempt {attempt+1}), initializing now..."
|
||||
)
|
||||
await self.initialize_client()
|
||||
if self.client is None:
|
||||
raise RuntimeError("Failed to initialize client")
|
||||
|
||||
# Get messages in standard format from the message manager
|
||||
self.message_manager.messages = messages.copy()
|
||||
prepared_messages = self.message_manager.get_messages()
|
||||
|
||||
# Convert messages to UI-TARS format
|
||||
uitars_messages = self.to_uitars_format(prepared_messages)
|
||||
|
||||
# Log request
|
||||
request_data = {
|
||||
"messages": uitars_messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"system": system_prompt,
|
||||
}
|
||||
|
||||
self._log_api_call("request", request_data)
|
||||
|
||||
# Make API call
|
||||
response = await self.client.run_interleaved(
|
||||
messages=uitars_messages,
|
||||
system=system_prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
)
|
||||
|
||||
# Log success response
|
||||
self._log_api_call("response", request_data, response)
|
||||
|
||||
return response
|
||||
|
||||
except (ConnectError, ReadTimeout) as e:
|
||||
last_error = e
|
||||
logger.warning(
|
||||
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
|
||||
)
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
||||
# Reset client on connection errors to force re-initialization
|
||||
self.client = None
|
||||
continue
|
||||
|
||||
except RuntimeError as e:
|
||||
# Handle client initialization errors specifically
|
||||
last_error = e
|
||||
self._log_api_call("error", request_data, error=e)
|
||||
logger.error(
|
||||
f"Client initialization error (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
|
||||
)
|
||||
if attempt < self.max_retries - 1:
|
||||
# Reset client to force re-initialization
|
||||
self.client = None
|
||||
await asyncio.sleep(self.retry_delay)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
# Log unexpected error
|
||||
last_error = e
|
||||
self._log_api_call("error", request_data, error=e)
|
||||
logger.error(f"Unexpected error in API call: {str(e)}")
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay)
|
||||
continue
|
||||
|
||||
# If we get here, all retries failed
|
||||
error_message = f"API call failed after {self.max_retries} attempts"
|
||||
if last_error:
|
||||
error_message += f": {str(last_error)}"
|
||||
|
||||
logger.error(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
###########################################
|
||||
# RESPONSE AND ACTION HANDLING
|
||||
###########################################
|
||||
|
||||
async def _handle_response(
|
||||
self, response: Any, messages: List[Dict[str, Any]]
|
||||
) -> Tuple[bool, bool]:
|
||||
"""Handle API response.
|
||||
|
||||
Args:
|
||||
response: API response
|
||||
messages: List of messages to update
|
||||
|
||||
Returns:
|
||||
Tuple of (should_continue, action_screenshot_saved)
|
||||
"""
|
||||
action_screenshot_saved = False
|
||||
|
||||
try:
|
||||
# Step 1: Extract the raw response text
|
||||
raw_text = None
|
||||
|
||||
try:
|
||||
# OpenAI-compatible response format
|
||||
raw_text = response["choices"][0]["message"]["content"]
|
||||
except (KeyError, TypeError, IndexError) as e:
|
||||
logger.error(f"Invalid response format: {str(e)}")
|
||||
return True, action_screenshot_saved
|
||||
|
||||
# Step 2: Add the response to message history
|
||||
self.message_manager.add_assistant_message([{"type": "text", "text": raw_text}])
|
||||
|
||||
# Step 3: Parse actions from the response
|
||||
parsed_actions = parse_actions(raw_text)
|
||||
|
||||
if not parsed_actions:
|
||||
logger.warning("No action found in the response")
|
||||
return True, action_screenshot_saved
|
||||
|
||||
# Step 4: Execute each action
|
||||
for action in parsed_actions:
|
||||
action_type = None
|
||||
|
||||
# Handle "finished" action
|
||||
if action.startswith("finished"):
|
||||
logger.info("Agent completed the task")
|
||||
return False, action_screenshot_saved
|
||||
|
||||
# Process other action types (click, type, etc.)
|
||||
try:
|
||||
# Parse action parameters using the utility function
|
||||
action_name, tool_args = parse_action_parameters(action)
|
||||
|
||||
if not action_name:
|
||||
logger.warning(f"Could not parse action: {action}")
|
||||
continue
|
||||
|
||||
# Mark actions that would create screenshots
|
||||
if action_name in ["click", "left_double", "right_single", "drag", "scroll"]:
|
||||
action_screenshot_saved = True
|
||||
|
||||
# Execute the tool with prepared arguments
|
||||
await self._ensure_tools_initialized()
|
||||
|
||||
# Let's log what we're about to execute for debugging
|
||||
logger.info(f"Executing computer tool with arguments: {tool_args}")
|
||||
|
||||
result = await self.tool_manager.execute_tool(name="computer", tool_input=tool_args)
|
||||
|
||||
# Handle the result
|
||||
if hasattr(result, "error") and result.error:
|
||||
logger.error(f"Error executing tool: {result.error}")
|
||||
else:
|
||||
# Action was successful
|
||||
logger.info(f"Successfully executed {action_name}")
|
||||
|
||||
# Save screenshot if one was returned and we haven't already saved one
|
||||
if hasattr(result, "base64_image") and result.base64_image:
|
||||
self._save_screenshot(result.base64_image, action_type=action_name)
|
||||
action_screenshot_saved = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing action {action}: {str(e)}")
|
||||
|
||||
# Continue the loop if there are actions to process
|
||||
return True, action_screenshot_saved
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling response: {str(e)}")
|
||||
# Add error message using the message manager
|
||||
error_message = [{"type": "text", "text": f"Error: {str(e)}"}]
|
||||
self.message_manager.add_assistant_message(error_message)
|
||||
raise
|
||||
|
||||
###########################################
|
||||
# SCREEN HANDLING
|
||||
###########################################
|
||||
|
||||
async def _get_current_screen(self, save_screenshot: bool = True) -> str:
|
||||
"""Get the current screen as a base64 encoded image.
|
||||
|
||||
Args:
|
||||
save_screenshot: Whether to save the screenshot
|
||||
|
||||
Returns:
|
||||
Base64 encoded screenshot
|
||||
"""
|
||||
try:
|
||||
# Take a screenshot
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
|
||||
# Convert to base64
|
||||
img_base64 = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
# Process screenshot through hooks and save if needed
|
||||
await self.handle_screenshot(img_base64, action_type="state")
|
||||
|
||||
# Save screenshot if requested
|
||||
if save_screenshot and self.save_trajectory:
|
||||
self._save_screenshot(img_base64, action_type="state")
|
||||
|
||||
return img_base64
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting current screen: {str(e)}")
|
||||
raise
|
||||
|
||||
###########################################
|
||||
# SYSTEM PROMPT
|
||||
###########################################
|
||||
|
||||
def _get_system_prompt(self) -> str:
|
||||
"""Get the system prompt for the model."""
|
||||
return SYSTEM_PROMPT
|
||||
|
||||
###########################################
|
||||
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
messages: List of messages in standard OpenAI format
|
||||
|
||||
Yields:
|
||||
Agent response format
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting UITARSLoop run with {len(messages)} messages")
|
||||
|
||||
# Initialize the message manager with the provided messages
|
||||
self.message_manager.messages = messages.copy()
|
||||
|
||||
# Create queue for response streaming
|
||||
queue = asyncio.Queue()
|
||||
|
||||
# Start loop in background task
|
||||
self.loop_task = asyncio.create_task(self._run_loop(queue, messages))
|
||||
|
||||
# Process and yield messages as they arrive
|
||||
while True:
|
||||
try:
|
||||
item = await queue.get()
|
||||
if item is None: # Stop signal
|
||||
break
|
||||
yield item
|
||||
queue.task_done()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing queue item: {str(e)}")
|
||||
continue
|
||||
|
||||
# Wait for loop to complete
|
||||
await self.loop_task
|
||||
|
||||
# Send completion message
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": "Task completed successfully.",
|
||||
"metadata": {"title": "✅ Complete"},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run method: {str(e)}")
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
|
||||
async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Internal method to run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
queue: Queue to put responses into
|
||||
messages: List of messages in standard OpenAI format
|
||||
"""
|
||||
# Continue running until explicitly told to stop
|
||||
running = True
|
||||
turn_created = False
|
||||
# Track if an action-specific screenshot has been saved this turn
|
||||
action_screenshot_saved = False
|
||||
|
||||
attempt = 0
|
||||
max_attempts = 3
|
||||
|
||||
try:
|
||||
while running and attempt < max_attempts:
|
||||
try:
|
||||
# Create a new turn directory if it's not already created
|
||||
if not turn_created:
|
||||
self._create_turn_dir()
|
||||
turn_created = True
|
||||
|
||||
# Ensure client is initialized
|
||||
if self.client is None:
|
||||
logger.info("Initializing client...")
|
||||
await self.initialize_client()
|
||||
if self.client is None:
|
||||
raise RuntimeError("Failed to initialize client")
|
||||
logger.info("Client initialized successfully")
|
||||
|
||||
# Get current screen
|
||||
base64_screenshot = await self._get_current_screen()
|
||||
|
||||
# Add screenshot to message history
|
||||
self.message_manager.add_user_message(
|
||||
[
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{base64_screenshot}"},
|
||||
}
|
||||
]
|
||||
)
|
||||
logger.info("Added screenshot to message history")
|
||||
|
||||
# Get system prompt
|
||||
system_prompt = self._get_system_prompt()
|
||||
|
||||
# Make API call with retries
|
||||
response = await self._make_api_call(
|
||||
self.message_manager.messages, system_prompt
|
||||
)
|
||||
|
||||
# Handle the response (may execute actions)
|
||||
# Returns: (should_continue, action_screenshot_saved)
|
||||
should_continue, new_screenshot_saved = await self._handle_response(
|
||||
response, self.message_manager.messages
|
||||
)
|
||||
|
||||
# Update whether an action screenshot was saved this turn
|
||||
action_screenshot_saved = action_screenshot_saved or new_screenshot_saved
|
||||
|
||||
agent_response = await to_agent_response_format(
|
||||
response,
|
||||
messages,
|
||||
model=self.model,
|
||||
)
|
||||
# Log standardized response for ease of parsing
|
||||
self._log_api_call("agent_response", request=None, response=agent_response)
|
||||
|
||||
# Put the response in the queue
|
||||
await queue.put(agent_response)
|
||||
|
||||
# Check if we should continue this conversation
|
||||
running = should_continue
|
||||
|
||||
# Create a new turn directory if we're continuing
|
||||
if running:
|
||||
turn_created = False
|
||||
|
||||
# Reset attempt counter on success
|
||||
attempt = 0
|
||||
|
||||
except Exception as e:
|
||||
attempt += 1
|
||||
error_msg = f"Error in run method (attempt {attempt}/{max_attempts}): {str(e)}"
|
||||
logger.error(error_msg)
|
||||
|
||||
# If this is our last attempt, provide more info about the error
|
||||
if attempt >= max_attempts:
|
||||
logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}")
|
||||
|
||||
await queue.put({
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
})
|
||||
|
||||
# Create a brief delay before retrying
|
||||
await asyncio.sleep(1)
|
||||
finally:
|
||||
# Signal that we're done
|
||||
await queue.put(None)
|
||||
|
||||
async def cancel(self) -> None:
|
||||
"""Cancel the currently running agent loop task.
|
||||
|
||||
This method stops the ongoing processing in the agent loop
|
||||
by cancelling the loop_task if it exists and is running.
|
||||
"""
|
||||
if self.loop_task and not self.loop_task.done():
|
||||
logger.info("Cancelling UITARS loop task")
|
||||
self.loop_task.cancel()
|
||||
try:
|
||||
# Wait for the task to be cancelled with a timeout
|
||||
await asyncio.wait_for(self.loop_task, timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Timeout while waiting for loop task to cancel")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Loop task cancelled successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error while cancelling loop task: {str(e)}")
|
||||
finally:
|
||||
logger.info("UITARS loop task cancelled")
|
||||
else:
|
||||
logger.info("No active UITARS loop task to cancel")
|
||||
|
||||
###########################################
|
||||
# UTILITY METHODS
|
||||
###########################################
|
||||
|
||||
async def _ensure_tools_initialized(self) -> None:
|
||||
"""Ensure the tool manager and tools are initialized before use."""
|
||||
if not hasattr(self.tool_manager, "tools") or self.tool_manager.tools is None:
|
||||
logger.info("Tools not initialized. Initializing now...")
|
||||
await self.tool_manager.initialize()
|
||||
logger.info("Tools initialized successfully.")
|
||||
|
||||
async def process_model_response(self, response_text: str) -> Optional[Dict[str, Any]]:
|
||||
"""Process model response to extract tool calls.
|
||||
|
||||
Args:
|
||||
response_text: Model response text
|
||||
|
||||
Returns:
|
||||
Extracted tool information, or None if no tool call was found
|
||||
"""
|
||||
# UI-TARS doesn't use the standard tool call format, so we parse its actions differently
|
||||
parsed_actions = parse_actions(response_text)
|
||||
|
||||
if parsed_actions:
|
||||
return {"actions": parsed_actions}
|
||||
|
||||
return None
|
||||
@@ -1,63 +0,0 @@
|
||||
"""Prompts for UI-TARS agent."""
|
||||
|
||||
MAC_SPECIFIC_NOTES = """
|
||||
(You are operating on macOS, use 'cmd' instead of 'ctrl' for most shortcuts e.g., hotkey(key='cmd c') for copy, hotkey(key='cmd v') for paste, hotkey(key='cmd t') for new tab).)
|
||||
"""
|
||||
|
||||
SYSTEM_PROMPT = "You are a helpful assistant."
|
||||
|
||||
COMPUTER_USE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
|
||||
## Action Space
|
||||
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
left_double(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
right_single(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
hotkey(key='')
|
||||
type(content='xxx') # Use escape characters \\', \\\", and \\n in content part to ensure we can parse the content in normal python string format. If you want to submit your input, use \\n at the end of content.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
wait() #Sleep for 5s and take a screenshot to check for any changes.
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
MOBILE_USE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
## Output Format
|
||||
```
|
||||
Thought: ...
|
||||
Action: ...
|
||||
```
|
||||
## Action Space
|
||||
|
||||
click(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
long_press(start_box='<|box_start|>(x1,y1)<|box_end|>')
|
||||
type(content='') #If you want to submit your input, use "\\n" at the end of `content`.
|
||||
scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
|
||||
open_app(app_name=\'\')
|
||||
drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>')
|
||||
press_home()
|
||||
press_back()
|
||||
finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format.
|
||||
|
||||
|
||||
## Note
|
||||
- Use {language} in `Thought` part.
|
||||
- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part.
|
||||
|
||||
## User Instruction
|
||||
{instruction}
|
||||
"""
|
||||
@@ -1 +0,0 @@
|
||||
"""UI-TARS tools package."""
|
||||
@@ -1,283 +0,0 @@
|
||||
"""Computer tool for UI-TARS."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Literal, Union
|
||||
|
||||
from computer import Computer
|
||||
from ....core.tools.base import ToolResult, ToolFailure
|
||||
from ....core.tools.computer import BaseComputerTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ComputerTool(BaseComputerTool):
|
||||
"""
|
||||
A tool that allows the UI-TARS agent to interact with the screen, keyboard, and mouse.
|
||||
"""
|
||||
|
||||
name: str = "computer"
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
computer: Computer
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the computer tool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance
|
||||
"""
|
||||
super().__init__(computer)
|
||||
self.computer = computer
|
||||
self.width = None
|
||||
self.height = None
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
if self.width is None or self.height is None:
|
||||
raise RuntimeError(
|
||||
"Screen dimensions not initialized. Call initialize_dimensions() first."
|
||||
)
|
||||
return {
|
||||
"type": "computer",
|
||||
"display_width": self.width,
|
||||
"display_height": self.height,
|
||||
}
|
||||
|
||||
async def initialize_dimensions(self) -> None:
|
||||
"""Initialize screen dimensions from the computer interface."""
|
||||
try:
|
||||
display_size = await self.computer.interface.get_screen_size()
|
||||
self.width = display_size["width"]
|
||||
self.height = display_size["height"]
|
||||
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
|
||||
except Exception as e:
|
||||
# Fall back to defaults if we can't get accurate dimensions
|
||||
self.width = 1024
|
||||
self.height = 768
|
||||
self.logger.warning(
|
||||
f"Failed to get screen dimensions, using defaults: {self.width}x{self.height}. Error: {e}"
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
action: str,
|
||||
**kwargs,
|
||||
) -> ToolResult:
|
||||
"""Execute a computer action.
|
||||
|
||||
Args:
|
||||
action: The action to perform (based on UI-TARS action space)
|
||||
**kwargs: Additional parameters for the action
|
||||
|
||||
Returns:
|
||||
ToolResult containing action output and possibly a base64 image
|
||||
"""
|
||||
try:
|
||||
# Ensure dimensions are initialized
|
||||
if self.width is None or self.height is None:
|
||||
await self.initialize_dimensions()
|
||||
if self.width is None or self.height is None:
|
||||
return ToolFailure(error="Failed to initialize screen dimensions")
|
||||
|
||||
# Handle actions defined in UI-TARS action space (from prompts.py)
|
||||
# Handle standard click (left click)
|
||||
if action == "click":
|
||||
if "x" in kwargs and "y" in kwargs:
|
||||
x, y = kwargs["x"], kwargs["y"]
|
||||
await self.computer.interface.left_click(x, y)
|
||||
|
||||
# Wait briefly for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Clicked at ({x}, {y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing coordinates for click action")
|
||||
|
||||
# Handle double click
|
||||
elif action == "left_double":
|
||||
if "x" in kwargs and "y" in kwargs:
|
||||
x, y = kwargs["x"], kwargs["y"]
|
||||
await self.computer.interface.double_click(x, y)
|
||||
|
||||
# Wait briefly for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Double-clicked at ({x}, {y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing coordinates for left_double action")
|
||||
|
||||
# Handle right click
|
||||
elif action == "right_single":
|
||||
if "x" in kwargs and "y" in kwargs:
|
||||
x, y = kwargs["x"], kwargs["y"]
|
||||
await self.computer.interface.right_click(x, y)
|
||||
|
||||
# Wait briefly for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Right-clicked at ({x}, {y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing coordinates for right_single action")
|
||||
|
||||
# Handle typing text
|
||||
elif action == "type_text":
|
||||
if "text" in kwargs:
|
||||
text = kwargs["text"]
|
||||
await self.computer.interface.type_text(text)
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Typed: {text}",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing text for type action")
|
||||
|
||||
# Handle hotkey
|
||||
elif action == "hotkey":
|
||||
if "keys" in kwargs:
|
||||
keys = kwargs["keys"]
|
||||
|
||||
if len(keys) > 1:
|
||||
await self.computer.interface.hotkey(*keys)
|
||||
else:
|
||||
# Single key press
|
||||
await self.computer.interface.press_key(keys[0])
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Pressed hotkey: {', '.join(keys)}",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing keys for hotkey action")
|
||||
|
||||
# Handle drag action
|
||||
elif action == "drag":
|
||||
if all(k in kwargs for k in ["start_x", "start_y", "end_x", "end_y"]):
|
||||
start_x, start_y = kwargs["start_x"], kwargs["start_y"]
|
||||
end_x, end_y = kwargs["end_x"], kwargs["end_y"]
|
||||
|
||||
# Perform drag
|
||||
await self.computer.interface.move_cursor(start_x, start_y)
|
||||
await self.computer.interface.drag_to(end_x, end_y)
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Dragged from ({start_x}, {start_y}) to ({end_x}, {end_y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing coordinates for drag action")
|
||||
|
||||
# Handle scroll action
|
||||
elif action == "scroll":
|
||||
if all(k in kwargs for k in ["x", "y", "direction"]):
|
||||
x, y = kwargs["x"], kwargs["y"]
|
||||
direction = kwargs["direction"]
|
||||
|
||||
# Move cursor to position
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
|
||||
# Scroll based on direction
|
||||
if direction == "down":
|
||||
await self.computer.interface.scroll_down(5)
|
||||
elif direction == "up":
|
||||
await self.computer.interface.scroll_up(5)
|
||||
elif direction == "right":
|
||||
pass # await self.computer.interface.scroll_right(5)
|
||||
elif direction == "left":
|
||||
pass # await self.computer.interface.scroll_left(5)
|
||||
else:
|
||||
return ToolFailure(error=f"Invalid scroll direction: {direction}")
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Scrolled {direction} at ({x}, {y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing parameters for scroll action")
|
||||
|
||||
# Handle wait action
|
||||
elif action == "wait":
|
||||
# Sleep for 5 seconds as specified in the action space
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Take screenshot after waiting
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output="Waited for 5 seconds",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
|
||||
# Handle finished action (task completion)
|
||||
elif action == "finished":
|
||||
content = kwargs.get("content", "Task completed")
|
||||
return ToolResult(
|
||||
output=f"Task finished: {content}",
|
||||
)
|
||||
|
||||
return await self._handle_scroll(action)
|
||||
else:
|
||||
return ToolFailure(error=f"Unsupported action: {action}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in ComputerTool.__call__: {str(e)}")
|
||||
return ToolFailure(error=f"Failed to execute {action}: {str(e)}")
|
||||
@@ -1,60 +0,0 @@
|
||||
"""Tool manager for the UI-TARS provider."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from computer import Computer
|
||||
from ....core.tools import BaseToolManager
|
||||
from ....core.tools.collection import ToolCollection
|
||||
from .computer import ComputerTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolManager(BaseToolManager):
|
||||
"""Manages UI-TARS provider tool initialization and execution."""
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the tool manager.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for computer-related tools
|
||||
"""
|
||||
super().__init__(computer)
|
||||
# Initialize UI-TARS-specific tools
|
||||
self.computer_tool = ComputerTool(self.computer)
|
||||
self._initialized = False
|
||||
|
||||
def _initialize_tools(self) -> ToolCollection:
|
||||
"""Initialize all available tools."""
|
||||
return ToolCollection(self.computer_tool)
|
||||
|
||||
async def _initialize_tools_specific(self) -> None:
|
||||
"""Initialize UI-TARS provider-specific tool requirements."""
|
||||
await self.computer_tool.initialize_dimensions()
|
||||
|
||||
def get_tool_params(self) -> List[Dict[str, Any]]:
|
||||
"""Get tool parameters for API calls.
|
||||
|
||||
Returns:
|
||||
List of tool parameters for the current provider's API
|
||||
"""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
|
||||
return self.tools.to_params()
|
||||
|
||||
async def execute_tool(self, name: str, tool_input: dict[str, Any]) -> Any:
|
||||
"""Execute a tool with the given input.
|
||||
|
||||
Args:
|
||||
name: Name of the tool to execute
|
||||
tool_input: Input parameters for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool execution
|
||||
"""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
|
||||
return await self.tools.run(name=name, tool_input=tool_input)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user