Standardize Agent Loop

This commit is contained in:
f-trycua
2025-03-23 23:40:18 +01:00
parent cc3891a7ad
commit e32b64590a
21 changed files with 2243 additions and 1767 deletions

View File

@@ -6,6 +6,7 @@ import logging
import traceback
from pathlib import Path
import signal
import json
from computer import Computer
@@ -32,42 +33,31 @@ async def run_omni_agent_example():
# Create agent with loop and provider
agent = ComputerAgent(
computer=computer,
loop=AgentLoop.ANTHROPIC,
# loop=AgentLoop.OMNI,
# loop=AgentLoop.ANTHROPIC,
loop=AgentLoop.OMNI,
# model=LLM(provider=LLMProvider.OPENAI, name="gpt-4.5-preview"),
model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"),
save_trajectory=True,
trajectory_dir=str(Path("trajectories")),
only_n_most_recent_images=3,
verbosity=logging.INFO,
verbosity=logging.DEBUG,
)
tasks = [
"""
1. Look for a repository named trycua/lume on GitHub.
2. Check the open issues, open the most recent one and read it.
3. Clone the repository in users/lume/projects if it doesn't exist yet.
4. Open the repository with an app named Cursor (on the dock, black background and white cube icon).
5. From Cursor, open Composer if not already open.
6. Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.
"""
"Look for a repository named trycua/cua on GitHub.",
"Check the open issues, open the most recent one and read it.",
"Clone the repository in users/lume/projects if it doesn't exist yet.",
"Open the repository with an app named Cursor (on the dock, black background and white cube icon).",
"From Cursor, open Composer if not already open.",
"Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.",
]
async with agent:
for i, task in enumerate(tasks, 1):
for i, task in enumerate(tasks):
print(f"\nExecuting task {i}/{len(tasks)}: {task}")
async for result in agent.run(task):
# Check if result has the expected structure
if "role" in result and "content" in result and "metadata" in result:
title = result["metadata"].get("title", "Screen Analysis")
content = result["content"]
else:
title = result.get("metadata", {}).get("title", "Screen Analysis")
content = result.get("content", str(result))
print(result)
print(f"\n{title}")
print(content)
print(f"Task {i} completed")
print(f"\n✅ Task {i+1}/{len(tasks)} completed: {task}")
except Exception as e:
logger.error(f"Error in run_omni_agent_example: {e}")

View File

@@ -2,11 +2,6 @@
from .loop import BaseLoop
from .messages import (
create_user_message,
create_assistant_message,
create_system_message,
create_image_message,
create_screen_message,
BaseMessageManager,
ImageRetentionConfig,
)

View File

@@ -3,8 +3,7 @@
import asyncio
import logging
import os
from typing import Any, AsyncGenerator, Dict, Optional, cast
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Optional, cast, List
from computer import Computer
from ..providers.anthropic.loop import AnthropicLoop
@@ -12,6 +11,7 @@ from ..providers.omni.loop import OmniLoop
from ..providers.omni.parser import OmniParser
from ..providers.omni.types import LLMProvider, LLM
from .. import AgentLoop
from .messages import StandardMessageManager, ImageRetentionConfig
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -44,7 +44,6 @@ class ComputerAgent:
save_trajectory: bool = True,
trajectory_dir: str = "trajectories",
only_n_most_recent_images: Optional[int] = None,
parser: Optional[OmniParser] = None,
verbosity: int = logging.INFO,
):
"""Initialize the ComputerAgent.
@@ -61,7 +60,6 @@ class ComputerAgent:
save_trajectory: Whether to save the trajectory.
trajectory_dir: Directory to save the trajectory.
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests.
parser: Parser instance for the OmniLoop. Only used if provider is not ANTHROPIC.
verbosity: Logging level.
"""
# Basic agent configuration
@@ -74,6 +72,11 @@ class ComputerAgent:
self._initialized = False
self._in_context = False
# Initialize the message manager for standardized message handling
self.message_manager = StandardMessageManager(
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
)
# Set logging level
logger.setLevel(verbosity)
@@ -118,10 +121,6 @@ class ComputerAgent:
only_n_most_recent_images=only_n_most_recent_images,
)
else:
# Default to OmniLoop for other loop types
# Initialize parser if not provided
actual_parser = parser or OmniParser()
self._loop = OmniLoop(
provider=self.provider,
api_key=actual_api_key,
@@ -130,7 +129,7 @@ class ComputerAgent:
save_trajectory=save_trajectory,
base_dir=trajectory_dir,
only_n_most_recent_images=only_n_most_recent_images,
parser=actual_parser,
parser=OmniParser(),
)
logger.info(
@@ -224,13 +223,25 @@ class ComputerAgent:
"""
try:
logger.info(f"Running task: {task}")
logger.info(
f"Message history before task has {len(self.message_manager.messages)} messages"
)
# Initialize the computer if needed
if not self._initialized:
await self.initialize()
# Format task as a message
messages = [{"role": "user", "content": task}]
# Add task as a user message using the message manager
self.message_manager.add_user_message([{"type": "text", "text": task}])
logger.info(
f"Added task message. Message history now has {len(self.message_manager.messages)} messages"
)
# Log message history types to help with debugging
message_types = [
f"{i}: {msg['role']}" for i, msg in enumerate(self.message_manager.messages)
]
logger.info(f"Message history roles: {', '.join(message_types)}")
# Pass properly formatted messages to the loop
if self._loop is None:
@@ -239,9 +250,28 @@ class ComputerAgent:
return
# Execute the task and yield results
async for result in self._loop.run(messages):
async for result in self._loop.run(self.message_manager.messages):
# Extract the assistant message from the result and add it to our history
assistant_response = result["response"]["choices"][0].get("message", None)
if assistant_response and assistant_response.get("role") == "assistant":
# Extract the content from the assistant response
content = assistant_response.get("content")
self.message_manager.add_assistant_message(content)
logger.info("Added assistant response to message history")
# Yield the result to the caller
yield result
# Logging the message history for debugging
logger.info(
f"Updated message history now has {len(self.message_manager.messages)} messages"
)
message_types = [
f"{i}: {msg['role']}" for i, msg in enumerate(self.message_manager.messages)
]
logger.info(f"Updated message history roles: {', '.join(message_types)}")
except Exception as e:
logger.error(f"Error in agent run method: {str(e)}")
yield {

View File

@@ -2,12 +2,9 @@
import logging
import asyncio
import json
import os
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from datetime import datetime
import base64
from computer import Computer
from .experiment import ExperimentManager
@@ -18,6 +15,10 @@ logger = logging.getLogger(__name__)
class BaseLoop(ABC):
"""Base class for agent loops that handle message processing and tool execution."""
###########################################
# INITIALIZATION AND CONFIGURATION
###########################################
def __init__(
self,
computer: Computer,
@@ -75,6 +76,64 @@ class BaseLoop(ABC):
# Initialize basic tracking
self.turn_count = 0
async def initialize(self) -> None:
"""Initialize both the API client and computer interface with retries."""
for attempt in range(self.max_retries):
try:
logger.info(
f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..."
)
# Initialize API client
await self.initialize_client()
logger.info("Initialization complete.")
return
except Exception as e:
if attempt < self.max_retries - 1:
logger.warning(
f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..."
)
await asyncio.sleep(self.retry_delay)
else:
logger.error(
f"Initialization failed after {self.max_retries} attempts: {str(e)}"
)
raise RuntimeError(f"Failed to initialize: {str(e)}")
###########################################
# ABSTRACT METHODS TO BE IMPLEMENTED BY SUBCLASSES
###########################################
@abstractmethod
async def initialize_client(self) -> None:
"""Initialize the API client and any provider-specific components.
This method must be implemented by subclasses to set up
provider-specific clients and tools.
"""
raise NotImplementedError
@abstractmethod
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
"""Run the agent loop with provided messages.
This method handles the main agent loop including message processing,
API calls, response handling, and action execution.
Args:
messages: List of message objects
Yields:
Dict containing response data
"""
raise NotImplementedError
###########################################
# EXPERIMENT AND TRAJECTORY MANAGEMENT
###########################################
def _setup_experiment_dirs(self) -> None:
"""Setup the experiment directory structure."""
if self.experiment_manager:
@@ -100,10 +159,13 @@ class BaseLoop(ABC):
) -> None:
"""Log API call details to file.
Preserves provider-specific formats for requests and responses to ensure
accurate logging for debugging and analysis purposes.
Args:
call_type: Type of API call (e.g., 'request', 'response', 'error')
request: The API request data
response: Optional API response data
request: The API request data in provider-specific format
response: Optional API response data in provider-specific format
error: Optional error information
"""
if self.experiment_manager:
@@ -130,119 +192,155 @@ class BaseLoop(ABC):
if self.experiment_manager:
self.experiment_manager.save_screenshot(img_base64, action_type)
async def initialize(self) -> None:
"""Initialize both the API client and computer interface with retries."""
for attempt in range(self.max_retries):
try:
logger.info(
f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..."
)
def _create_openai_compatible_response(
self, response: Any, messages: List[Dict[str, Any]]
) -> Dict[str, Any]:
"""Create an OpenAI computer use agent compatible response format.
# Initialize API client
await self.initialize_client()
logger.info("Initialization complete.")
return
except Exception as e:
if attempt < self.max_retries - 1:
logger.warning(
f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..."
)
await asyncio.sleep(self.retry_delay)
else:
logger.error(
f"Initialization failed after {self.max_retries} attempts: {str(e)}"
)
raise RuntimeError(f"Failed to initialize: {str(e)}")
async def _get_parsed_screen_som(self) -> Dict[str, Any]:
"""Get parsed screen information.
Args:
response: The original API response
messages: List of messages in standard OpenAI format
Returns:
Dict containing screen information
A response formatted according to OpenAI's computer use agent standard
"""
try:
# Take screenshot
screenshot = await self.computer.interface.screenshot()
import json
# Initialize with default values
width, height = 1024, 768
base64_image = ""
# Create a unique ID for this response
response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(response)}"
reasoning_id = f"rs_{response_id}"
action_id = f"cu_{response_id}"
call_id = f"call_{response_id}"
# Handle different types of screenshot returns
if isinstance(screenshot, (bytes, bytearray, memoryview)):
# Raw bytes screenshot
base64_image = base64.b64encode(screenshot).decode("utf-8")
elif hasattr(screenshot, "base64_image"):
# Object-style screenshot with attributes
# Type checking can't infer these attributes, but they exist at runtime
# on certain screenshot return types
base64_image = getattr(screenshot, "base64_image")
width = (
getattr(screenshot, "width", width) if hasattr(screenshot, "width") else width
)
height = (
getattr(screenshot, "height", height)
if hasattr(screenshot, "height")
else height
)
# Extract the last assistant message
assistant_msg = None
for msg in reversed(messages):
if msg["role"] == "assistant":
assistant_msg = msg
break
# Create parsed screen data
parsed_screen = {
"width": width,
"height": height,
"parsed_content_list": [],
"timestamp": datetime.now().isoformat(),
"screenshot_base64": base64_image,
if not assistant_msg:
# If no assistant message found, create a default one
assistant_msg = {"role": "assistant", "content": "No response available"}
# Initialize output array
output_items = []
# Extract reasoning and action details from the response
content = assistant_msg["content"]
reasoning_text = None
action_details = None
# Extract reasoning and action from different content formats
if isinstance(content, str):
try:
# Try to parse JSON
parsed_content = json.loads(content)
reasoning_text = parsed_content.get("Explanation", "")
# Extract action details
action = parsed_content.get("Action", "")
position = parsed_content.get("Position", {})
text_input = parsed_content.get("Text", "")
if action.lower() == "click" and position:
action_details = {
"type": "click",
"button": "left",
"x": position.get("x", 100),
"y": position.get("y", 100),
}
elif action.lower() == "type" and text_input:
action_details = {
"type": "type",
"text": text_input,
}
elif action.lower() == "scroll":
action_details = {
"type": "scroll",
"x": 100,
"y": 100,
"scroll_x": position.get("delta_x", 0),
"scroll_y": position.get("delta_y", 0),
}
except json.JSONDecodeError:
# If not valid JSON, use the content as reasoning
reasoning_text = content
elif isinstance(content, list):
# Handle list of content blocks (like Anthropic format)
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
# Collect text blocks for reasoning
if reasoning_text is None:
reasoning_text = ""
reasoning_text += item.get("text", "")
elif item.get("type") == "tool_use":
# Extract action from tool_use (similar to Anthropic format)
tool_input = item.get("input", {})
if "click" in tool_input or "position" in tool_input:
position = tool_input.get("click", tool_input.get("position", {}))
if isinstance(position, dict) and "x" in position and "y" in position:
action_details = {
"type": "click",
"button": "left",
"x": position.get("x", 100),
"y": position.get("y", 100),
}
elif "type" in tool_input or "text" in tool_input:
action_details = {
"type": "type",
"text": tool_input.get("type", tool_input.get("text", "")),
}
elif "scroll" in tool_input:
scroll = tool_input.get("scroll", {})
action_details = {
"type": "scroll",
"x": 100,
"y": 100,
"scroll_x": scroll.get("x", 0),
"scroll_y": scroll.get("y", 0),
}
# Add reasoning item if we have text content
if reasoning_text:
output_items.append(
{
"type": "reasoning",
"id": reasoning_id,
"summary": [
{
"type": "summary_text",
"text": reasoning_text[:200], # Truncate to reasonable length
}
],
}
)
# If no action details extracted, use default
if not action_details:
action_details = {
"type": "click",
"button": "left",
"x": 100,
"y": 100,
}
# Save screenshot if requested
if self.save_trajectory and self.experiment_manager:
try:
img_data = base64_image
if "," in img_data:
img_data = img_data.split(",")[1]
self._save_screenshot(img_data, action_type="state")
except Exception as e:
logger.error(f"Error saving screenshot: {str(e)}")
# Add computer_call item
computer_call = {
"type": "computer_call",
"id": action_id,
"call_id": call_id,
"action": action_details,
"pending_safety_checks": [],
"status": "completed",
}
output_items.append(computer_call)
return parsed_screen
except Exception as e:
logger.error(f"Error taking screenshot: {str(e)}")
return {
"width": 1024,
"height": 768,
"parsed_content_list": [],
"timestamp": datetime.now().isoformat(),
"error": f"Error taking screenshot: {str(e)}",
"screenshot_base64": "",
}
@abstractmethod
async def initialize_client(self) -> None:
"""Initialize the API client and any provider-specific components."""
raise NotImplementedError
@abstractmethod
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
"""Run the agent loop with provided messages.
Args:
messages: List of message objects
Yields:
Dict containing response data
"""
raise NotImplementedError
@abstractmethod
async def _process_screen(
self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]]
) -> None:
"""Process screen information and add to messages.
Args:
parsed_screen: Dictionary containing parsed screen info
messages: List of messages to update
"""
raise NotImplementedError
# Create the OpenAI-compatible response format
return {
"output": output_items,
"id": response_id,
# Include the original response for compatibility
"response": {"choices": [{"message": assistant_msg, "finish_reason": "stop"}]},
}

View File

@@ -4,9 +4,10 @@ import base64
from datetime import datetime
from io import BytesIO
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast, Tuple
from PIL import Image
from dataclasses import dataclass
import re
logger = logging.getLogger(__name__)
@@ -123,123 +124,278 @@ class BaseMessageManager:
break
def create_user_message(text: str) -> Dict[str, str]:
"""Create a user message.
class StandardMessageManager:
"""Manages messages in a standardized OpenAI format across different providers."""
Args:
text: The message text
def __init__(self, config: Optional[ImageRetentionConfig] = None):
"""Initialize message manager.
Returns:
Message dictionary
"""
return {
"role": "user",
"content": text,
}
Args:
config: Configuration for image retention
"""
self.messages: List[Dict[str, Any]] = []
self.config = config or ImageRetentionConfig()
def add_user_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
"""Add a user message.
def create_assistant_message(text: str) -> Dict[str, str]:
"""Create an assistant message.
Args:
content: Message content (text or multimodal content)
"""
self.messages.append({"role": "user", "content": content})
Args:
text: The message text
def add_assistant_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
"""Add an assistant message.
Returns:
Message dictionary
"""
return {
"role": "assistant",
"content": text,
}
Args:
content: Message content (text or multimodal content)
"""
self.messages.append({"role": "assistant", "content": content})
def add_system_message(self, content: str) -> None:
"""Add a system message.
def create_system_message(text: str) -> Dict[str, str]:
"""Create a system message.
Args:
content: System message content
"""
self.messages.append({"role": "system", "content": content})
Args:
text: The message text
def get_messages(self) -> List[Dict[str, Any]]:
"""Get all messages in standard format.
Returns:
Message dictionary
"""
return {
"role": "system",
"content": text,
}
Returns:
List of messages
"""
# If image retention is configured, apply it
if self.config.num_images_to_keep is not None:
return self._apply_image_retention(self.messages)
return self.messages
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Apply image retention policy to messages.
def create_image_message(
image_base64: Optional[str] = None,
image_path: Optional[str] = None,
image_obj: Optional[Image.Image] = None,
) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
"""Create a message with an image.
Args:
messages: List of messages
Args:
image_base64: Base64 encoded image
image_path: Path to image file
image_obj: PIL Image object
Returns:
List of messages with image retention applied
"""
if not self.config.num_images_to_keep:
return messages
Returns:
Message dictionary with content list
# Find user messages with images
image_messages = []
for msg in messages:
if msg["role"] == "user" and isinstance(msg["content"], list):
has_image = any(
item.get("type") == "image_url" or item.get("type") == "image"
for item in msg["content"]
)
if has_image:
image_messages.append(msg)
Raises:
ValueError: If no image source is provided
"""
if not any([image_base64, image_path, image_obj]):
raise ValueError("Must provide one of image_base64, image_path, or image_obj")
# If we don't have more images than the limit, return all messages
if len(image_messages) <= self.config.num_images_to_keep:
return messages
# Convert to base64 if needed
if image_path and not image_base64:
with open(image_path, "rb") as f:
image_bytes = f.read()
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
elif image_obj and not image_base64:
buffer = BytesIO()
image_obj.save(buffer, format="PNG")
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
# Get the most recent N images to keep
images_to_keep = image_messages[-self.config.num_images_to_keep :]
images_to_remove = image_messages[: -self.config.num_images_to_keep]
return {
"role": "user",
"content": [
{"type": "image", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
],
}
# Create a new message list without the older images
result = []
for msg in messages:
if msg in images_to_remove:
# Skip this message
continue
result.append(msg)
return result
def create_screen_message(
parsed_screen: Dict[str, Any],
include_raw: bool = False,
) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
"""Create a message with screen information.
def to_anthropic_format(
self, messages: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], str]:
"""Convert standard OpenAI format messages to Anthropic format.
Args:
parsed_screen: Dictionary containing parsed screen info
include_raw: Whether to include raw screenshot base64
Args:
messages: List of messages in OpenAI format
Returns:
Message dictionary with content
"""
if include_raw and "screenshot_base64" in parsed_screen:
# Create content list with both image and text
return {
"role": "user",
"content": [
{
"type": "image",
"image_url": {
"url": f"data:image/png;base64,{parsed_screen['screenshot_base64']}"
},
},
{
"type": "text",
"text": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
},
],
}
else:
# Create text-only message with screen info
return {
"role": "user",
"content": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
}
Returns:
Tuple containing (anthropic_messages, system_content)
"""
result = []
system_content = ""
# Process messages in order to maintain conversation flow
previous_assistant_tool_use_ids = (
set()
) # Track tool_use_ids in the previous assistant message
for i, msg in enumerate(messages):
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
# Collect system messages for later use
system_content += content + "\n"
continue
if role == "assistant":
# Track tool_use_ids in this assistant message for the next user message
previous_assistant_tool_use_ids = set()
if isinstance(content, list):
for item in content:
if (
isinstance(item, dict)
and item.get("type") == "tool_use"
and "id" in item
):
previous_assistant_tool_use_ids.add(item["id"])
logger.info(
f"Tool use IDs in assistant message #{i}: {previous_assistant_tool_use_ids}"
)
if role in ["user", "assistant"]:
anthropic_msg = {"role": role}
# Convert content based on type
if isinstance(content, str):
# Simple text content
anthropic_msg["content"] = [{"type": "text", "text": content}]
elif isinstance(content, list):
# Convert complex content
anthropic_content = []
for item in content:
item_type = item.get("type", "")
if item_type == "text":
anthropic_content.append({"type": "text", "text": item.get("text", "")})
elif item_type == "image_url":
# Convert OpenAI image format to Anthropic
image_url = item.get("image_url", {}).get("url", "")
if image_url.startswith("data:"):
# Extract base64 data and media type
match = re.match(r"data:(.+);base64,(.+)", image_url)
if match:
media_type, data = match.groups()
anthropic_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data,
},
}
)
else:
# Regular URL
anthropic_content.append(
{
"type": "image",
"source": {
"type": "url",
"url": image_url,
},
}
)
elif item_type == "tool_use":
# Always include tool_use blocks
anthropic_content.append(item)
elif item_type == "tool_result":
# Check if this is a user message AND if the tool_use_id exists in the previous assistant message
tool_use_id = item.get("tool_use_id")
# Only include tool_result if it references a tool_use from the immediately preceding assistant message
if (
role == "user"
and tool_use_id
and tool_use_id in previous_assistant_tool_use_ids
):
anthropic_content.append(item)
logger.info(
f"Including tool_result with tool_use_id: {tool_use_id}"
)
else:
# Convert to text to preserve information
logger.warning(
f"Converting tool_result to text. Tool use ID {tool_use_id} not found in previous assistant message"
)
content_text = "Tool Result: "
if "content" in item:
if isinstance(item["content"], list):
for content_item in item["content"]:
if (
isinstance(content_item, dict)
and content_item.get("type") == "text"
):
content_text += content_item.get("text", "")
elif isinstance(item["content"], str):
content_text += item["content"]
anthropic_content.append({"type": "text", "text": content_text})
anthropic_msg["content"] = anthropic_content
result.append(anthropic_msg)
return result, system_content
def from_anthropic_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert Anthropic format messages to standard OpenAI format.
Args:
messages: List of messages in Anthropic format
Returns:
List of messages in OpenAI format
"""
result = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", [])
if role in ["user", "assistant"]:
openai_msg = {"role": role}
# Simple case: single text block
if len(content) == 1 and content[0].get("type") == "text":
openai_msg["content"] = content[0].get("text", "")
else:
# Complex case: multiple blocks or non-text
openai_content = []
for item in content:
item_type = item.get("type", "")
if item_type == "text":
openai_content.append({"type": "text", "text": item.get("text", "")})
elif item_type == "image":
# Convert Anthropic image to OpenAI format
source = item.get("source", {})
if source.get("type") == "base64":
media_type = source.get("media_type", "image/png")
data = source.get("data", "")
openai_content.append(
{
"type": "image_url",
"image_url": {"url": f"data:{media_type};base64,{data}"},
}
)
else:
# URL
openai_content.append(
{
"type": "image_url",
"image_url": {"url": source.get("url", "")},
}
)
elif item_type in ["tool_use", "tool_result"]:
# Pass through tool-related content
openai_content.append(item)
openai_msg["content"] = openai_content
result.append(openai_msg)
return result

View File

@@ -0,0 +1,197 @@
"""Core visualization utilities for agents."""
import logging
import base64
from typing import Dict, Tuple
from PIL import Image, ImageDraw
from io import BytesIO
logger = logging.getLogger(__name__)
def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
"""Visualize a click action by drawing a circle on the screenshot.
Args:
x: X coordinate of the click
y: Y coordinate of the click
img_base64: Base64-encoded screenshot
Returns:
PIL Image with visualization
"""
try:
# Decode the base64 image
image_data = base64.b64decode(img_base64)
img = Image.open(BytesIO(image_data))
# Create a copy to draw on
draw_img = img.copy()
draw = ImageDraw.Draw(draw_img)
# Draw a circle at the click location
radius = 15
draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], outline="red", width=3)
# Draw crosshairs
line_length = 20
draw.line([(x - line_length, y), (x + line_length, y)], fill="red", width=3)
draw.line([(x, y - line_length), (x, y + line_length)], fill="red", width=3)
return draw_img
except Exception as e:
logger.error(f"Error visualizing click: {str(e)}")
# Return a blank image as fallback
return Image.new("RGB", (800, 600), "white")
def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
"""Visualize a scroll action by drawing arrows on the screenshot.
Args:
direction: Direction of scroll ('up' or 'down')
clicks: Number of scroll clicks
img_base64: Base64-encoded screenshot
Returns:
PIL Image with visualization
"""
try:
# Decode the base64 image
image_data = base64.b64decode(img_base64)
img = Image.open(BytesIO(image_data))
# Create a copy to draw on
draw_img = img.copy()
draw = ImageDraw.Draw(draw_img)
# Calculate parameters for visualization
width, height = img.size
center_x = width // 2
# Draw arrows to indicate scrolling
arrow_length = min(100, height // 4)
arrow_width = 30
num_arrows = min(clicks, 3) # Don't draw too many arrows
# Calculate starting position
if direction == "down":
start_y = height // 3
arrow_dir = 1 # Down
else:
start_y = height * 2 // 3
arrow_dir = -1 # Up
# Draw the arrows
for i in range(num_arrows):
y_pos = start_y + (i * arrow_length * arrow_dir * 0.7)
arrow_top = (center_x, y_pos)
arrow_bottom = (center_x, y_pos + arrow_length * arrow_dir)
# Draw the main line
draw.line([arrow_top, arrow_bottom], fill="red", width=5)
# Draw the arrowhead
arrowhead_size = 20
if direction == "down":
draw.line(
[
(center_x - arrow_width // 2, arrow_bottom[1] - arrowhead_size),
arrow_bottom,
(center_x + arrow_width // 2, arrow_bottom[1] - arrowhead_size),
],
fill="red",
width=5,
)
else:
draw.line(
[
(center_x - arrow_width // 2, arrow_bottom[1] + arrowhead_size),
arrow_bottom,
(center_x + arrow_width // 2, arrow_bottom[1] + arrowhead_size),
],
fill="red",
width=5,
)
return draw_img
except Exception as e:
logger.error(f"Error visualizing scroll: {str(e)}")
# Return a blank image as fallback
return Image.new("RGB", (800, 600), "white")
def calculate_element_center(bbox: Dict[str, float], width: int, height: int) -> Tuple[int, int]:
"""Calculate the center point of a UI element.
Args:
bbox: Bounding box dictionary with x1, y1, x2, y2 coordinates (0-1 normalized)
width: Screen width in pixels
height: Screen height in pixels
Returns:
(x, y) tuple with pixel coordinates
"""
center_x = int((bbox["x1"] + bbox["x2"]) / 2 * width)
center_y = int((bbox["y1"] + bbox["y2"]) / 2 * height)
return center_x, center_y
class VisualizationHelper:
"""Helper class for visualizing agent actions."""
def __init__(self, agent):
"""Initialize visualization helper.
Args:
agent: Reference to the agent that will use this helper
"""
self.agent = agent
def visualize_action(self, x: int, y: int, img_base64: str) -> None:
"""Visualize a click action by drawing on the screenshot."""
if (
not self.agent.save_trajectory
or not hasattr(self.agent, "experiment_manager")
or not self.agent.experiment_manager
):
return
try:
# Use the visualization utility
img = visualize_click(x, y, img_base64)
# Save the visualization
self.agent.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
except Exception as e:
logger.error(f"Error visualizing action: {str(e)}")
def visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
"""Visualize a scroll action by drawing arrows on the screenshot."""
if (
not self.agent.save_trajectory
or not hasattr(self.agent, "experiment_manager")
or not self.agent.experiment_manager
):
return
try:
# Use the visualization utility
img = visualize_scroll(direction, clicks, img_base64)
# Save the visualization
self.agent.experiment_manager.save_action_visualization(
img, "scroll", f"{direction}_{clicks}"
)
except Exception as e:
logger.error(f"Error visualizing scroll: {str(e)}")
def save_action_visualization(
self, img: Image.Image, action_name: str, details: str = ""
) -> str:
"""Save a visualization of an action."""
if hasattr(self.agent, "experiment_manager") and self.agent.experiment_manager:
return self.agent.experiment_manager.save_action_visualization(
img, action_name, details
)
return ""

View File

@@ -0,0 +1,141 @@
"""API call handling for Anthropic provider."""
import logging
import asyncio
from typing import Any, Dict, List, Optional
from httpx import ConnectError, ReadTimeout
from anthropic.types.beta import (
BetaMessage,
BetaMessageParam,
BetaTextBlockParam,
)
from .types import LLMProvider
from .prompts import SYSTEM_PROMPT
# Constants
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
logger = logging.getLogger(__name__)
class AnthropicAPIHandler:
"""Handles API calls to Anthropic's API with structured error handling and retries."""
def __init__(self, loop):
"""Initialize the API handler.
Args:
loop: Reference to the parent loop instance that provides context
"""
self.loop = loop
async def make_api_call(
self, messages: List[BetaMessageParam], system_prompt: str = SYSTEM_PROMPT
) -> BetaMessage:
"""Make API call to Anthropic with retry logic.
Args:
messages: List of messages to send to the API
system_prompt: System prompt to use (default: SYSTEM_PROMPT)
Returns:
API response
Raises:
RuntimeError: If API call fails after all retries
"""
if self.loop.client is None:
raise RuntimeError("Client not initialized. Call initialize_client() first.")
if self.loop.tool_manager is None:
raise RuntimeError("Tool manager not initialized. Call initialize_client() first.")
last_error = None
# Add detailed debug logging to examine messages
logger.info(f"Sending {len(messages)} messages to Anthropic API")
# Log tool use IDs and tool result IDs for debugging
tool_use_ids = set()
tool_result_ids = set()
for i, msg in enumerate(messages):
logger.info(f"Message {i}: role={msg.get('role')}")
if isinstance(msg.get("content"), list):
for content_block in msg.get("content", []):
if isinstance(content_block, dict):
block_type = content_block.get("type")
if block_type == "tool_use" and "id" in content_block:
tool_id = content_block.get("id")
tool_use_ids.add(tool_id)
logger.info(f" - Found tool_use with ID: {tool_id}")
elif block_type == "tool_result" and "tool_use_id" in content_block:
result_id = content_block.get("tool_use_id")
tool_result_ids.add(result_id)
logger.info(f" - Found tool_result referencing ID: {result_id}")
# Check for mismatches
missing_tool_uses = tool_result_ids - tool_use_ids
if missing_tool_uses:
logger.warning(
f"Found tool_result IDs without matching tool_use IDs: {missing_tool_uses}"
)
for attempt in range(self.loop.max_retries):
try:
# Log request
request_data = {
"messages": messages,
"max_tokens": self.loop.max_tokens,
"system": system_prompt,
}
# Let ExperimentManager handle sanitization
self.loop._log_api_call("request", request_data)
# Setup betas and system
system = BetaTextBlockParam(
type="text",
text=system_prompt,
)
betas = [COMPUTER_USE_BETA_FLAG]
# Add prompt caching if enabled in the message manager's config
if self.loop.message_manager.config.enable_caching:
betas.append(PROMPT_CACHING_BETA_FLAG)
system["cache_control"] = {"type": "ephemeral"}
# Make API call
response = await self.loop.client.create_message(
messages=messages,
system=[system],
tools=self.loop.tool_manager.get_tool_params(),
max_tokens=self.loop.max_tokens,
betas=betas,
)
# Let ExperimentManager handle sanitization
self.loop._log_api_call("response", request_data, response)
return response
except Exception as e:
last_error = e
logger.error(
f"Error in API call (attempt {attempt + 1}/{self.loop.max_retries}): {str(e)}"
)
self.loop._log_api_call("error", {"messages": messages}, error=e)
if attempt < self.loop.max_retries - 1:
await asyncio.sleep(
self.loop.retry_delay * (attempt + 1)
) # Exponential backoff
continue
# If we get here, all retries failed
error_message = f"API call failed after {self.loop.max_retries} attempts"
if last_error:
error_message += f": {str(last_error)}"
logger.error(error_message)
raise RuntimeError(error_message)

View File

@@ -0,0 +1,5 @@
"""Anthropic callbacks package."""
from .manager import CallbackManager
__all__ = ["CallbackManager"]

View File

@@ -2,40 +2,35 @@
import logging
import asyncio
import json
import os
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, cast
import base64
from datetime import datetime
from httpx import ConnectError, ReadTimeout
# Anthropic-specific imports
from anthropic import AsyncAnthropic
from anthropic.types.beta import (
BetaMessage,
BetaMessageParam,
BetaTextBlock,
BetaTextBlockParam,
BetaToolUseBlockParam,
BetaContentBlockParam,
)
import base64
from datetime import datetime
# Computer
from computer import Computer
# Base imports
from ...core.loop import BaseLoop
from ...core.messages import ImageRetentionConfig as CoreImageRetentionConfig
from ...core.messages import StandardMessageManager, ImageRetentionConfig
# Anthropic provider-specific imports
from .api.client import AnthropicClientFactory, BaseAnthropicClient
from .tools.manager import ToolManager
from .messages.manager import MessageManager, ImageRetentionConfig
from .callbacks.manager import CallbackManager
from .prompts import SYSTEM_PROMPT
from .types import LLMProvider
from .tools import ToolResult
# Import the new modules we created
from .api_handler import AnthropicAPIHandler
from .response_handler import AnthropicResponseHandler
from .callbacks.manager import CallbackManager
# Constants
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
@@ -44,13 +39,22 @@ logger = logging.getLogger(__name__)
class AnthropicLoop(BaseLoop):
"""Anthropic-specific implementation of the agent loop."""
"""Anthropic-specific implementation of the agent loop.
This class extends BaseLoop to provide specialized support for Anthropic's Claude models
with their unique tool-use capabilities, custom message formatting, and
callback-driven approach to handling responses.
"""
###########################################
# INITIALIZATION AND CONFIGURATION
###########################################
def __init__(
self,
api_key: str,
computer: Computer,
model: str = "claude-3-7-sonnet-20250219", # Fixed model
model: str = "claude-3-7-sonnet-20250219",
only_n_most_recent_images: Optional[int] = 2,
base_dir: Optional[str] = "trajectories",
max_retries: int = 3,
@@ -83,27 +87,37 @@ class AnthropicLoop(BaseLoop):
**kwargs,
)
# Ensure model is always the fixed one
self.model = "claude-3-7-sonnet-20250219"
# Anthropic-specific attributes
self.provider = LLMProvider.ANTHROPIC
self.client = None
self.retry_count = 0
self.tool_manager = None
self.message_manager = None
self.callback_manager = None
# Configure image retention with core config
self.image_retention_config = CoreImageRetentionConfig(
num_images_to_keep=only_n_most_recent_images
# Initialize standard message manager with image retention config
self.message_manager = StandardMessageManager(
config=ImageRetentionConfig(
num_images_to_keep=only_n_most_recent_images, enable_caching=True
)
)
# Message history
# Message history (standard OpenAI format)
self.message_history = []
# Initialize handlers
self.api_handler = AnthropicAPIHandler(self)
self.response_handler = AnthropicResponseHandler(self)
###########################################
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
###########################################
async def initialize_client(self) -> None:
"""Initialize the Anthropic API client and tools."""
"""Initialize the Anthropic API client and tools.
Implements abstract method from BaseLoop to set up the Anthropic-specific
client, tool manager, message manager, and callback handlers.
"""
try:
logger.info(f"Initializing Anthropic client with model {self.model}...")
@@ -112,14 +126,7 @@ class AnthropicLoop(BaseLoop):
provider=self.provider, api_key=self.api_key, model=self.model
)
# Initialize message manager
self.message_manager = MessageManager(
image_retention_config=ImageRetentionConfig(
num_images_to_keep=self.only_n_most_recent_images, enable_caching=True
)
)
# Initialize callback manager
# Initialize callback manager with our callback handlers
self.callback_manager = CallbackManager(
content_callback=self._handle_content,
tool_callback=self._handle_tool_result,
@@ -136,51 +143,18 @@ class AnthropicLoop(BaseLoop):
self.client = None
raise RuntimeError(f"Failed to initialize Anthropic client: {str(e)}")
async def _process_screen(
self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]]
) -> None:
"""Process screen information and add to messages.
Args:
parsed_screen: Dictionary containing parsed screen info
messages: List of messages to update
"""
try:
# Extract screenshot from parsed screen
screenshot_base64 = parsed_screen.get("screenshot_base64")
if screenshot_base64:
# Remove data URL prefix if present
if "," in screenshot_base64:
screenshot_base64 = screenshot_base64.split(",")[1]
# Create Anthropic-compatible message with image
screen_info_msg = {
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": screenshot_base64,
},
}
],
}
# Add screen info message to messages
messages.append(screen_info_msg)
except Exception as e:
logger.error(f"Error processing screen info: {str(e)}")
raise
###########################################
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
###########################################
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
"""Run the agent loop with provided messages.
Implements abstract method from BaseLoop to handle the main agent loop
for the AnthropicLoop implementation, using async queues and callbacks.
Args:
messages: List of message objects
messages: List of message objects in standard OpenAI format
Yields:
Dict containing response data
@@ -188,7 +162,7 @@ class AnthropicLoop(BaseLoop):
try:
logger.info("Starting Anthropic loop run")
# Reset message history and add new messages
# Reset message history and add new messages in standard format
self.message_history = []
self.message_history.extend(messages)
@@ -236,6 +210,10 @@ class AnthropicLoop(BaseLoop):
"metadata": {"title": "❌ Error"},
}
###########################################
# AGENT LOOP IMPLEMENTATION
###########################################
async def _run_loop(self, queue: asyncio.Queue) -> None:
"""Run the agent loop with current message history.
@@ -244,31 +222,65 @@ class AnthropicLoop(BaseLoop):
"""
try:
while True:
# Get up-to-date screen information
parsed_screen = await self._get_parsed_screen_som()
# Capture screenshot
try:
# Take screenshot - always returns raw PNG bytes
screenshot = await self.computer.interface.screenshot()
# Process screen info and update messages
await self._process_screen(parsed_screen, self.message_history)
# Convert PNG bytes to base64
base64_image = base64.b64encode(screenshot).decode("utf-8")
# Prepare messages and make API call
if self.message_manager is None:
raise RuntimeError(
"Message manager not initialized. Call initialize_client() first."
)
prepared_messages = self.message_manager.prepare_messages(
cast(List[BetaMessageParam], self.message_history.copy())
)
# Save screenshot if requested
if self.save_trajectory and self.experiment_manager:
try:
self._save_screenshot(base64_image, action_type="state")
except Exception as e:
logger.error(f"Error saving screenshot: {str(e)}")
# Add screenshot to message history in OpenAI format
screen_info_msg = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
}
],
}
self.message_history.append(screen_info_msg)
except Exception as e:
logger.error(f"Error capturing or processing screenshot: {str(e)}")
raise
# Create new turn directory for this API call
self._create_turn_dir()
# Use _make_api_call instead of direct client call to ensure logging
response = await self._make_api_call(prepared_messages)
# Convert standard messages to Anthropic format
anthropic_messages, system_content = self.message_manager.to_anthropic_format(
self.message_history.copy()
)
# Handle the response
if not await self._handle_response(response, self.message_history):
# Use API handler to make API call with Anthropic format
response = await self.api_handler.make_api_call(
messages=cast(List[BetaMessageParam], anthropic_messages),
system_prompt=system_content or SYSTEM_PROMPT,
)
# Use response handler to handle the response and convert to standard format
# This adds the response to message_history
if not await self.response_handler.handle_response(response, self.message_history):
break
# Get the last assistant message and convert it to OpenAI computer use format
for msg in reversed(self.message_history):
if msg["role"] == "assistant":
# Create OpenAI-compatible response and add to queue
openai_compatible_response = self._create_openai_compatible_response(
msg, response
)
await queue.put(openai_compatible_response)
break
# Signal completion
await queue.put(None)
@@ -283,98 +295,128 @@ class AnthropicLoop(BaseLoop):
)
await queue.put(None)
async def _make_api_call(self, messages: List[BetaMessageParam]) -> BetaMessage:
"""Make API call to Anthropic with retry logic.
def _create_openai_compatible_response(
self, assistant_msg: Dict[str, Any], original_response: Any
) -> Dict[str, Any]:
"""Create an OpenAI computer use agent compatible response format.
Args:
messages: List of messages to send to the API
assistant_msg: The assistant message in standard OpenAI format
original_response: The original API response object for ID generation
Returns:
API response
A response formatted according to OpenAI's computer use agent standard
"""
if self.client is None:
raise RuntimeError("Client not initialized. Call initialize_client() first.")
if self.tool_manager is None:
raise RuntimeError("Tool manager not initialized. Call initialize_client() first.")
# Create a unique ID for this response
response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(original_response)}"
reasoning_id = f"rs_{response_id}"
action_id = f"cu_{response_id}"
call_id = f"call_{response_id}"
last_error = None
# Extract reasoning and action details from the response
content = assistant_msg["content"]
for attempt in range(self.max_retries):
try:
# Log request
request_data = {
"messages": messages,
"max_tokens": self.max_tokens,
"system": SYSTEM_PROMPT,
# Initialize output array
output_items = []
# Add reasoning item if we have text content
reasoning_text = None
action_details = None
# AnthropicLoop expects a list of content blocks with type "text" or "tool_use"
if isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
reasoning_text = item.get("text", "")
elif isinstance(item, dict) and item.get("type") == "tool_use":
action_details = item
else:
# Fallback for string content
reasoning_text = content if isinstance(content, str) else None
# If we have reasoning text, add reasoning item
if reasoning_text:
output_items.append(
{
"type": "reasoning",
"id": reasoning_id,
"summary": [
{
"type": "summary_text",
"text": reasoning_text[:200], # Truncate to reasonable length
}
],
}
# Let ExperimentManager handle sanitization
self._log_api_call("request", request_data)
)
# Setup betas and system
system = BetaTextBlockParam(
type="text",
text=SYSTEM_PROMPT,
)
# Add computer_call item with action details if available
computer_call = {
"type": "computer_call",
"id": action_id,
"call_id": call_id,
"action": {"type": "click", "button": "left", "x": 100, "y": 100}, # Default action
"pending_safety_checks": [],
"status": "completed",
}
betas = [COMPUTER_USE_BETA_FLAG]
# Temporarily disable prompt caching due to "A maximum of 4 blocks with cache_control may be provided" error
# if self.message_manager.image_retention_config.enable_caching:
# betas.append(PROMPT_CACHING_BETA_FLAG)
# system["cache_control"] = {"type": "ephemeral"}
# If we have action details from a tool_use, update the computer_call
if action_details:
# Try to map tool_use to computer_call action
tool_input = action_details.get("input", {})
if "click" in tool_input or "position" in tool_input:
position = tool_input.get("click", tool_input.get("position", {}))
if isinstance(position, dict) and "x" in position and "y" in position:
computer_call["action"] = {
"type": "click",
"button": "left",
"x": position.get("x", 100),
"y": position.get("y", 100),
}
elif "type" in tool_input or "text" in tool_input:
computer_call["action"] = {
"type": "type",
"text": tool_input.get("type", tool_input.get("text", "")),
}
elif "scroll" in tool_input:
scroll = tool_input.get("scroll", {})
computer_call["action"] = {
"type": "scroll",
"x": 100,
"y": 100,
"scroll_x": scroll.get("x", 0),
"scroll_y": scroll.get("y", 0),
}
# Make API call
response = await self.client.create_message(
messages=messages,
system=[system],
tools=self.tool_manager.get_tool_params(),
max_tokens=self.max_tokens,
betas=betas,
)
output_items.append(computer_call)
# Let ExperimentManager handle sanitization
self._log_api_call("response", request_data, response)
# Create the OpenAI-compatible response format
return {
"output": output_items,
"id": response_id,
# Include the original format for backward compatibility
"response": {"choices": [{"message": assistant_msg, "finish_reason": "stop"}]},
}
return response
except Exception as e:
last_error = e
logger.error(
f"Error in API call (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
)
self._log_api_call("error", {"messages": messages}, error=e)
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
continue
# If we get here, all retries failed
error_message = f"API call failed after {self.max_retries} attempts"
if last_error:
error_message += f": {str(last_error)}"
logger.error(error_message)
raise RuntimeError(error_message)
###########################################
# RESPONSE AND CALLBACK HANDLING
###########################################
async def _handle_response(self, response: BetaMessage, messages: List[Dict[str, Any]]) -> bool:
"""Handle the Anthropic API response.
Args:
response: API response
messages: List of messages to update
messages: List of messages to update in standard OpenAI format
Returns:
True if the loop should continue, False otherwise
"""
try:
# Convert response to parameter format
response_params = self._response_to_params(response)
# Convert Anthropic response to standard OpenAI format
response_blocks = self._response_to_blocks(response)
# Add response to messages
messages.append(
{
"role": "assistant",
"content": response_params,
}
)
# Add response to standard message history
messages.append({"role": "assistant", "content": response_blocks})
if self.callback_manager is None:
raise RuntimeError(
@@ -383,31 +425,33 @@ class AnthropicLoop(BaseLoop):
# Handle tool use blocks and collect results
tool_result_content = []
for content_block in response_params:
for content_block in response.content:
# Notify callback of content
self.callback_manager.on_content(cast(BetaContentBlockParam, content_block))
# Handle tool use
if content_block.get("type") == "tool_use":
# Handle tool use - carefully check and access attributes
if hasattr(content_block, "type") and content_block.type == "tool_use":
if self.tool_manager is None:
raise RuntimeError(
"Tool manager not initialized. Call initialize_client() first."
)
# Safely get attributes
tool_name = getattr(content_block, "name", "")
tool_input = getattr(content_block, "input", {})
tool_id = getattr(content_block, "id", "")
result = await self.tool_manager.execute_tool(
name=content_block["name"],
tool_input=cast(Dict[str, Any], content_block["input"]),
name=tool_name,
tool_input=cast(Dict[str, Any], tool_input),
)
# Create tool result and add to content
tool_result = self._make_tool_result(
cast(ToolResult, result), content_block["id"]
)
# Create tool result
tool_result = self._make_tool_result(cast(ToolResult, result), tool_id)
tool_result_content.append(tool_result)
# Notify callback of tool result
self.callback_manager.on_tool_result(
cast(ToolResult, result), content_block["id"]
)
self.callback_manager.on_tool_result(cast(ToolResult, result), tool_id)
# If no tool results, we're done
if not tool_result_content:
@@ -415,8 +459,8 @@ class AnthropicLoop(BaseLoop):
self.callback_manager.on_content({"type": "text", "text": "<DONE>"})
return False
# Add tool results to message history
messages.append({"content": tool_result_content, "role": "user"})
# Add tool results to message history in standard format
messages.append({"role": "user", "content": tool_result_content})
return True
except Exception as e:
@@ -429,28 +473,41 @@ class AnthropicLoop(BaseLoop):
)
return False
def _response_to_params(
self,
response: BetaMessage,
) -> List[Dict[str, Any]]:
"""Convert API response to message parameters.
def _response_to_blocks(self, response: BetaMessage) -> List[Dict[str, Any]]:
"""Convert Anthropic API response to standard blocks format.
Args:
response: API response message
Returns:
List of content blocks
List of content blocks in standard format
"""
result = []
for block in response.content:
if isinstance(block, BetaTextBlock):
result.append({"type": "text", "text": block.text})
elif hasattr(block, "type") and block.type == "tool_use":
# Safely access attributes after confirming it's a tool_use
result.append(
{
"type": "tool_use",
"id": getattr(block, "id", ""),
"name": getattr(block, "name", ""),
"input": getattr(block, "input", {}),
}
)
else:
result.append(cast(Dict[str, Any], block.model_dump()))
# For other block types, convert to dict
block_dict = {}
for key, value in vars(block).items():
if not key.startswith("_"):
block_dict[key] = value
result.append(block_dict)
return result
def _make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]:
"""Convert a tool result to API format.
"""Convert a tool result to standard format.
Args:
result: Tool execution result
@@ -489,12 +546,8 @@ class AnthropicLoop(BaseLoop):
if result.base64_image:
tool_result_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": result.base64_image,
},
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{result.base64_image}"},
}
)
@@ -519,16 +572,19 @@ class AnthropicLoop(BaseLoop):
result_text = f"<s>{result.system}</s>\n{result_text}"
return result_text
def _handle_content(self, content: BetaContentBlockParam) -> None:
###########################################
# CALLBACK HANDLERS
###########################################
def _handle_content(self, content):
"""Handle content updates from the assistant."""
if content.get("type") == "text":
text_content = cast(BetaTextBlockParam, content)
text = text_content["text"]
text = content.get("text", "")
if text == "<DONE>":
return
logger.info(f"Assistant: {text}")
def _handle_tool_result(self, result: ToolResult, tool_id: str) -> None:
def _handle_tool_result(self, result, tool_id):
"""Handle tool execution results."""
if result.error:
logger.error(f"Tool {tool_id} error: {result.error}")

View File

@@ -0,0 +1,223 @@
"""Response and tool handling for Anthropic provider."""
import logging
from typing import Any, Dict, List, Optional, Tuple, cast
from anthropic.types.beta import (
BetaMessage,
BetaMessageParam,
BetaTextBlock,
BetaTextBlockParam,
BetaToolUseBlockParam,
BetaContentBlockParam,
)
from .tools import ToolResult
logger = logging.getLogger(__name__)
class AnthropicResponseHandler:
"""Handles Anthropic API responses and tool execution results."""
def __init__(self, loop):
"""Initialize the response handler.
Args:
loop: Reference to the parent loop instance that provides context
"""
self.loop = loop
async def handle_response(self, response: BetaMessage, messages: List[Dict[str, Any]]) -> bool:
"""Handle the Anthropic API response.
Args:
response: API response
messages: List of messages to update
Returns:
True if the loop should continue, False otherwise
"""
try:
# Convert response to parameter format
response_params = self.response_to_params(response)
# Collect all existing tool_use IDs from previous messages for validation
existing_tool_use_ids = set()
for msg in messages:
if msg.get("role") == "assistant" and isinstance(msg.get("content"), list):
for block in msg.get("content", []):
if (
isinstance(block, dict)
and block.get("type") == "tool_use"
and "id" in block
):
existing_tool_use_ids.add(block["id"])
# Also add new tool_use IDs from the current response
current_tool_use_ids = set()
for block in response_params:
if isinstance(block, dict) and block.get("type") == "tool_use" and "id" in block:
current_tool_use_ids.add(block["id"])
existing_tool_use_ids.add(block["id"])
logger.info(f"Existing tool_use IDs in conversation: {existing_tool_use_ids}")
logger.info(f"New tool_use IDs in current response: {current_tool_use_ids}")
# Add response to messages
messages.append(
{
"role": "assistant",
"content": response_params,
}
)
if self.loop.callback_manager is None:
raise RuntimeError(
"Callback manager not initialized. Call initialize_client() first."
)
# Handle tool use blocks and collect results
tool_result_content = []
for content_block in response_params:
# Notify callback of content
self.loop.callback_manager.on_content(cast(BetaContentBlockParam, content_block))
# Handle tool use
if content_block.get("type") == "tool_use":
if self.loop.tool_manager is None:
raise RuntimeError(
"Tool manager not initialized. Call initialize_client() first."
)
# Execute the tool
result = await self.loop.tool_manager.execute_tool(
name=content_block["name"],
tool_input=cast(Dict[str, Any], content_block["input"]),
)
# Verify the tool_use ID exists in the conversation (which it should now)
tool_use_id = content_block["id"]
if tool_use_id in existing_tool_use_ids:
# Create tool result and add to content
tool_result = self.make_tool_result(cast(ToolResult, result), tool_use_id)
tool_result_content.append(tool_result)
# Notify callback of tool result
self.loop.callback_manager.on_tool_result(
cast(ToolResult, result), content_block["id"]
)
else:
logger.warning(
f"Tool use ID {tool_use_id} not found in previous messages. Skipping tool result."
)
# If no tool results, we're done
if not tool_result_content:
# Signal completion
self.loop.callback_manager.on_content({"type": "text", "text": "<DONE>"})
return False
# Add tool results to message history
messages.append({"content": tool_result_content, "role": "user"})
return True
except Exception as e:
logger.error(f"Error handling response: {str(e)}")
messages.append(
{
"role": "assistant",
"content": f"Error: {str(e)}",
}
)
return False
def response_to_params(
self,
response: BetaMessage,
) -> List[Dict[str, Any]]:
"""Convert API response to message parameters.
Args:
response: API response message
Returns:
List of content blocks
"""
result = []
for block in response.content:
if isinstance(block, BetaTextBlock):
result.append({"type": "text", "text": block.text})
else:
result.append(cast(Dict[str, Any], block.model_dump()))
return result
def make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]:
"""Convert a tool result to API format.
Args:
result: Tool execution result
tool_use_id: ID of the tool use
Returns:
Formatted tool result
"""
if result.content:
return {
"type": "tool_result",
"content": result.content,
"tool_use_id": tool_use_id,
"is_error": bool(result.error),
}
tool_result_content = []
is_error = False
if result.error:
is_error = True
tool_result_content = [
{
"type": "text",
"text": self.maybe_prepend_system_tool_result(result, result.error),
}
]
else:
if result.output:
tool_result_content.append(
{
"type": "text",
"text": self.maybe_prepend_system_tool_result(result, result.output),
}
)
if result.base64_image:
tool_result_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": result.base64_image,
},
}
)
return {
"type": "tool_result",
"content": tool_result_content,
"tool_use_id": tool_use_id,
"is_error": is_error,
}
def maybe_prepend_system_tool_result(self, result: ToolResult, result_text: str) -> str:
"""Prepend system information to tool result if available.
Args:
result: Tool execution result
result_text: Text to prepend to
Returns:
Text with system information prepended if available
"""
if result.system:
result_text = f"<s>{result.system}</s>\n{result_text}"
return result_text

View File

@@ -7,101 +7,6 @@ from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
from ....core.tools.bash import BaseBashTool
class _BashSession:
"""A session of a bash shell."""
_started: bool
_process: asyncio.subprocess.Process
command: str = "/bin/bash"
_output_delay: float = 0.2 # seconds
_timeout: float = 120.0 # seconds
_sentinel: str = "<<exit>>"
def __init__(self):
self._started = False
self._timed_out = False
async def start(self):
if self._started:
return
self._process = await asyncio.create_subprocess_shell(
self.command,
preexec_fn=os.setsid,
shell=True,
bufsize=0,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
self._started = True
def stop(self):
"""Terminate the bash shell."""
if not self._started:
raise ToolError("Session has not started.")
if self._process.returncode is not None:
return
self._process.terminate()
async def run(self, command: str):
"""Execute a command in the bash shell."""
if not self._started:
raise ToolError("Session has not started.")
if self._process.returncode is not None:
return ToolResult(
system="tool must be restarted",
error=f"bash has exited with returncode {self._process.returncode}",
)
if self._timed_out:
raise ToolError(
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
)
# we know these are not None because we created the process with PIPEs
assert self._process.stdin
assert self._process.stdout
assert self._process.stderr
# send command to the process
self._process.stdin.write(command.encode() + f"; echo '{self._sentinel}'\n".encode())
await self._process.stdin.drain()
# read output from the process, until the sentinel is found
try:
async with asyncio.timeout(self._timeout):
while True:
await asyncio.sleep(self._output_delay)
# Read from stdout using the proper API
output_bytes = await self._process.stdout.read()
if output_bytes:
output = output_bytes.decode()
if self._sentinel in output:
# strip the sentinel and break
output = output[: output.index(self._sentinel)]
break
except asyncio.TimeoutError:
self._timed_out = True
raise ToolError(
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
) from None
if output and output.endswith("\n"):
output = output[:-1]
# Read from stderr using the proper API
error_bytes = await self._process.stderr.read()
error = error_bytes.decode() if error_bytes else ""
if error and error.endswith("\n"):
error = error[:-1]
# No need to clear buffers as we're using read() which consumes the data
return CLIResult(output=output, error=error)
class BashTool(BaseBashTool, BaseAnthropicTool):
"""
A tool that allows the agent to run bash commands.
@@ -123,7 +28,6 @@ class BashTool(BaseBashTool, BaseAnthropicTool):
# Then initialize the Anthropic tool
BaseAnthropicTool.__init__(self)
# Initialize bash session
self._session = _BashSession()
async def __call__(self, command: str | None = None, restart: bool = False, **kwargs):
"""Execute a bash command.

View File

@@ -3,8 +3,6 @@
# The OmniComputerAgent has been replaced by the unified ComputerAgent
# which can be found in agent.core.agent
from .types import LLMProvider
from .experiment import ExperimentManager
from .visualization import visualize_click, visualize_scroll, calculate_element_center
from .image_utils import (
decode_base64_image,
encode_image_base64,
@@ -15,10 +13,6 @@ from .image_utils import (
__all__ = [
"LLMProvider",
"ExperimentManager",
"visualize_click",
"visualize_scroll",
"calculate_element_center",
"decode_base64_image",
"encode_image_base64",
"clean_base64_data",

View File

@@ -0,0 +1,264 @@
"""Action execution for the Omni agent."""
import logging
from typing import Dict, Any, Tuple
import json
from .parser import ParseResult
from ...core.visualization import calculate_element_center
logger = logging.getLogger(__name__)
class ActionExecutor:
"""Executes UI actions based on model instructions."""
def __init__(self, loop):
"""Initialize the action executor.
Args:
loop: Reference to the parent loop instance that provides context
"""
self.loop = loop
async def execute_action(self, content: Dict[str, Any], parsed_screen: ParseResult) -> bool:
"""Execute the action specified in the content.
Args:
content: Dictionary containing the action details
parsed_screen: Current parsed screen information
Returns:
Whether an action-specific screenshot was saved
"""
try:
action = content.get("Action", "").lower()
if not action:
return False
# Track if we saved an action-specific screenshot
action_screenshot_saved = False
try:
# Prepare kwargs based on action type
kwargs = {}
if action in ["left_click", "right_click", "double_click", "move_cursor"]:
try:
box_id = int(content["Box ID"])
logger.info(f"Processing Box ID: {box_id}")
# Calculate click coordinates
x, y = await self.calculate_click_coordinates(box_id, parsed_screen)
logger.info(f"Calculated coordinates: x={x}, y={y}")
kwargs["x"] = x
kwargs["y"] = y
# Visualize action if screenshot is available
if parsed_screen.annotated_image_base64:
img_data = parsed_screen.annotated_image_base64
# Remove data URL prefix if present
if img_data.startswith("data:image"):
img_data = img_data.split(",")[1]
# Only save visualization for coordinate-based actions
self.loop.viz_helper.visualize_action(x, y, img_data)
action_screenshot_saved = True
except ValueError as e:
logger.error(f"Error processing Box ID: {str(e)}")
return False
elif action == "drag_to":
try:
box_id = int(content["Box ID"])
x, y = await self.calculate_click_coordinates(box_id, parsed_screen)
kwargs.update(
{
"x": x,
"y": y,
"button": content.get("button", "left"),
"duration": float(content.get("duration", 0.5)),
}
)
# Visualize drag destination if screenshot is available
if parsed_screen.annotated_image_base64:
img_data = parsed_screen.annotated_image_base64
# Remove data URL prefix if present
if img_data.startswith("data:image"):
img_data = img_data.split(",")[1]
# Only save visualization for coordinate-based actions
self.loop.viz_helper.visualize_action(x, y, img_data)
action_screenshot_saved = True
except ValueError as e:
logger.error(f"Error processing drag coordinates: {str(e)}")
return False
elif action == "type_text":
kwargs["text"] = content["Value"]
# For type_text, store the value in the action type
action_type = f"type_{content['Value'][:20]}" # Truncate if too long
elif action == "press_key":
kwargs["key"] = content["Value"]
action_type = f"press_{content['Value']}"
elif action == "hotkey":
if isinstance(content.get("Value"), list):
keys = content["Value"]
action_type = f"hotkey_{'_'.join(keys)}"
else:
# Simply split string format like "command+space" into a list
keys = [k.strip() for k in content["Value"].lower().split("+")]
action_type = f"hotkey_{content['Value'].replace('+', '_')}"
logger.info(f"Preparing hotkey with keys: {keys}")
# Get the method but call it with *args instead of **kwargs
method = getattr(self.loop.computer.interface, action)
await method(*keys) # Unpack the keys list as positional arguments
logger.info(f"Tool execution completed successfully: {action}")
# For hotkeys, take a screenshot after the action
try:
# Get a new screenshot after the action and save it with the action type
new_parsed_screen = await self.loop._get_parsed_screen_som(
save_screenshot=False
)
if new_parsed_screen and new_parsed_screen.annotated_image_base64:
img_data = new_parsed_screen.annotated_image_base64
# Remove data URL prefix if present
if img_data.startswith("data:image"):
img_data = img_data.split(",")[1]
# Save with action type to indicate this is a post-action screenshot
self.loop._save_screenshot(img_data, action_type=action_type)
action_screenshot_saved = True
except Exception as screenshot_error:
logger.error(
f"Error taking post-hotkey screenshot: {str(screenshot_error)}"
)
return action_screenshot_saved
elif action in ["scroll_down", "scroll_up"]:
clicks = int(content.get("amount", 1))
kwargs["clicks"] = clicks
action_type = f"scroll_{action.split('_')[1]}_{clicks}"
# Visualize scrolling if screenshot is available
if parsed_screen.annotated_image_base64:
img_data = parsed_screen.annotated_image_base64
# Remove data URL prefix if present
if img_data.startswith("data:image"):
img_data = img_data.split(",")[1]
direction = "down" if action == "scroll_down" else "up"
# For scrolling, we only save the visualization to avoid duplicate images
self.loop.viz_helper.visualize_scroll(direction, clicks, img_data)
action_screenshot_saved = True
else:
logger.warning(f"Unknown action: {action}")
return False
# Execute tool and handle result
try:
method = getattr(self.loop.computer.interface, action)
logger.info(f"Found method for action '{action}': {method}")
await method(**kwargs)
logger.info(f"Tool execution completed successfully: {action}")
# For non-coordinate based actions that don't already have visualizations,
# take a new screenshot after the action
if not action_screenshot_saved:
# Take a new screenshot
try:
# Get a new screenshot after the action and save it with the action type
new_parsed_screen = await self.loop._get_parsed_screen_som(
save_screenshot=False
)
if new_parsed_screen and new_parsed_screen.annotated_image_base64:
img_data = new_parsed_screen.annotated_image_base64
# Remove data URL prefix if present
if img_data.startswith("data:image"):
img_data = img_data.split(",")[1]
# Save with action type to indicate this is a post-action screenshot
if "action_type" in locals():
self.loop._save_screenshot(img_data, action_type=action_type)
else:
self.loop._save_screenshot(img_data, action_type=action)
# Update the action screenshot flag for this turn
action_screenshot_saved = True
except Exception as screenshot_error:
logger.error(
f"Error taking post-action screenshot: {str(screenshot_error)}"
)
except AttributeError as e:
logger.error(f"Method not found for action '{action}': {str(e)}")
return False
except Exception as tool_error:
logger.error(f"Tool execution failed: {str(tool_error)}")
return False
return action_screenshot_saved
except Exception as e:
logger.error(f"Error executing action {action}: {str(e)}")
return False
except Exception as e:
logger.error(f"Error in execute_action: {str(e)}")
return False
async def calculate_click_coordinates(
self, box_id: int, parsed_screen: ParseResult
) -> Tuple[int, int]:
"""Calculate click coordinates based on box ID.
Args:
box_id: The ID of the box to click
parsed_screen: The parsed screen information
Returns:
Tuple of (x, y) coordinates
Raises:
ValueError: If box_id is invalid or missing from parsed screen
"""
# First try to use structured elements data
logger.info(f"Elements count: {len(parsed_screen.elements)}")
# Try to find element with matching ID
for element in parsed_screen.elements:
if element.id == box_id:
logger.info(f"Found element with ID {box_id}: {element}")
bbox = element.bbox
# Get screen dimensions from the metadata if available, or fallback
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
logger.info(f"Screen dimensions: width={width}, height={height}")
# Create a dictionary from the element's bbox for calculate_element_center
bbox_dict = {"x1": bbox.x1, "y1": bbox.y1, "x2": bbox.x2, "y2": bbox.y2}
center_x, center_y = calculate_element_center(bbox_dict, width, height)
logger.info(f"Calculated center: ({center_x}, {center_y})")
# Validate coordinates - if they're (0,0) or unreasonably small,
# use a default position in the center of the screen
if center_x == 0 and center_y == 0:
logger.warning("Got (0,0) coordinates, using fallback position")
center_x = width // 2
center_y = height // 2
logger.info(f"Using fallback center: ({center_x}, {center_y})")
return center_x, center_y
# If we couldn't find the box, use center of screen
logger.error(
f"Box ID {box_id} not found in structured elements (count={len(parsed_screen.elements)})"
)
# Use center of screen as fallback
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
logger.warning(f"Using fallback position in center of screen ({width//2}, {height//2})")
return width // 2, height // 2

View File

@@ -0,0 +1,42 @@
"""API handling for Omni provider."""
import logging
from typing import Any, Dict, List
from .prompts import SYSTEM_PROMPT
logger = logging.getLogger(__name__)
class OmniAPIHandler:
"""Handler for Omni API calls."""
def __init__(self, loop):
"""Initialize the API handler.
Args:
loop: Parent loop instance
"""
self.loop = loop
async def make_api_call(
self, messages: List[Dict[str, Any]], system_prompt: str = SYSTEM_PROMPT
) -> Any:
"""Make an API call to the appropriate provider.
Args:
messages: List of messages in standard OpenAI format
system_prompt: System prompt to use
Returns:
API response
"""
if not self.loop._make_api_call:
raise RuntimeError("Loop does not have _make_api_call method")
try:
# Use the loop's _make_api_call method with standard messages
return await self.loop._make_api_call(messages=messages, system_prompt=system_prompt)
except Exception as e:
logger.error(f"Error making API call: {str(e)}")
raise

View File

@@ -44,6 +44,10 @@ class AnthropicClient(BaseOmniClient):
anthropic_messages = []
for message in messages:
# Skip messages with empty content
if not message.get("content"):
continue
if message["role"] == "user":
anthropic_messages.append({"role": "user", "content": message["content"]})
elif message["role"] == "assistant":

View File

@@ -1,101 +0,0 @@
"""Groq client implementation."""
import os
import logging
from typing import Dict, List, Optional, Any, Tuple
from groq import Groq
import re
from .utils import is_image_path
from .base import BaseOmniClient
logger = logging.getLogger(__name__)
class GroqClient(BaseOmniClient):
"""Client for making Groq API calls."""
def __init__(
self,
api_key: Optional[str] = None,
model: str = "deepseek-r1-distill-llama-70b",
max_tokens: int = 4096,
temperature: float = 0.6,
):
"""Initialize Groq client.
Args:
api_key: Groq API key (if not provided, will try to get from env)
model: Model name to use
max_tokens: Maximum tokens to generate
temperature: Temperature for sampling
"""
super().__init__(api_key=api_key, model=model)
self.api_key = api_key or os.getenv("GROQ_API_KEY")
if not self.api_key:
raise ValueError("No Groq API key provided")
self.max_tokens = max_tokens
self.temperature = temperature
self.client = Groq(api_key=self.api_key)
self.model: str = model # Add explicit type annotation
def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
) -> tuple[str, int]:
"""Run interleaved chat completion.
Args:
messages: List of message dicts
system: System prompt
max_tokens: Optional max tokens override
Returns:
Tuple of (response text, token usage)
"""
# Avoid using system messages for R1
final_messages = [{"role": "user", "content": system}]
# Process messages
if isinstance(messages, list):
for item in messages:
if isinstance(item, dict):
# For dict items, concatenate all text content, ignoring images
text_contents = []
for cnt in item["content"]:
if isinstance(cnt, str):
if not is_image_path(cnt): # Skip image paths
text_contents.append(cnt)
else:
text_contents.append(str(cnt))
if text_contents: # Only add if there's text content
message = {"role": "user", "content": " ".join(text_contents)}
final_messages.append(message)
else: # str
message = {"role": "user", "content": item}
final_messages.append(message)
elif isinstance(messages, str):
final_messages.append({"role": "user", "content": messages})
try:
completion = self.client.chat.completions.create( # type: ignore
model=self.model,
messages=final_messages, # type: ignore
temperature=self.temperature,
max_tokens=max_tokens or self.max_tokens,
top_p=0.95,
stream=False,
)
response = completion.choices[0].message.content
final_answer = response.split("</think>\n")[-1] if "</think>" in response else response
final_answer = final_answer.replace("<output>", "").replace("</output>", "")
token_usage = completion.usage.total_tokens
return final_answer, token_usage
except Exception as e:
logger.error(f"Error in Groq API call: {e}")
raise

View File

@@ -1,276 +0,0 @@
"""Experiment management for the Cua provider."""
import os
import logging
import copy
import base64
from io import BytesIO
from datetime import datetime
from typing import Any, Dict, List, Optional
from PIL import Image
import json
import time
logger = logging.getLogger(__name__)
class ExperimentManager:
"""Manages experiment directories and logging for the agent."""
def __init__(
self,
base_dir: Optional[str] = None,
only_n_most_recent_images: Optional[int] = None,
):
"""Initialize the experiment manager.
Args:
base_dir: Base directory for saving experiment data
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
"""
self.base_dir = base_dir
self.only_n_most_recent_images = only_n_most_recent_images
self.run_dir = None
self.current_turn_dir = None
self.turn_count = 0
self.screenshot_count = 0
# Track all screenshots for potential API request inclusion
self.screenshot_paths = []
# Set up experiment directories if base_dir is provided
if self.base_dir:
self.setup_experiment_dirs()
def setup_experiment_dirs(self) -> None:
"""Setup the experiment directory structure."""
if not self.base_dir:
return
# Create base experiments directory if it doesn't exist
os.makedirs(self.base_dir, exist_ok=True)
# Use the base_dir directly as the run_dir
self.run_dir = self.base_dir
logger.info(f"Using directory for experiment: {self.run_dir}")
# Create first turn directory
self.create_turn_dir()
def create_turn_dir(self) -> None:
"""Create a new directory for the current turn."""
if not self.run_dir:
return
self.turn_count += 1
self.current_turn_dir = os.path.join(self.run_dir, f"turn_{self.turn_count:03d}")
os.makedirs(self.current_turn_dir, exist_ok=True)
logger.info(f"Created turn directory: {self.current_turn_dir}")
def sanitize_log_data(self, data: Any) -> Any:
"""Sanitize data for logging by removing large base64 strings.
Args:
data: Data to sanitize (dict, list, or primitive)
Returns:
Sanitized copy of the data
"""
if isinstance(data, dict):
result = copy.deepcopy(data)
# Handle nested dictionaries and lists
for key, value in result.items():
# Process content arrays that contain image data
if key == "content" and isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
# Handle Anthropic format
if item.get("type") == "image" and isinstance(item.get("source"), dict):
source = item["source"]
if "data" in source and isinstance(source["data"], str):
# Replace base64 data with a placeholder and length info
data_len = len(source["data"])
source["data"] = f"[BASE64_IMAGE_DATA_LENGTH_{data_len}]"
# Handle OpenAI format
elif item.get("type") == "image_url" and isinstance(
item.get("image_url"), dict
):
url_dict = item["image_url"]
if "url" in url_dict and isinstance(url_dict["url"], str):
url = url_dict["url"]
if url.startswith("data:"):
# Replace base64 data with placeholder
data_len = len(url)
url_dict["url"] = f"[BASE64_IMAGE_URL_LENGTH_{data_len}]"
# Handle other nested structures recursively
if isinstance(value, dict):
result[key] = self.sanitize_log_data(value)
elif isinstance(value, list):
result[key] = [self.sanitize_log_data(item) for item in value]
return result
elif isinstance(data, list):
return [self.sanitize_log_data(item) for item in data]
else:
return data
def save_debug_image(self, image_data: str, filename: str) -> None:
"""Save a debug image to the experiment directory.
Args:
image_data: Base64 encoded image data
filename: Filename to save the image as
"""
# Since we no longer want to use the images/ folder, we'll skip this functionality
return
def save_screenshot(self, img_base64: str, action_type: str = "") -> Optional[str]:
"""Save a screenshot to the experiment directory.
Args:
img_base64: Base64 encoded screenshot
action_type: Type of action that triggered the screenshot
Returns:
Optional[str]: Path to the saved screenshot, or None if saving failed
"""
if not self.current_turn_dir:
return None
try:
# Increment screenshot counter
self.screenshot_count += 1
# Create a descriptive filename
timestamp = int(time.time() * 1000)
action_suffix = f"_{action_type}" if action_type else ""
filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
# Save directly to the turn directory (no screenshots subdirectory)
filepath = os.path.join(self.current_turn_dir, filename)
# Save the screenshot
img_data = base64.b64decode(img_base64)
with open(filepath, "wb") as f:
f.write(img_data)
# Keep track of the file path for reference
self.screenshot_paths.append(filepath)
return filepath
except Exception as e:
logger.error(f"Error saving screenshot: {str(e)}")
return None
def should_save_debug_image(self) -> bool:
"""Determine if debug images should be saved.
Returns:
Boolean indicating if debug images should be saved
"""
# We no longer need to save debug images, so always return False
return False
def save_action_visualization(
self, img: Image.Image, action_name: str, details: str = ""
) -> str:
"""Save a visualization of an action.
Args:
img: Image to save
action_name: Name of the action
details: Additional details about the action
Returns:
Path to the saved image
"""
if not self.current_turn_dir:
return ""
try:
# Create a descriptive filename
timestamp = int(time.time() * 1000)
details_suffix = f"_{details}" if details else ""
filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
# Save directly to the turn directory (no visualizations subdirectory)
filepath = os.path.join(self.current_turn_dir, filename)
# Save the image
img.save(filepath)
# Keep track of the file path for cleanup
self.screenshot_paths.append(filepath)
return filepath
except Exception as e:
logger.error(f"Error saving action visualization: {str(e)}")
return ""
def extract_and_save_images(self, data: Any, prefix: str) -> None:
"""Extract and save images from response data.
Args:
data: Response data to extract images from
prefix: Prefix for saved image filenames
"""
# Since we no longer want to save extracted images separately,
# we'll skip this functionality entirely
return
def log_api_call(
self,
call_type: str,
request: Any,
provider: str,
model: str,
response: Any = None,
error: Optional[Exception] = None,
) -> None:
"""Log API call details to file.
Args:
call_type: Type of API call (e.g., 'request', 'response', 'error')
request: The API request data
provider: The AI provider used
model: The AI model used
response: Optional API response data
error: Optional error information
"""
if not self.current_turn_dir:
return
try:
# Create a unique filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"api_call_{timestamp}_{call_type}.json"
filepath = os.path.join(self.current_turn_dir, filename)
# Sanitize data to remove large base64 strings
sanitized_request = self.sanitize_log_data(request)
sanitized_response = self.sanitize_log_data(response) if response is not None else None
# Prepare log data
log_data = {
"timestamp": timestamp,
"provider": provider,
"model": model,
"type": call_type,
"request": sanitized_request,
}
if sanitized_response is not None:
log_data["response"] = sanitized_response
if error is not None:
log_data["error"] = str(error)
# Write to file
with open(filepath, "w") as f:
json.dump(log_data, f, indent=2, default=str)
logger.info(f"Logged API {call_type} to {filepath}")
except Exception as e:
logger.error(f"Error logging API call: {str(e)}")

File diff suppressed because it is too large Load Diff

View File

@@ -1,171 +0,0 @@
"""Omni message manager implementation."""
import base64
from typing import Any, Dict, List, Optional
from io import BytesIO
from PIL import Image
from ...core.messages import BaseMessageManager, ImageRetentionConfig
class OmniMessageManager(BaseMessageManager):
"""Message manager for multi-provider support."""
def __init__(self, config: Optional[ImageRetentionConfig] = None):
"""Initialize the message manager.
Args:
config: Optional configuration for image retention
"""
super().__init__(config)
self.messages: List[Dict[str, Any]] = []
self.config = config
def add_user_message(self, content: str, images: Optional[List[bytes]] = None) -> None:
"""Add a user message to the history.
Args:
content: Message content
images: Optional list of image data
"""
# Add images if present
if images:
# Initialize with proper typing for mixed content
message_content: List[Dict[str, Any]] = [{"type": "text", "text": content}]
# Add each image
for img in images:
message_content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64.b64encode(img).decode()}"
},
}
)
message = {"role": "user", "content": message_content}
else:
# Simple text message
message = {"role": "user", "content": content}
self.messages.append(message)
# Apply retention policy
if self.config and self.config.num_images_to_keep:
self._apply_image_retention_policy()
def add_assistant_message(self, content: str) -> None:
"""Add an assistant message to the history.
Args:
content: Message content
"""
self.messages.append({"role": "assistant", "content": content})
def add_system_message(self, content: str) -> None:
"""Add a system message to the history.
Args:
content: Message content
"""
self.messages.append({"role": "system", "content": content})
def _apply_image_retention_policy(self) -> None:
"""Apply image retention policy to message history."""
if not self.config or not self.config.num_images_to_keep:
return
# Count images from newest to oldest
image_count = 0
for message in reversed(self.messages):
if message["role"] != "user":
continue
# Handle multimodal messages
if isinstance(message["content"], list):
new_content = []
for item in message["content"]:
if item["type"] == "text":
new_content.append(item)
elif item["type"] == "image_url":
if image_count < self.config.num_images_to_keep:
new_content.append(item)
image_count += 1
message["content"] = new_content
def get_formatted_messages(self, provider: str) -> List[Dict[str, Any]]:
"""Get messages formatted for specific provider.
Args:
provider: Provider name to format messages for
Returns:
List of formatted messages
"""
# Set the provider for message formatting
self.set_provider(provider)
if provider == "anthropic":
return self._format_for_anthropic()
elif provider == "openai":
return self._format_for_openai()
elif provider == "groq":
return self._format_for_groq()
elif provider == "qwen":
return self._format_for_qwen()
else:
raise ValueError(f"Unsupported provider: {provider}")
def _format_for_anthropic(self) -> List[Dict[str, Any]]:
"""Format messages for Anthropic API."""
formatted = []
for msg in self.messages:
formatted_msg = {"role": msg["role"]}
# Handle multimodal content
if isinstance(msg["content"], list):
formatted_msg["content"] = []
for item in msg["content"]:
if item["type"] == "text":
formatted_msg["content"].append({"type": "text", "text": item["text"]})
elif item["type"] == "image_url":
formatted_msg["content"].append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": item["image_url"]["url"].split(",")[1],
},
}
)
else:
formatted_msg["content"] = msg["content"]
formatted.append(formatted_msg)
return formatted
def _format_for_openai(self) -> List[Dict[str, Any]]:
"""Format messages for OpenAI API."""
# OpenAI already uses the same format
return self.messages
def _format_for_groq(self) -> List[Dict[str, Any]]:
"""Format messages for Groq API."""
# Groq uses OpenAI-compatible format
return self.messages
def _format_for_qwen(self) -> List[Dict[str, Any]]:
"""Format messages for Qwen API."""
formatted = []
for msg in self.messages:
if isinstance(msg["content"], list):
# Convert multimodal content to text-only
text_content = next(
(item["text"] for item in msg["content"] if item["type"] == "text"), ""
)
formatted.append({"role": msg["role"], "content": text_content})
else:
formatted.append(msg)
return formatted

View File

@@ -1,130 +0,0 @@
"""Visualization utilities for the Cua provider."""
import base64
import logging
from io import BytesIO
from typing import Tuple
from PIL import Image, ImageDraw
logger = logging.getLogger(__name__)
def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
"""Visualize a click action by drawing on the screenshot.
Args:
x: X coordinate of the click
y: Y coordinate of the click
img_base64: Base64 encoded image to draw on
Returns:
PIL Image with visualization
"""
try:
# Decode the base64 image
img_data = base64.b64decode(img_base64)
img = Image.open(BytesIO(img_data))
# Create a drawing context
draw = ImageDraw.Draw(img)
# Draw concentric circles at the click position
small_radius = 10
large_radius = 30
# Draw filled inner circle
draw.ellipse(
[(x - small_radius, y - small_radius), (x + small_radius, y + small_radius)],
fill="red",
)
# Draw outlined outer circle
draw.ellipse(
[(x - large_radius, y - large_radius), (x + large_radius, y + large_radius)],
outline="red",
width=3,
)
return img
except Exception as e:
logger.error(f"Error visualizing click: {str(e)}")
# Return a blank image in case of error
return Image.new("RGB", (800, 600), color="white")
def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
"""Visualize a scroll action by drawing arrows on the screenshot.
Args:
direction: 'up' or 'down'
clicks: Number of scroll clicks
img_base64: Base64 encoded image to draw on
Returns:
PIL Image with visualization
"""
try:
# Decode the base64 image
img_data = base64.b64decode(img_base64)
img = Image.open(BytesIO(img_data))
# Get image dimensions
width, height = img.size
# Create a drawing context
draw = ImageDraw.Draw(img)
# Determine arrow direction and positions
center_x = width // 2
arrow_width = 100
if direction.lower() == "up":
# Draw up arrow in the middle of the screen
arrow_y = height // 2
# Arrow points
points = [
(center_x, arrow_y - 50), # Top point
(center_x - arrow_width // 2, arrow_y + 50), # Bottom left
(center_x + arrow_width // 2, arrow_y + 50), # Bottom right
]
color = "blue"
else: # down
# Draw down arrow in the middle of the screen
arrow_y = height // 2
# Arrow points
points = [
(center_x, arrow_y + 50), # Bottom point
(center_x - arrow_width // 2, arrow_y - 50), # Top left
(center_x + arrow_width // 2, arrow_y - 50), # Top right
]
color = "green"
# Draw filled arrow
draw.polygon(points, fill=color)
# Add text showing number of clicks
text_y = arrow_y + 70 if direction.lower() == "down" else arrow_y - 70
draw.text((center_x - 40, text_y), f"{clicks} clicks", fill="black")
return img
except Exception as e:
logger.error(f"Error visualizing scroll: {str(e)}")
# Return a blank image in case of error
return Image.new("RGB", (800, 600), color="white")
def calculate_element_center(box: Tuple[int, int, int, int]) -> Tuple[int, int]:
"""Calculate the center coordinates of a bounding box.
Args:
box: Tuple of (left, top, right, bottom) coordinates
Returns:
Tuple of (center_x, center_y) coordinates
"""
left, top, right, bottom = box
center_x = (left + right) // 2
center_y = (top + bottom) // 2
return center_x, center_y

View File

@@ -0,0 +1,134 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install openai\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"import os\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"import os\n",
"\n",
"response = requests.post(\n",
" \"https://api.openai.com/v1/responses\",\n",
" headers={\n",
" \"Content-Type\": \"application/json\", \n",
" \"Authorization\": f\"Bearer {os.environ['OPENAI_API_KEY']}\"\n",
" },\n",
" json={\n",
" \"model\": \"computer-use-preview\",\n",
" \"tools\": [{\n",
" \"type\": \"computer_use_preview\",\n",
" \"display_width\": 1024,\n",
" \"display_height\": 768,\n",
" \"environment\": \"mac\" # other possible values: \"mac\", \"windows\", \"ubuntu\"\n",
" }],\n",
" \"input\": [\n",
" {\n",
" \"role\": \"user\",\n",
" \"content\": \"Check the latest OpenAI news on bing.com.\"\n",
" }\n",
" # Optional: include a screenshot of the initial state of the environment\n",
" # {\n",
" # type: \"input_image\", \n",
" # image_url: f\"data:image/png;base64,{screenshot_base64}\"\n",
" # }\n",
" ],\n",
" \"reasoning\": {\n",
" \"generate_summary\": \"concise\",\n",
" },\n",
" \"truncation\": \"auto\"\n",
" }\n",
")\n",
"\n",
"print(response.json())"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n"
]
}
],
"source": [
"from openai import OpenAI\n",
"client = OpenAI() # assumes OPENAI_API_KEY is set in env\n",
"\n",
"def has_model_starting_with(prefix=\"computer\"):\n",
" models = client.models.list().data\n",
" return any(model.id.startswith(prefix) for model in models)\n",
"\n",
"print(has_model_starting_with())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"import os\n",
"\n",
"response = requests.post(\n",
" \"https://api.openai.com/v1/responses\",\n",
" headers={\n",
" \"Content-Type\": \"application/json\",\n",
" \"Authorization\": f\"Bearer {os.environ['OPENAI_API_KEY']}\"\n",
" },\n",
" json={\n",
" \"model\": \"gpt-4o\",\n",
" \"input\": \"Tell me a three sentence bedtime story about a unicorn.\"\n",
" }\n",
")\n",
"print(response.json())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}