Merge pull request #65 from trycua/feature/agent/agent-loop

[Agent] Standardize Agent Loop
This commit is contained in:
f-trycua
2025-03-24 12:41:15 -07:00
committed by GitHub
38 changed files with 2927 additions and 2669 deletions

View File

@@ -1,17 +1,14 @@
"""Example demonstrating the ComputerAgent capabilities with the Omni provider."""
import os
import asyncio
import logging
import traceback
from pathlib import Path
import signal
from computer import Computer
# Import the unified agent class and types
from agent import AgentLoop, LLMProvider, LLM
from agent.core.computer_agent import ComputerAgent
from agent import ComputerAgent, LLMProvider, LLM, AgentLoop
# Import utility functions
from utils import load_dotenv_files, handle_sigint
@@ -21,7 +18,7 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def run_omni_agent_example():
async def run_agent_example():
"""Run example of using the ComputerAgent with OpenAI and Omni provider."""
print("\n=== Example: ComputerAgent with OpenAI and Omni provider ===")
@@ -32,42 +29,31 @@ async def run_omni_agent_example():
# Create agent with loop and provider
agent = ComputerAgent(
computer=computer,
loop=AgentLoop.ANTHROPIC,
# loop=AgentLoop.OMNI,
# model=LLM(provider=LLMProvider.OPENAI, name="gpt-4.5-preview"),
model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"),
# loop=AgentLoop.ANTHROPIC,
loop=AgentLoop.OMNI,
model=LLM(provider=LLMProvider.OPENAI, name="gpt-4.5-preview"),
# model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"),
save_trajectory=True,
trajectory_dir=str(Path("trajectories")),
only_n_most_recent_images=3,
verbosity=logging.INFO,
verbosity=logging.DEBUG,
)
tasks = [
"""
1. Look for a repository named trycua/lume on GitHub.
2. Check the open issues, open the most recent one and read it.
3. Clone the repository in users/lume/projects if it doesn't exist yet.
4. Open the repository with an app named Cursor (on the dock, black background and white cube icon).
5. From Cursor, open Composer if not already open.
6. Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.
"""
"Look for a repository named trycua/cua on GitHub.",
"Check the open issues, open the most recent one and read it.",
"Clone the repository in users/lume/projects if it doesn't exist yet.",
"Open the repository with an app named Cursor (on the dock, black background and white cube icon).",
"From Cursor, open Composer if not already open.",
"Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.",
]
async with agent:
for i, task in enumerate(tasks, 1):
print(f"\nExecuting task {i}/{len(tasks)}: {task}")
async for result in agent.run(task):
# Check if result has the expected structure
if "role" in result and "content" in result and "metadata" in result:
title = result["metadata"].get("title", "Screen Analysis")
content = result["content"]
else:
title = result.get("metadata", {}).get("title", "Screen Analysis")
content = result.get("content", str(result))
for i, task in enumerate(tasks):
print(f"\nExecuting task {i}/{len(tasks)}: {task}")
async for result in agent.run(task):
# print(result)
pass
print(f"\n{title}")
print(content)
print(f"Task {i} completed")
print(f"\n✅ Task {i+1}/{len(tasks)} completed: {task}")
except Exception as e:
logger.error(f"Error in run_omni_agent_example: {e}")
@@ -91,7 +77,7 @@ def main():
# Register signal handler for graceful exit
signal.signal(signal.SIGINT, handle_sigint)
asyncio.run(run_omni_agent_example())
asyncio.run(run_agent_example())
except Exception as e:
print(f"Error running example: {e}")
traceback.print_exc()

View File

@@ -49,6 +49,7 @@ except Exception as e:
logger.warning(f"Error initializing telemetry: {e}")
from .providers.omni.types import LLMProvider, LLM
from .types.base import AgentLoop
from .core.loop import AgentLoop
from .core.computer_agent import ComputerAgent
__all__ = ["AgentLoop", "LLMProvider", "LLM"]
__all__ = ["AgentLoop", "LLMProvider", "LLM", "ComputerAgent"]

View File

@@ -2,11 +2,6 @@
from .loop import BaseLoop
from .messages import (
create_user_message,
create_assistant_message,
create_system_message,
create_image_message,
create_screen_message,
BaseMessageManager,
ImageRetentionConfig,
)

View File

@@ -3,8 +3,7 @@
import asyncio
import logging
import os
from typing import Any, AsyncGenerator, Dict, Optional, cast
from dataclasses import dataclass
from typing import Any, AsyncGenerator, Dict, Optional, cast, List
from computer import Computer
from ..providers.anthropic.loop import AnthropicLoop
@@ -12,6 +11,8 @@ from ..providers.omni.loop import OmniLoop
from ..providers.omni.parser import OmniParser
from ..providers.omni.types import LLMProvider, LLM
from .. import AgentLoop
from .messages import StandardMessageManager, ImageRetentionConfig
from .types import AgentResponse
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -44,7 +45,6 @@ class ComputerAgent:
save_trajectory: bool = True,
trajectory_dir: str = "trajectories",
only_n_most_recent_images: Optional[int] = None,
parser: Optional[OmniParser] = None,
verbosity: int = logging.INFO,
):
"""Initialize the ComputerAgent.
@@ -61,12 +61,11 @@ class ComputerAgent:
save_trajectory: Whether to save the trajectory.
trajectory_dir: Directory to save the trajectory.
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests.
parser: Parser instance for the OmniLoop. Only used if provider is not ANTHROPIC.
verbosity: Logging level.
"""
# Basic agent configuration
self.max_retries = max_retries
self.computer = computer or Computer()
self.computer = computer
self.queue = asyncio.Queue()
self.screenshot_dir = screenshot_dir
self.log_dir = log_dir
@@ -100,7 +99,7 @@ class ComputerAgent:
)
# Ensure computer is properly cast for typing purposes
computer_instance = cast(Computer, self.computer)
computer_instance = self.computer
# Get API key from environment if not provided
actual_api_key = api_key or os.environ.get(ENV_VARS[self.provider], "")
@@ -118,10 +117,6 @@ class ComputerAgent:
only_n_most_recent_images=only_n_most_recent_images,
)
else:
# Default to OmniLoop for other loop types
# Initialize parser if not provided
actual_parser = parser or OmniParser()
self._loop = OmniLoop(
provider=self.provider,
api_key=actual_api_key,
@@ -130,9 +125,12 @@ class ComputerAgent:
save_trajectory=save_trajectory,
base_dir=trajectory_dir,
only_n_most_recent_images=only_n_most_recent_images,
parser=actual_parser,
parser=OmniParser(),
)
# Initialize the message manager from the loop
self.message_manager = self._loop.message_manager
logger.info(
f"ComputerAgent initialized with provider: {self.provider}, model: {actual_model_name}"
)
@@ -201,36 +199,30 @@ class ComputerAgent:
await self.computer.run()
self._initialized = True
async def _init_if_needed(self):
"""Initialize the computer interface if it hasn't been initialized yet."""
if not self.computer._initialized:
logger.info("Computer not initialized, initializing now...")
try:
# Call run directly
await self.computer.run()
logger.info("Computer interface initialized successfully")
except Exception as e:
logger.error(f"Error initializing computer interface: {str(e)}")
raise
async def run(self, task: str) -> AsyncGenerator[Dict[str, Any], None]:
async def run(self, task: str) -> AsyncGenerator[AgentResponse, None]:
"""Run a task using the computer agent.
Args:
task: Task description
Yields:
Task execution updates
Agent response format
"""
try:
logger.info(f"Running task: {task}")
logger.info(
f"Message history before task has {len(self.message_manager.messages)} messages"
)
# Initialize the computer if needed
if not self._initialized:
await self.initialize()
# Format task as a message
messages = [{"role": "user", "content": task}]
# Add task as a user message using the message manager
self.message_manager.add_user_message([{"type": "text", "text": task}])
logger.info(
f"Added task message. Message history now has {len(self.message_manager.messages)} messages"
)
# Pass properly formatted messages to the loop
if self._loop is None:
@@ -239,7 +231,8 @@ class ComputerAgent:
return
# Execute the task and yield results
async for result in self._loop.run(messages):
async for result in self._loop.run(self.message_manager.messages):
# Yield the result to the caller
yield result
except Exception as e:

View File

@@ -2,22 +2,34 @@
import logging
import asyncio
import json
import os
from abc import ABC, abstractmethod
from enum import Enum, auto
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
from datetime import datetime
import base64
from computer import Computer
from .experiment import ExperimentManager
from .messages import StandardMessageManager, ImageRetentionConfig
from .types import AgentResponse
logger = logging.getLogger(__name__)
class AgentLoop(Enum):
"""Enumeration of available loop types."""
ANTHROPIC = auto() # Anthropic implementation
OMNI = auto() # OmniLoop implementation
# Add more loop types as needed
class BaseLoop(ABC):
"""Base class for agent loops that handle message processing and tool execution."""
###########################################
# INITIALIZATION AND CONFIGURATION
###########################################
def __init__(
self,
computer: Computer,
@@ -55,8 +67,6 @@ class BaseLoop(ABC):
self.save_trajectory = save_trajectory
self.only_n_most_recent_images = only_n_most_recent_images
self._kwargs = kwargs
self.message_history = []
# self.tool_manager = BaseToolManager(computer)
# Initialize experiment manager
if self.save_trajectory and self.base_dir:
@@ -75,61 +85,6 @@ class BaseLoop(ABC):
# Initialize basic tracking
self.turn_count = 0
def _setup_experiment_dirs(self) -> None:
"""Setup the experiment directory structure."""
if self.experiment_manager:
# Use the experiment manager to set up directories
self.experiment_manager.setup_experiment_dirs()
# Update local tracking variables
self.run_dir = self.experiment_manager.run_dir
self.current_turn_dir = self.experiment_manager.current_turn_dir
def _create_turn_dir(self) -> None:
"""Create a new directory for the current turn."""
if self.experiment_manager:
# Use the experiment manager to create the turn directory
self.experiment_manager.create_turn_dir()
# Update local tracking variables
self.current_turn_dir = self.experiment_manager.current_turn_dir
self.turn_count = self.experiment_manager.turn_count
def _log_api_call(
self, call_type: str, request: Any, response: Any = None, error: Optional[Exception] = None
) -> None:
"""Log API call details to file.
Args:
call_type: Type of API call (e.g., 'request', 'response', 'error')
request: The API request data
response: Optional API response data
error: Optional error information
"""
if self.experiment_manager:
# Use the experiment manager to log the API call
provider = getattr(self, "provider", "unknown")
provider_str = str(provider) if provider else "unknown"
self.experiment_manager.log_api_call(
call_type=call_type,
request=request,
provider=provider_str,
model=self.model,
response=response,
error=error,
)
def _save_screenshot(self, img_base64: str, action_type: str = "") -> None:
"""Save a screenshot to the experiment directory.
Args:
img_base64: Base64 encoded screenshot
action_type: Type of action that triggered the screenshot
"""
if self.experiment_manager:
self.experiment_manager.save_screenshot(img_base64, action_type)
async def initialize(self) -> None:
"""Initialize both the API client and computer interface with retries."""
for attempt in range(self.max_retries):
@@ -155,94 +110,93 @@ class BaseLoop(ABC):
)
raise RuntimeError(f"Failed to initialize: {str(e)}")
async def _get_parsed_screen_som(self) -> Dict[str, Any]:
"""Get parsed screen information.
###########################################
Returns:
Dict containing screen information
"""
try:
# Take screenshot
screenshot = await self.computer.interface.screenshot()
# Initialize with default values
width, height = 1024, 768
base64_image = ""
# Handle different types of screenshot returns
if isinstance(screenshot, (bytes, bytearray, memoryview)):
# Raw bytes screenshot
base64_image = base64.b64encode(screenshot).decode("utf-8")
elif hasattr(screenshot, "base64_image"):
# Object-style screenshot with attributes
# Type checking can't infer these attributes, but they exist at runtime
# on certain screenshot return types
base64_image = getattr(screenshot, "base64_image")
width = (
getattr(screenshot, "width", width) if hasattr(screenshot, "width") else width
)
height = (
getattr(screenshot, "height", height)
if hasattr(screenshot, "height")
else height
)
# Create parsed screen data
parsed_screen = {
"width": width,
"height": height,
"parsed_content_list": [],
"timestamp": datetime.now().isoformat(),
"screenshot_base64": base64_image,
}
# Save screenshot if requested
if self.save_trajectory and self.experiment_manager:
try:
img_data = base64_image
if "," in img_data:
img_data = img_data.split(",")[1]
self._save_screenshot(img_data, action_type="state")
except Exception as e:
logger.error(f"Error saving screenshot: {str(e)}")
return parsed_screen
except Exception as e:
logger.error(f"Error taking screenshot: {str(e)}")
return {
"width": 1024,
"height": 768,
"parsed_content_list": [],
"timestamp": datetime.now().isoformat(),
"error": f"Error taking screenshot: {str(e)}",
"screenshot_base64": "",
}
# ABSTRACT METHODS TO BE IMPLEMENTED BY SUBCLASSES
###########################################
@abstractmethod
async def initialize_client(self) -> None:
"""Initialize the API client and any provider-specific components."""
"""Initialize the API client and any provider-specific components.
This method must be implemented by subclasses to set up
provider-specific clients and tools.
"""
raise NotImplementedError
@abstractmethod
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
"""Run the agent loop with provided messages.
This method handles the main agent loop including message processing,
API calls, response handling, and action execution.
Args:
messages: List of message objects
Yields:
Dict containing response data
Agent response format
"""
raise NotImplementedError
@abstractmethod
async def _process_screen(
self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]]
###########################################
# EXPERIMENT AND TRAJECTORY MANAGEMENT
###########################################
def _setup_experiment_dirs(self) -> None:
"""Setup the experiment directory structure."""
if self.experiment_manager:
# Use the experiment manager to set up directories
self.experiment_manager.setup_experiment_dirs()
# Update local tracking variables
self.run_dir = self.experiment_manager.run_dir
self.current_turn_dir = self.experiment_manager.current_turn_dir
def _create_turn_dir(self) -> None:
"""Create a new directory for the current turn."""
if self.experiment_manager:
# Use the experiment manager to create the turn directory
self.experiment_manager.create_turn_dir()
# Update local tracking variables
self.current_turn_dir = self.experiment_manager.current_turn_dir
self.turn_count = self.experiment_manager.turn_count
def _log_api_call(
self, call_type: str, request: Any, response: Any = None, error: Optional[Exception] = None
) -> None:
"""Process screen information and add to messages.
"""Log API call details to file.
Preserves provider-specific formats for requests and responses to ensure
accurate logging for debugging and analysis purposes.
Args:
parsed_screen: Dictionary containing parsed screen info
messages: List of messages to update
call_type: Type of API call (e.g., 'request', 'response', 'error')
request: The API request data in provider-specific format
response: Optional API response data in provider-specific format
error: Optional error information
"""
raise NotImplementedError
if self.experiment_manager:
# Use the experiment manager to log the API call
provider = getattr(self, "provider", "unknown")
provider_str = str(provider) if provider else "unknown"
self.experiment_manager.log_api_call(
call_type=call_type,
request=request,
provider=provider_str,
model=self.model,
response=response,
error=error,
)
def _save_screenshot(self, img_base64: str, action_type: str = "") -> None:
"""Save a screenshot to the experiment directory.
Args:
img_base64: Base64 encoded screenshot
action_type: Type of action that triggered the screenshot
"""
if self.experiment_manager:
self.experiment_manager.save_screenshot(img_base64, action_type)

View File

@@ -1,12 +1,11 @@
"""Message handling utilities for agent."""
import base64
from datetime import datetime
from io import BytesIO
import logging
from typing import Any, Dict, List, Optional, Union
from PIL import Image
import json
from typing import Any, Dict, List, Optional, Union, Tuple
from dataclasses import dataclass
import re
from ..providers.omni.parser import ParseResult
logger = logging.getLogger(__name__)
@@ -123,123 +122,278 @@ class BaseMessageManager:
break
def create_user_message(text: str) -> Dict[str, str]:
"""Create a user message.
class StandardMessageManager:
"""Manages messages in a standardized OpenAI format across different providers."""
Args:
text: The message text
def __init__(self, config: Optional[ImageRetentionConfig] = None):
"""Initialize message manager.
Returns:
Message dictionary
"""
return {
"role": "user",
"content": text,
}
Args:
config: Configuration for image retention
"""
self.messages: List[Dict[str, Any]] = []
self.config = config or ImageRetentionConfig()
def add_user_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
"""Add a user message.
def create_assistant_message(text: str) -> Dict[str, str]:
"""Create an assistant message.
Args:
content: Message content (text or multimodal content)
"""
self.messages.append({"role": "user", "content": content})
Args:
text: The message text
def add_assistant_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
"""Add an assistant message.
Returns:
Message dictionary
"""
return {
"role": "assistant",
"content": text,
}
Args:
content: Message content (text or multimodal content)
"""
self.messages.append({"role": "assistant", "content": content})
def add_system_message(self, content: str) -> None:
"""Add a system message.
def create_system_message(text: str) -> Dict[str, str]:
"""Create a system message.
Args:
content: System message content
"""
self.messages.append({"role": "system", "content": content})
Args:
text: The message text
def get_messages(self) -> List[Dict[str, Any]]:
"""Get all messages in standard format.
Returns:
Message dictionary
"""
return {
"role": "system",
"content": text,
}
Returns:
List of messages
"""
# If image retention is configured, apply it
if self.config.num_images_to_keep is not None:
return self._apply_image_retention(self.messages)
return self.messages
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Apply image retention policy to messages.
def create_image_message(
image_base64: Optional[str] = None,
image_path: Optional[str] = None,
image_obj: Optional[Image.Image] = None,
) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
"""Create a message with an image.
Args:
messages: List of messages
Args:
image_base64: Base64 encoded image
image_path: Path to image file
image_obj: PIL Image object
Returns:
List of messages with image retention applied
"""
if not self.config.num_images_to_keep:
return messages
Returns:
Message dictionary with content list
# Find user messages with images
image_messages = []
for msg in messages:
if msg["role"] == "user" and isinstance(msg["content"], list):
has_image = any(
item.get("type") == "image_url" or item.get("type") == "image"
for item in msg["content"]
)
if has_image:
image_messages.append(msg)
Raises:
ValueError: If no image source is provided
"""
if not any([image_base64, image_path, image_obj]):
raise ValueError("Must provide one of image_base64, image_path, or image_obj")
# If we don't have more images than the limit, return all messages
if len(image_messages) <= self.config.num_images_to_keep:
return messages
# Convert to base64 if needed
if image_path and not image_base64:
with open(image_path, "rb") as f:
image_bytes = f.read()
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
elif image_obj and not image_base64:
buffer = BytesIO()
image_obj.save(buffer, format="PNG")
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
# Get the most recent N images to keep
images_to_keep = image_messages[-self.config.num_images_to_keep :]
images_to_remove = image_messages[: -self.config.num_images_to_keep]
return {
"role": "user",
"content": [
{"type": "image", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
],
}
# Create a new message list without the older images
result = []
for msg in messages:
if msg in images_to_remove:
# Skip this message
continue
result.append(msg)
return result
def create_screen_message(
parsed_screen: Dict[str, Any],
include_raw: bool = False,
) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
"""Create a message with screen information.
def to_anthropic_format(
self, messages: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], str]:
"""Convert standard OpenAI format messages to Anthropic format.
Args:
parsed_screen: Dictionary containing parsed screen info
include_raw: Whether to include raw screenshot base64
Args:
messages: List of messages in OpenAI format
Returns:
Message dictionary with content
"""
if include_raw and "screenshot_base64" in parsed_screen:
# Create content list with both image and text
return {
"role": "user",
"content": [
{
"type": "image",
"image_url": {
"url": f"data:image/png;base64,{parsed_screen['screenshot_base64']}"
},
},
{
"type": "text",
"text": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
},
],
}
else:
# Create text-only message with screen info
return {
"role": "user",
"content": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
}
Returns:
Tuple containing (anthropic_messages, system_content)
"""
result = []
system_content = ""
# Process messages in order to maintain conversation flow
previous_assistant_tool_use_ids = (
set()
) # Track tool_use_ids in the previous assistant message
for i, msg in enumerate(messages):
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
# Collect system messages for later use
system_content += content + "\n"
continue
if role == "assistant":
# Track tool_use_ids in this assistant message for the next user message
previous_assistant_tool_use_ids = set()
if isinstance(content, list):
for item in content:
if (
isinstance(item, dict)
and item.get("type") == "tool_use"
and "id" in item
):
previous_assistant_tool_use_ids.add(item["id"])
logger.info(
f"Tool use IDs in assistant message #{i}: {previous_assistant_tool_use_ids}"
)
if role in ["user", "assistant"]:
anthropic_msg = {"role": role}
# Convert content based on type
if isinstance(content, str):
# Simple text content
anthropic_msg["content"] = [{"type": "text", "text": content}]
elif isinstance(content, list):
# Convert complex content
anthropic_content = []
for item in content:
item_type = item.get("type", "")
if item_type == "text":
anthropic_content.append({"type": "text", "text": item.get("text", "")})
elif item_type == "image_url":
# Convert OpenAI image format to Anthropic
image_url = item.get("image_url", {}).get("url", "")
if image_url.startswith("data:"):
# Extract base64 data and media type
match = re.match(r"data:(.+);base64,(.+)", image_url)
if match:
media_type, data = match.groups()
anthropic_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data,
},
}
)
else:
# Regular URL
anthropic_content.append(
{
"type": "image",
"source": {
"type": "url",
"url": image_url,
},
}
)
elif item_type == "tool_use":
# Always include tool_use blocks
anthropic_content.append(item)
elif item_type == "tool_result":
# Check if this is a user message AND if the tool_use_id exists in the previous assistant message
tool_use_id = item.get("tool_use_id")
# Only include tool_result if it references a tool_use from the immediately preceding assistant message
if (
role == "user"
and tool_use_id
and tool_use_id in previous_assistant_tool_use_ids
):
anthropic_content.append(item)
logger.info(
f"Including tool_result with tool_use_id: {tool_use_id}"
)
else:
# Convert to text to preserve information
logger.warning(
f"Converting tool_result to text. Tool use ID {tool_use_id} not found in previous assistant message"
)
content_text = "Tool Result: "
if "content" in item:
if isinstance(item["content"], list):
for content_item in item["content"]:
if (
isinstance(content_item, dict)
and content_item.get("type") == "text"
):
content_text += content_item.get("text", "")
elif isinstance(item["content"], str):
content_text += item["content"]
anthropic_content.append({"type": "text", "text": content_text})
anthropic_msg["content"] = anthropic_content
result.append(anthropic_msg)
return result, system_content
def from_anthropic_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert Anthropic format messages to standard OpenAI format.
Args:
messages: List of messages in Anthropic format
Returns:
List of messages in OpenAI format
"""
result = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", [])
if role in ["user", "assistant"]:
openai_msg = {"role": role}
# Simple case: single text block
if len(content) == 1 and content[0].get("type") == "text":
openai_msg["content"] = content[0].get("text", "")
else:
# Complex case: multiple blocks or non-text
openai_content = []
for item in content:
item_type = item.get("type", "")
if item_type == "text":
openai_content.append({"type": "text", "text": item.get("text", "")})
elif item_type == "image":
# Convert Anthropic image to OpenAI format
source = item.get("source", {})
if source.get("type") == "base64":
media_type = source.get("media_type", "image/png")
data = source.get("data", "")
openai_content.append(
{
"type": "image_url",
"image_url": {"url": f"data:{media_type};base64,{data}"},
}
)
else:
# URL
openai_content.append(
{
"type": "image_url",
"image_url": {"url": source.get("url", "")},
}
)
elif item_type in ["tool_use", "tool_result"]:
# Pass through tool-related content
openai_content.append(item)
openai_msg["content"] = openai_content
result.append(openai_msg)
return result

View File

@@ -0,0 +1,35 @@
"""Core type definitions."""
from typing import Any, Dict, List, Optional, TypedDict, Union
class AgentResponse(TypedDict, total=False):
"""Agent response format."""
id: str
object: str
created_at: int
status: str
error: Optional[str]
incomplete_details: Optional[Any]
instructions: Optional[Any]
max_output_tokens: Optional[int]
model: str
output: List[Dict[str, Any]]
parallel_tool_calls: bool
previous_response_id: Optional[str]
reasoning: Dict[str, str]
store: bool
temperature: float
text: Dict[str, Dict[str, str]]
tool_choice: str
tools: List[Dict[str, Union[str, int]]]
top_p: float
truncation: str
usage: Dict[str, Any]
user: Optional[str]
metadata: Dict[str, Any]
response: Dict[str, List[Dict[str, Any]]]
# Additional fields for error responses
role: str
content: Union[str, List[Dict[str, Any]]]

View File

@@ -0,0 +1,197 @@
"""Core visualization utilities for agents."""
import logging
import base64
from typing import Dict, Tuple
from PIL import Image, ImageDraw
from io import BytesIO
logger = logging.getLogger(__name__)
def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
"""Visualize a click action by drawing a circle on the screenshot.
Args:
x: X coordinate of the click
y: Y coordinate of the click
img_base64: Base64-encoded screenshot
Returns:
PIL Image with visualization
"""
try:
# Decode the base64 image
image_data = base64.b64decode(img_base64)
img = Image.open(BytesIO(image_data))
# Create a copy to draw on
draw_img = img.copy()
draw = ImageDraw.Draw(draw_img)
# Draw a circle at the click location
radius = 15
draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], outline="red", width=3)
# Draw crosshairs
line_length = 20
draw.line([(x - line_length, y), (x + line_length, y)], fill="red", width=3)
draw.line([(x, y - line_length), (x, y + line_length)], fill="red", width=3)
return draw_img
except Exception as e:
logger.error(f"Error visualizing click: {str(e)}")
# Return a blank image as fallback
return Image.new("RGB", (800, 600), "white")
def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
"""Visualize a scroll action by drawing arrows on the screenshot.
Args:
direction: Direction of scroll ('up' or 'down')
clicks: Number of scroll clicks
img_base64: Base64-encoded screenshot
Returns:
PIL Image with visualization
"""
try:
# Decode the base64 image
image_data = base64.b64decode(img_base64)
img = Image.open(BytesIO(image_data))
# Create a copy to draw on
draw_img = img.copy()
draw = ImageDraw.Draw(draw_img)
# Calculate parameters for visualization
width, height = img.size
center_x = width // 2
# Draw arrows to indicate scrolling
arrow_length = min(100, height // 4)
arrow_width = 30
num_arrows = min(clicks, 3) # Don't draw too many arrows
# Calculate starting position
if direction == "down":
start_y = height // 3
arrow_dir = 1 # Down
else:
start_y = height * 2 // 3
arrow_dir = -1 # Up
# Draw the arrows
for i in range(num_arrows):
y_pos = start_y + (i * arrow_length * arrow_dir * 0.7)
arrow_top = (center_x, y_pos)
arrow_bottom = (center_x, y_pos + arrow_length * arrow_dir)
# Draw the main line
draw.line([arrow_top, arrow_bottom], fill="red", width=5)
# Draw the arrowhead
arrowhead_size = 20
if direction == "down":
draw.line(
[
(center_x - arrow_width // 2, arrow_bottom[1] - arrowhead_size),
arrow_bottom,
(center_x + arrow_width // 2, arrow_bottom[1] - arrowhead_size),
],
fill="red",
width=5,
)
else:
draw.line(
[
(center_x - arrow_width // 2, arrow_bottom[1] + arrowhead_size),
arrow_bottom,
(center_x + arrow_width // 2, arrow_bottom[1] + arrowhead_size),
],
fill="red",
width=5,
)
return draw_img
except Exception as e:
logger.error(f"Error visualizing scroll: {str(e)}")
# Return a blank image as fallback
return Image.new("RGB", (800, 600), "white")
def calculate_element_center(bbox: Dict[str, float], width: int, height: int) -> Tuple[int, int]:
"""Calculate the center point of a UI element.
Args:
bbox: Bounding box dictionary with x1, y1, x2, y2 coordinates (0-1 normalized)
width: Screen width in pixels
height: Screen height in pixels
Returns:
(x, y) tuple with pixel coordinates
"""
center_x = int((bbox["x1"] + bbox["x2"]) / 2 * width)
center_y = int((bbox["y1"] + bbox["y2"]) / 2 * height)
return center_x, center_y
class VisualizationHelper:
"""Helper class for visualizing agent actions."""
def __init__(self, agent):
"""Initialize visualization helper.
Args:
agent: Reference to the agent that will use this helper
"""
self.agent = agent
def visualize_action(self, x: int, y: int, img_base64: str) -> None:
"""Visualize a click action by drawing on the screenshot."""
if (
not self.agent.save_trajectory
or not hasattr(self.agent, "experiment_manager")
or not self.agent.experiment_manager
):
return
try:
# Use the visualization utility
img = visualize_click(x, y, img_base64)
# Save the visualization
self.agent.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
except Exception as e:
logger.error(f"Error visualizing action: {str(e)}")
def visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
"""Visualize a scroll action by drawing arrows on the screenshot."""
if (
not self.agent.save_trajectory
or not hasattr(self.agent, "experiment_manager")
or not self.agent.experiment_manager
):
return
try:
# Use the visualization utility
img = visualize_scroll(direction, clicks, img_base64)
# Save the visualization
self.agent.experiment_manager.save_action_visualization(
img, "scroll", f"{direction}_{clicks}"
)
except Exception as e:
logger.error(f"Error visualizing scroll: {str(e)}")
def save_action_visualization(
self, img: Image.Image, action_name: str, details: str = ""
) -> str:
"""Save a visualization of an action."""
if hasattr(self.agent, "experiment_manager") and self.agent.experiment_manager:
return self.agent.experiment_manager.save_action_visualization(
img, action_name, details
)
return ""

View File

@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, List, Dict, cast
import httpx
import asyncio
from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex
@@ -80,6 +80,147 @@ class BaseAnthropicClient:
f"Failed after {self.MAX_RETRIES} retries. " f"Last error: {str(last_error)}"
)
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: int = 4096
) -> Any:
"""Run the Anthropic API with the Claude model, supports interleaved tool calling.
Args:
messages: List of message objects
system: System prompt
max_tokens: Maximum tokens to generate
Returns:
API response
"""
# Add the tool_result check/fix logic here
fixed_messages = self._fix_missing_tool_results(messages)
# Get model name from concrete implementation if available
model_name = getattr(self, "model", "unknown model")
logger.info(f"Running Anthropic API call with model {model_name}")
retry_count = 0
while retry_count < self.MAX_RETRIES:
try:
# Call the Anthropic API through create_message which is implemented by subclasses
# Convert system str to the list format expected by create_message
system_list = [system]
# Convert message format if needed - concrete implementations may do further conversion
response = await self.create_message(
messages=cast(list[BetaMessageParam], fixed_messages),
system=system_list,
tools=[], # Tools are included in the messages
max_tokens=max_tokens,
betas=["tools-2023-12-13"],
)
logger.info(f"Anthropic API call successful")
return response
except Exception as e:
retry_count += 1
wait_time = self.INITIAL_RETRY_DELAY * (
2 ** (retry_count - 1)
) # Exponential backoff
logger.info(
f"Retrying request (attempt {retry_count}/{self.MAX_RETRIES}) in {wait_time:.2f} seconds after error: {str(e)}"
)
await asyncio.sleep(wait_time)
# If we get here, all retries failed
raise RuntimeError(f"Failed to call Anthropic API after {self.MAX_RETRIES} attempts")
def _fix_missing_tool_results(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Check for and fix any missing tool_result blocks after tool_use blocks.
Args:
messages: List of message objects
Returns:
Fixed messages with proper tool_result blocks
"""
fixed_messages = []
pending_tool_uses = {} # Map of tool_use IDs to their details
for i, message in enumerate(messages):
# Track any tool_use blocks in this message
if message.get("role") == "assistant" and "content" in message:
content = message.get("content", [])
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_use":
tool_id = block.get("id")
if tool_id:
pending_tool_uses[tool_id] = {
"name": block.get("name", ""),
"input": block.get("input", {}),
}
# Check if this message handles any pending tool_use blocks
if message.get("role") == "user" and "content" in message:
# Check for tool_result blocks in this message
content = message.get("content", [])
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_result":
tool_id = block.get("tool_use_id")
if tool_id in pending_tool_uses:
# This tool_result handles a pending tool_use
pending_tool_uses.pop(tool_id)
# Add the message to our fixed list
fixed_messages.append(message)
# If this is an assistant message with tool_use blocks and there are
# pending tool uses that need to be resolved before the next assistant message
if (
i + 1 < len(messages)
and message.get("role") == "assistant"
and messages[i + 1].get("role") == "assistant"
and pending_tool_uses
):
# We need to insert a user message with tool_results for all pending tool_uses
tool_results = []
for tool_id, tool_info in pending_tool_uses.items():
tool_results.append(
{
"type": "tool_result",
"tool_use_id": tool_id,
"content": {
"type": "error",
"message": "Tool execution was skipped or failed",
},
}
)
# Insert a synthetic user message with the tool results
if tool_results:
fixed_messages.append({"role": "user", "content": tool_results})
# Clear pending tools since we've added results for them
pending_tool_uses = {}
# Check if there are any remaining pending tool_uses at the end of the conversation
if pending_tool_uses and fixed_messages and fixed_messages[-1].get("role") == "assistant":
# Add a final user message with tool results for any pending tool_uses
tool_results = []
for tool_id, tool_info in pending_tool_uses.items():
tool_results.append(
{
"type": "tool_result",
"tool_use_id": tool_id,
"content": {
"type": "error",
"message": "Tool execution was skipped or failed",
},
}
)
if tool_results:
fixed_messages.append({"role": "user", "content": tool_results})
return fixed_messages
class AnthropicDirectClient(BaseAnthropicClient):
"""Direct Anthropic API client implementation."""

View File

@@ -0,0 +1,140 @@
"""API call handling for Anthropic provider."""
import logging
import asyncio
from typing import List
from anthropic.types.beta import (
BetaMessage,
BetaMessageParam,
BetaTextBlockParam,
)
from .types import LLMProvider
from .prompts import SYSTEM_PROMPT
# Constants
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
logger = logging.getLogger(__name__)
class AnthropicAPIHandler:
"""Handles API calls to Anthropic's API with structured error handling and retries."""
def __init__(self, loop):
"""Initialize the API handler.
Args:
loop: Reference to the parent loop instance that provides context
"""
self.loop = loop
async def make_api_call(
self, messages: List[BetaMessageParam], system_prompt: str = SYSTEM_PROMPT
) -> BetaMessage:
"""Make API call to Anthropic with retry logic.
Args:
messages: List of messages to send to the API
system_prompt: System prompt to use (default: SYSTEM_PROMPT)
Returns:
API response
Raises:
RuntimeError: If API call fails after all retries
"""
if self.loop.client is None:
raise RuntimeError("Client not initialized. Call initialize_client() first.")
if self.loop.tool_manager is None:
raise RuntimeError("Tool manager not initialized. Call initialize_client() first.")
last_error = None
# Add detailed debug logging to examine messages
logger.info(f"Sending {len(messages)} messages to Anthropic API")
# Log tool use IDs and tool result IDs for debugging
tool_use_ids = set()
tool_result_ids = set()
for i, msg in enumerate(messages):
logger.info(f"Message {i}: role={msg.get('role')}")
if isinstance(msg.get("content"), list):
for content_block in msg.get("content", []):
if isinstance(content_block, dict):
block_type = content_block.get("type")
if block_type == "tool_use" and "id" in content_block:
tool_id = content_block.get("id")
tool_use_ids.add(tool_id)
logger.info(f" - Found tool_use with ID: {tool_id}")
elif block_type == "tool_result" and "tool_use_id" in content_block:
result_id = content_block.get("tool_use_id")
tool_result_ids.add(result_id)
logger.info(f" - Found tool_result referencing ID: {result_id}")
# Check for mismatches
missing_tool_uses = tool_result_ids - tool_use_ids
if missing_tool_uses:
logger.warning(
f"Found tool_result IDs without matching tool_use IDs: {missing_tool_uses}"
)
for attempt in range(self.loop.max_retries):
try:
# Log request
request_data = {
"messages": messages,
"max_tokens": self.loop.max_tokens,
"system": system_prompt,
}
# Let ExperimentManager handle sanitization
self.loop._log_api_call("request", request_data)
# Setup betas and system
system = BetaTextBlockParam(
type="text",
text=system_prompt,
)
betas = [COMPUTER_USE_BETA_FLAG]
# Add prompt caching if enabled in the message manager's config
if self.loop.message_manager.config.enable_caching:
betas.append(PROMPT_CACHING_BETA_FLAG)
system["cache_control"] = {"type": "ephemeral"}
# Make API call
response = await self.loop.client.create_message(
messages=messages,
system=[system],
tools=self.loop.tool_manager.get_tool_params(),
max_tokens=self.loop.max_tokens,
betas=betas,
)
# Let ExperimentManager handle sanitization
self.loop._log_api_call("response", request_data, response)
return response
except Exception as e:
last_error = e
logger.error(
f"Error in API call (attempt {attempt + 1}/{self.loop.max_retries}): {str(e)}"
)
self.loop._log_api_call("error", {"messages": messages}, error=e)
if attempt < self.loop.max_retries - 1:
await asyncio.sleep(
self.loop.retry_delay * (attempt + 1)
) # Exponential backoff
continue
# If we get here, all retries failed
error_message = f"API call failed after {self.loop.max_retries} attempts"
if last_error:
error_message += f": {str(last_error)}"
logger.error(error_message)
raise RuntimeError(error_message)

View File

@@ -0,0 +1,5 @@
"""Anthropic callbacks package."""
from .manager import CallbackManager
__all__ = ["CallbackManager"]

View File

@@ -2,39 +2,36 @@
import logging
import asyncio
import json
import os
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, cast
import base64
from datetime import datetime
from httpx import ConnectError, ReadTimeout
# Anthropic-specific imports
from anthropic import AsyncAnthropic
from anthropic.types.beta import (
BetaMessage,
BetaMessageParam,
BetaTextBlock,
BetaTextBlockParam,
BetaToolUseBlockParam,
BetaContentBlockParam,
)
import base64
from datetime import datetime
# Computer
from computer import Computer
# Base imports
from ...core.loop import BaseLoop
from ...core.messages import ImageRetentionConfig as CoreImageRetentionConfig
from ...core.messages import StandardMessageManager, ImageRetentionConfig
from ...core.types import AgentResponse
# Anthropic provider-specific imports
from .api.client import AnthropicClientFactory, BaseAnthropicClient
from .tools.manager import ToolManager
from .messages.manager import MessageManager, ImageRetentionConfig
from .callbacks.manager import CallbackManager
from .prompts import SYSTEM_PROMPT
from .types import LLMProvider
from .tools import ToolResult
from .utils import to_anthropic_format, to_agent_response_format
# Import the new modules we created
from .api_handler import AnthropicAPIHandler
from .response_handler import AnthropicResponseHandler
from .callbacks.manager import CallbackManager
# Constants
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
@@ -44,13 +41,22 @@ logger = logging.getLogger(__name__)
class AnthropicLoop(BaseLoop):
"""Anthropic-specific implementation of the agent loop."""
"""Anthropic-specific implementation of the agent loop.
This class extends BaseLoop to provide specialized support for Anthropic's Claude models
with their unique tool-use capabilities, custom message formatting, and
callback-driven approach to handling responses.
"""
###########################################
# INITIALIZATION AND CONFIGURATION
###########################################
def __init__(
self,
api_key: str,
computer: Computer,
model: str = "claude-3-7-sonnet-20250219", # Fixed model
model: str = "claude-3-7-sonnet-20250219",
only_n_most_recent_images: Optional[int] = 2,
base_dir: Optional[str] = "trajectories",
max_retries: int = 3,
@@ -83,27 +89,33 @@ class AnthropicLoop(BaseLoop):
**kwargs,
)
# Ensure model is always the fixed one
self.model = "claude-3-7-sonnet-20250219"
# Initialize message manager
self.message_manager = StandardMessageManager(
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
)
# Anthropic-specific attributes
self.provider = LLMProvider.ANTHROPIC
self.client = None
self.retry_count = 0
self.tool_manager = None
self.message_manager = None
self.callback_manager = None
self.queue = asyncio.Queue() # Initialize queue
# Configure image retention with core config
self.image_retention_config = CoreImageRetentionConfig(
num_images_to_keep=only_n_most_recent_images
)
# Initialize handlers
self.api_handler = AnthropicAPIHandler(self)
self.response_handler = AnthropicResponseHandler(self)
# Message history
self.message_history = []
###########################################
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
###########################################
async def initialize_client(self) -> None:
"""Initialize the Anthropic API client and tools."""
"""Initialize the Anthropic API client and tools.
Implements abstract method from BaseLoop to set up the Anthropic-specific
client, tool manager, message manager, and callback handlers.
"""
try:
logger.info(f"Initializing Anthropic client with model {self.model}...")
@@ -112,14 +124,7 @@ class AnthropicLoop(BaseLoop):
provider=self.provider, api_key=self.api_key, model=self.model
)
# Initialize message manager
self.message_manager = MessageManager(
image_retention_config=ImageRetentionConfig(
num_images_to_keep=self.only_n_most_recent_images, enable_caching=True
)
)
# Initialize callback manager
# Initialize callback manager with our callback handlers
self.callback_manager = CallbackManager(
content_callback=self._handle_content,
tool_callback=self._handle_tool_result,
@@ -136,62 +141,22 @@ class AnthropicLoop(BaseLoop):
self.client = None
raise RuntimeError(f"Failed to initialize Anthropic client: {str(e)}")
async def _process_screen(
self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]]
) -> None:
"""Process screen information and add to messages.
###########################################
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
###########################################
Args:
parsed_screen: Dictionary containing parsed screen info
messages: List of messages to update
"""
try:
# Extract screenshot from parsed screen
screenshot_base64 = parsed_screen.get("screenshot_base64")
if screenshot_base64:
# Remove data URL prefix if present
if "," in screenshot_base64:
screenshot_base64 = screenshot_base64.split(",")[1]
# Create Anthropic-compatible message with image
screen_info_msg = {
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": screenshot_base64,
},
}
],
}
# Add screen info message to messages
messages.append(screen_info_msg)
except Exception as e:
logger.error(f"Error processing screen info: {str(e)}")
raise
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
"""Run the agent loop with provided messages.
Args:
messages: List of message objects
messages: List of message objects in standard OpenAI format
Yields:
Dict containing response data
Agent response format
"""
try:
logger.info("Starting Anthropic loop run")
# Reset message history and add new messages
self.message_history = []
self.message_history.extend(messages)
# Create queue for response streaming
queue = asyncio.Queue()
@@ -204,7 +169,7 @@ class AnthropicLoop(BaseLoop):
logger.info("Client initialized successfully")
# Start loop in background task
loop_task = asyncio.create_task(self._run_loop(queue))
loop_task = asyncio.create_task(self._run_loop(queue, messages))
# Process and yield messages as they arrive
while True:
@@ -236,37 +201,87 @@ class AnthropicLoop(BaseLoop):
"metadata": {"title": "❌ Error"},
}
async def _run_loop(self, queue: asyncio.Queue) -> None:
"""Run the agent loop with current message history.
###########################################
# AGENT LOOP IMPLEMENTATION
###########################################
async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
"""Run the agent loop with provided messages.
Args:
queue: Queue for response streaming
messages: List of messages in standard OpenAI format
"""
try:
while True:
# Get up-to-date screen information
parsed_screen = await self._get_parsed_screen_som()
# Capture screenshot
try:
# Take screenshot - always returns raw PNG bytes
screenshot = await self.computer.interface.screenshot()
logger.info("Screenshot captured successfully")
# Process screen info and update messages
await self._process_screen(parsed_screen, self.message_history)
# Convert PNG bytes to base64
base64_image = base64.b64encode(screenshot).decode("utf-8")
logger.info(f"Screenshot converted to base64 (size: {len(base64_image)} bytes)")
# Prepare messages and make API call
if self.message_manager is None:
raise RuntimeError(
"Message manager not initialized. Call initialize_client() first."
)
prepared_messages = self.message_manager.prepare_messages(
cast(List[BetaMessageParam], self.message_history.copy())
)
# Save screenshot if requested
if self.save_trajectory and self.experiment_manager:
try:
self._save_screenshot(base64_image, action_type="state")
logger.info("Screenshot saved to trajectory")
except Exception as e:
logger.error(f"Error saving screenshot: {str(e)}")
# Create screenshot message
screen_info_msg = {
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": base64_image,
},
}
],
}
# Add screenshot to messages
messages.append(screen_info_msg)
logger.info("Screenshot message added to conversation")
except Exception as e:
logger.error(f"Error capturing or processing screenshot: {str(e)}")
raise
# Create new turn directory for this API call
self._create_turn_dir()
# Use _make_api_call instead of direct client call to ensure logging
response = await self._make_api_call(prepared_messages)
# Convert standard messages to Anthropic format using utility function
anthropic_messages, system_content = to_anthropic_format(messages.copy())
# Handle the response
if not await self._handle_response(response, self.message_history):
# Use API handler to make API call with Anthropic format
response = await self.api_handler.make_api_call(
messages=cast(List[BetaMessageParam], anthropic_messages),
system_prompt=system_content or SYSTEM_PROMPT,
)
# Use response handler to handle the response and get new messages
new_messages, should_continue = await self.response_handler.handle_response(
response, messages
)
# Add new messages to the parent's message history
messages.extend(new_messages)
openai_compatible_response = await to_agent_response_format(
response,
messages,
model=self.model,
)
await queue.put(openai_compatible_response)
if not should_continue:
break
# Signal completion
@@ -283,142 +298,101 @@ class AnthropicLoop(BaseLoop):
)
await queue.put(None)
async def _make_api_call(self, messages: List[BetaMessageParam]) -> BetaMessage:
"""Make API call to Anthropic with retry logic.
Args:
messages: List of messages to send to the API
Returns:
API response
"""
if self.client is None:
raise RuntimeError("Client not initialized. Call initialize_client() first.")
if self.tool_manager is None:
raise RuntimeError("Tool manager not initialized. Call initialize_client() first.")
last_error = None
for attempt in range(self.max_retries):
try:
# Log request
request_data = {
"messages": messages,
"max_tokens": self.max_tokens,
"system": SYSTEM_PROMPT,
}
# Let ExperimentManager handle sanitization
self._log_api_call("request", request_data)
# Setup betas and system
system = BetaTextBlockParam(
type="text",
text=SYSTEM_PROMPT,
)
betas = [COMPUTER_USE_BETA_FLAG]
# Temporarily disable prompt caching due to "A maximum of 4 blocks with cache_control may be provided" error
# if self.message_manager.image_retention_config.enable_caching:
# betas.append(PROMPT_CACHING_BETA_FLAG)
# system["cache_control"] = {"type": "ephemeral"}
# Make API call
response = await self.client.create_message(
messages=messages,
system=[system],
tools=self.tool_manager.get_tool_params(),
max_tokens=self.max_tokens,
betas=betas,
)
# Let ExperimentManager handle sanitization
self._log_api_call("response", request_data, response)
return response
except Exception as e:
last_error = e
logger.error(
f"Error in API call (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
)
self._log_api_call("error", {"messages": messages}, error=e)
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
continue
# If we get here, all retries failed
error_message = f"API call failed after {self.max_retries} attempts"
if last_error:
error_message += f": {str(last_error)}"
logger.error(error_message)
raise RuntimeError(error_message)
###########################################
# RESPONSE AND CALLBACK HANDLING
###########################################
async def _handle_response(self, response: BetaMessage, messages: List[Dict[str, Any]]) -> bool:
"""Handle the Anthropic API response.
"""Handle a response from the Anthropic API.
Args:
response: API response
messages: List of messages to update
response: The response from the Anthropic API
messages: The message history
Returns:
True if the loop should continue, False otherwise
bool: Whether to continue the conversation
"""
try:
# Convert response to parameter format
response_params = self._response_to_params(response)
# Add response to messages
messages.append(
{
"role": "assistant",
"content": response_params,
}
# Convert response to standard format
openai_compatible_response = await to_agent_response_format(
response,
messages,
model=self.model,
)
# Put the response on the queue
await self.queue.put(openai_compatible_response)
if self.callback_manager is None:
raise RuntimeError(
"Callback manager not initialized. Call initialize_client() first."
)
# Handle tool use blocks and collect results
# Handle tool use blocks and collect ALL results before adding to messages
tool_result_content = []
for content_block in response_params:
has_tool_use = False
for content_block in response.content:
# Notify callback of content
self.callback_manager.on_content(cast(BetaContentBlockParam, content_block))
# Handle tool use
if content_block.get("type") == "tool_use":
# Handle tool use - carefully check and access attributes
if hasattr(content_block, "type") and content_block.type == "tool_use":
has_tool_use = True
if self.tool_manager is None:
raise RuntimeError(
"Tool manager not initialized. Call initialize_client() first."
)
# Safely get attributes
tool_name = getattr(content_block, "name", "")
tool_input = getattr(content_block, "input", {})
tool_id = getattr(content_block, "id", "")
result = await self.tool_manager.execute_tool(
name=content_block["name"],
tool_input=cast(Dict[str, Any], content_block["input"]),
name=tool_name,
tool_input=cast(Dict[str, Any], tool_input),
)
# Create tool result and add to content
tool_result = self._make_tool_result(
cast(ToolResult, result), content_block["id"]
)
# Create tool result
tool_result = self._make_tool_result(cast(ToolResult, result), tool_id)
tool_result_content.append(tool_result)
# Notify callback of tool result
self.callback_manager.on_tool_result(
cast(ToolResult, result), content_block["id"]
)
self.callback_manager.on_tool_result(cast(ToolResult, result), tool_id)
# If no tool results, we're done
if not tool_result_content:
# Signal completion
# If we had any tool_use blocks, we MUST add the tool_result message
# even if there were errors or no actual results
if has_tool_use:
# If somehow we have no tool results but had tool uses, add synthetic error results
if not tool_result_content:
logger.warning(
"Had tool uses but no tool results, adding synthetic error results"
)
for content_block in response.content:
if hasattr(content_block, "type") and content_block.type == "tool_use":
tool_id = getattr(content_block, "id", "")
if tool_id:
tool_result_content.append(
{
"type": "tool_result",
"tool_use_id": tool_id,
"content": {
"type": "error",
"text": "Tool execution was skipped or failed",
},
"is_error": True,
}
)
# Add ALL tool results as a SINGLE user message
messages.append({"role": "user", "content": tool_result_content})
return True
else:
# No tool uses, we're done
self.callback_manager.on_content({"type": "text", "text": "<DONE>"})
return False
# Add tool results to message history
messages.append({"content": tool_result_content, "role": "user"})
return True
except Exception as e:
logger.error(f"Error handling response: {str(e)}")
messages.append(
@@ -429,28 +403,41 @@ class AnthropicLoop(BaseLoop):
)
return False
def _response_to_params(
self,
response: BetaMessage,
) -> List[Dict[str, Any]]:
"""Convert API response to message parameters.
def _response_to_blocks(self, response: BetaMessage) -> List[Dict[str, Any]]:
"""Convert Anthropic API response to standard blocks format.
Args:
response: API response message
Returns:
List of content blocks
List of content blocks in standard format
"""
result = []
for block in response.content:
if isinstance(block, BetaTextBlock):
result.append({"type": "text", "text": block.text})
elif hasattr(block, "type") and block.type == "tool_use":
# Safely access attributes after confirming it's a tool_use
result.append(
{
"type": "tool_use",
"id": getattr(block, "id", ""),
"name": getattr(block, "name", ""),
"input": getattr(block, "input", {}),
}
)
else:
result.append(cast(Dict[str, Any], block.model_dump()))
# For other block types, convert to dict
block_dict = {}
for key, value in vars(block).items():
if not key.startswith("_"):
block_dict[key] = value
result.append(block_dict)
return result
def _make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]:
"""Convert a tool result to API format.
"""Convert a tool result to standard format.
Args:
result: Tool execution result
@@ -489,12 +476,8 @@ class AnthropicLoop(BaseLoop):
if result.base64_image:
tool_result_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": result.base64_image,
},
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{result.base64_image}"},
}
)
@@ -519,16 +502,19 @@ class AnthropicLoop(BaseLoop):
result_text = f"<s>{result.system}</s>\n{result_text}"
return result_text
def _handle_content(self, content: BetaContentBlockParam) -> None:
###########################################
# CALLBACK HANDLERS
###########################################
def _handle_content(self, content):
"""Handle content updates from the assistant."""
if content.get("type") == "text":
text_content = cast(BetaTextBlockParam, content)
text = text_content["text"]
text = content.get("text", "")
if text == "<DONE>":
return
logger.info(f"Assistant: {text}")
def _handle_tool_result(self, result: ToolResult, tool_id: str) -> None:
def _handle_tool_result(self, result, tool_id):
"""Handle tool execution results."""
if result.error:
logger.error(f"Tool {tool_id} error: {result.error}")

View File

@@ -0,0 +1,229 @@
"""Response and tool handling for Anthropic provider."""
import logging
from typing import Any, Dict, List, Optional, Tuple, cast
from anthropic.types.beta import (
BetaMessage,
BetaMessageParam,
BetaTextBlock,
BetaTextBlockParam,
BetaToolUseBlockParam,
BetaContentBlockParam,
)
from .tools import ToolResult
logger = logging.getLogger(__name__)
class AnthropicResponseHandler:
"""Handles Anthropic API responses and tool execution results."""
def __init__(self, loop):
"""Initialize the response handler.
Args:
loop: Reference to the parent loop instance that provides context
"""
self.loop = loop
async def handle_response(
self, response: BetaMessage, messages: List[Dict[str, Any]]
) -> Tuple[List[Dict[str, Any]], bool]:
"""Handle the Anthropic API response.
Args:
response: API response
messages: List of messages for context
Returns:
Tuple containing:
- List of new messages to be added
- Boolean indicating if the loop should continue
"""
try:
new_messages = []
# Convert response to parameter format
response_params = self.response_to_params(response)
# Collect all existing tool_use IDs from previous messages for validation
existing_tool_use_ids = set()
for msg in messages:
if msg.get("role") == "assistant" and isinstance(msg.get("content"), list):
for block in msg.get("content", []):
if (
isinstance(block, dict)
and block.get("type") == "tool_use"
and "id" in block
):
existing_tool_use_ids.add(block["id"])
# Also add new tool_use IDs from the current response
current_tool_use_ids = set()
for block in response_params:
if isinstance(block, dict) and block.get("type") == "tool_use" and "id" in block:
current_tool_use_ids.add(block["id"])
existing_tool_use_ids.add(block["id"])
logger.info(f"Existing tool_use IDs in conversation: {existing_tool_use_ids}")
logger.info(f"New tool_use IDs in current response: {current_tool_use_ids}")
# Create assistant message
new_messages.append(
{
"role": "assistant",
"content": response_params,
}
)
if self.loop.callback_manager is None:
raise RuntimeError(
"Callback manager not initialized. Call initialize_client() first."
)
# Handle tool use blocks and collect results
tool_result_content = []
for content_block in response_params:
# Notify callback of content
self.loop.callback_manager.on_content(cast(BetaContentBlockParam, content_block))
# Handle tool use
if content_block.get("type") == "tool_use":
if self.loop.tool_manager is None:
raise RuntimeError(
"Tool manager not initialized. Call initialize_client() first."
)
# Execute the tool
result = await self.loop.tool_manager.execute_tool(
name=content_block["name"],
tool_input=cast(Dict[str, Any], content_block["input"]),
)
# Verify the tool_use ID exists in the conversation (which it should now)
tool_use_id = content_block["id"]
if tool_use_id in existing_tool_use_ids:
# Create tool result and add to content
tool_result = self.make_tool_result(cast(ToolResult, result), tool_use_id)
tool_result_content.append(tool_result)
# Notify callback of tool result
self.loop.callback_manager.on_tool_result(
cast(ToolResult, result), content_block["id"]
)
else:
logger.warning(
f"Tool use ID {tool_use_id} not found in previous messages. Skipping tool result."
)
# If no tool results, we're done
if not tool_result_content:
# Signal completion
self.loop.callback_manager.on_content({"type": "text", "text": "<DONE>"})
return new_messages, False
# Add tool results as user message
new_messages.append({"content": tool_result_content, "role": "user"})
return new_messages, True
except Exception as e:
logger.error(f"Error handling response: {str(e)}")
new_messages.append(
{
"role": "assistant",
"content": f"Error: {str(e)}",
}
)
return new_messages, False
def response_to_params(
self,
response: BetaMessage,
) -> List[Dict[str, Any]]:
"""Convert API response to message parameters.
Args:
response: API response message
Returns:
List of content blocks
"""
result = []
for block in response.content:
if isinstance(block, BetaTextBlock):
result.append({"type": "text", "text": block.text})
else:
result.append(cast(Dict[str, Any], block.model_dump()))
return result
def make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]:
"""Convert a tool result to API format.
Args:
result: Tool execution result
tool_use_id: ID of the tool use
Returns:
Formatted tool result
"""
if result.content:
return {
"type": "tool_result",
"content": result.content,
"tool_use_id": tool_use_id,
"is_error": bool(result.error),
}
tool_result_content = []
is_error = False
if result.error:
is_error = True
tool_result_content = [
{
"type": "text",
"text": self.maybe_prepend_system_tool_result(result, result.error),
}
]
else:
if result.output:
tool_result_content.append(
{
"type": "text",
"text": self.maybe_prepend_system_tool_result(result, result.output),
}
)
if result.base64_image:
tool_result_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": result.base64_image,
},
}
)
return {
"type": "tool_result",
"content": tool_result_content,
"tool_use_id": tool_use_id,
"is_error": is_error,
}
def maybe_prepend_system_tool_result(self, result: ToolResult, result_text: str) -> str:
"""Prepend system information to tool result if available.
Args:
result: Tool execution result
result_text: Text to prepend to
Returns:
Text with system information prepended if available
"""
if result.system:
result_text = f"<s>{result.system}</s>\n{result_text}"
return result_text

View File

@@ -7,101 +7,6 @@ from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
from ....core.tools.bash import BaseBashTool
class _BashSession:
"""A session of a bash shell."""
_started: bool
_process: asyncio.subprocess.Process
command: str = "/bin/bash"
_output_delay: float = 0.2 # seconds
_timeout: float = 120.0 # seconds
_sentinel: str = "<<exit>>"
def __init__(self):
self._started = False
self._timed_out = False
async def start(self):
if self._started:
return
self._process = await asyncio.create_subprocess_shell(
self.command,
preexec_fn=os.setsid,
shell=True,
bufsize=0,
stdin=asyncio.subprocess.PIPE,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
self._started = True
def stop(self):
"""Terminate the bash shell."""
if not self._started:
raise ToolError("Session has not started.")
if self._process.returncode is not None:
return
self._process.terminate()
async def run(self, command: str):
"""Execute a command in the bash shell."""
if not self._started:
raise ToolError("Session has not started.")
if self._process.returncode is not None:
return ToolResult(
system="tool must be restarted",
error=f"bash has exited with returncode {self._process.returncode}",
)
if self._timed_out:
raise ToolError(
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
)
# we know these are not None because we created the process with PIPEs
assert self._process.stdin
assert self._process.stdout
assert self._process.stderr
# send command to the process
self._process.stdin.write(command.encode() + f"; echo '{self._sentinel}'\n".encode())
await self._process.stdin.drain()
# read output from the process, until the sentinel is found
try:
async with asyncio.timeout(self._timeout):
while True:
await asyncio.sleep(self._output_delay)
# Read from stdout using the proper API
output_bytes = await self._process.stdout.read()
if output_bytes:
output = output_bytes.decode()
if self._sentinel in output:
# strip the sentinel and break
output = output[: output.index(self._sentinel)]
break
except asyncio.TimeoutError:
self._timed_out = True
raise ToolError(
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
) from None
if output and output.endswith("\n"):
output = output[:-1]
# Read from stderr using the proper API
error_bytes = await self._process.stderr.read()
error = error_bytes.decode() if error_bytes else ""
if error and error.endswith("\n"):
error = error[:-1]
# No need to clear buffers as we're using read() which consumes the data
return CLIResult(output=output, error=error)
class BashTool(BaseBashTool, BaseAnthropicTool):
"""
A tool that allows the agent to run bash commands.
@@ -123,7 +28,6 @@ class BashTool(BaseBashTool, BaseAnthropicTool):
# Then initialize the Anthropic tool
BaseAnthropicTool.__init__(self)
# Initialize bash session
self._session = _BashSession()
async def __call__(self, command: str | None = None, restart: bool = False, **kwargs):
"""Execute a bash command.

View File

@@ -0,0 +1,370 @@
"""Utility functions for Anthropic message handling."""
import time
import logging
import re
from typing import Any, Dict, List, Optional, Tuple, cast
from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaTextBlock
from ..omni.parser import ParseResult
from ...core.types import AgentResponse
from datetime import datetime
import json
# Configure module logger
logger = logging.getLogger(__name__)
def to_anthropic_format(
messages: List[Dict[str, Any]],
) -> Tuple[List[Dict[str, Any]], str]:
"""Convert standard OpenAI format messages to Anthropic format.
Args:
messages: List of messages in OpenAI format
Returns:
Tuple containing (anthropic_messages, system_content)
"""
result = []
system_content = ""
# Process messages in order to maintain conversation flow
previous_assistant_tool_use_ids = set() # Track tool_use_ids in the previous assistant message
for i, msg in enumerate(messages):
role = msg.get("role", "")
content = msg.get("content", "")
if role == "system":
# Collect system messages for later use
system_content += content + "\n"
continue
if role == "assistant":
# Track tool_use_ids in this assistant message for the next user message
previous_assistant_tool_use_ids = set()
if isinstance(content, list):
for item in content:
if isinstance(item, dict) and item.get("type") == "tool_use" and "id" in item:
previous_assistant_tool_use_ids.add(item["id"])
if role in ["user", "assistant"]:
anthropic_msg = {"role": role}
# Convert content based on type
if isinstance(content, str):
# Simple text content
anthropic_msg["content"] = [{"type": "text", "text": content}]
elif isinstance(content, list):
# Convert complex content
anthropic_content = []
for item in content:
item_type = item.get("type", "")
if item_type == "text":
anthropic_content.append({"type": "text", "text": item.get("text", "")})
elif item_type == "image_url":
# Convert OpenAI image format to Anthropic
image_url = item.get("image_url", {}).get("url", "")
if image_url.startswith("data:"):
# Extract base64 data and media type
match = re.match(r"data:(.+);base64,(.+)", image_url)
if match:
media_type, data = match.groups()
anthropic_content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": media_type,
"data": data,
},
}
)
else:
# Regular URL
anthropic_content.append(
{
"type": "image",
"source": {
"type": "url",
"url": image_url,
},
}
)
elif item_type == "tool_use":
# Always include tool_use blocks
anthropic_content.append(item)
elif item_type == "tool_result":
# Check if this is a user message AND if the tool_use_id exists in the previous assistant message
tool_use_id = item.get("tool_use_id")
# Only include tool_result if it references a tool_use from the immediately preceding assistant message
if (
role == "user"
and tool_use_id
and tool_use_id in previous_assistant_tool_use_ids
):
anthropic_content.append(item)
else:
content_text = "Tool Result: "
if "content" in item:
if isinstance(item["content"], list):
for content_item in item["content"]:
if (
isinstance(content_item, dict)
and content_item.get("type") == "text"
):
content_text += content_item.get("text", "")
elif isinstance(item["content"], str):
content_text += item["content"]
anthropic_content.append({"type": "text", "text": content_text})
anthropic_msg["content"] = anthropic_content
result.append(anthropic_msg)
return result, system_content
def from_anthropic_format(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert Anthropic format messages to standard OpenAI format.
Args:
messages: List of messages in Anthropic format
Returns:
List of messages in OpenAI format
"""
result = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", [])
if role in ["user", "assistant"]:
openai_msg = {"role": role}
# Simple case: single text block
if len(content) == 1 and content[0].get("type") == "text":
openai_msg["content"] = content[0].get("text", "")
else:
# Complex case: multiple blocks or non-text
openai_content = []
for item in content:
item_type = item.get("type", "")
if item_type == "text":
openai_content.append({"type": "text", "text": item.get("text", "")})
elif item_type == "image":
# Convert Anthropic image to OpenAI format
source = item.get("source", {})
if source.get("type") == "base64":
media_type = source.get("media_type", "image/png")
data = source.get("data", "")
openai_content.append(
{
"type": "image_url",
"image_url": {"url": f"data:{media_type};base64,{data}"},
}
)
else:
# URL
openai_content.append(
{
"type": "image_url",
"image_url": {"url": source.get("url", "")},
}
)
elif item_type in ["tool_use", "tool_result"]:
# Pass through tool-related content
openai_content.append(item)
openai_msg["content"] = openai_content
result.append(openai_msg)
return result
async def to_agent_response_format(
response: BetaMessage,
messages: List[Dict[str, Any]],
parsed_screen: Optional[ParseResult] = None,
parser: Optional[Any] = None,
model: Optional[str] = None,
) -> AgentResponse:
"""Convert an Anthropic response to the standard agent response format.
Args:
response: The Anthropic API response (BetaMessage)
messages: List of messages in standard format
parsed_screen: Optional pre-parsed screen information
parser: Optional parser instance for coordinate calculation
model: Optional model name
Returns:
A response formatted according to the standard agent response format
"""
# Create unique IDs for this response
response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(response)}"
reasoning_id = f"rs_{response_id}"
action_id = f"cu_{response_id}"
call_id = f"call_{response_id}"
# Extract content and reasoning from Anthropic response
content = []
reasoning_text = None
action_details = None
for block in response.content:
if block.type == "text":
# Use the first text block as reasoning
if reasoning_text is None:
reasoning_text = block.text
content.append({"type": "text", "text": block.text})
elif block.type == "tool_use" and block.name == "computer":
try:
input_dict = cast(Dict[str, Any], block.input)
action = input_dict.get("action", "").lower()
# Extract coordinates from coordinate list if provided
coordinates = input_dict.get("coordinate", [100, 100])
x, y = coordinates if len(coordinates) == 2 else (100, 100)
if action == "screenshot":
action_details = {
"type": "screenshot",
}
elif action in ["click", "left_click", "right_click", "double_click"]:
action_details = {
"type": "click",
"button": "left" if action in ["click", "left_click"] else "right",
"double": action == "double_click",
"x": x,
"y": y,
}
elif action == "type":
action_details = {
"type": "type",
"text": input_dict.get("text", ""),
}
elif action == "key":
action_details = {
"type": "hotkey",
"keys": [input_dict.get("text", "")],
}
elif action == "scroll":
scroll_amount = input_dict.get("scroll_amount", 1)
scroll_direction = input_dict.get("scroll_direction", "down")
delta_y = scroll_amount if scroll_direction == "down" else -scroll_amount
action_details = {
"type": "scroll",
"x": x,
"y": y,
"delta_x": 0,
"delta_y": delta_y,
}
elif action == "move":
action_details = {
"type": "move",
"x": x,
"y": y,
}
except Exception as e:
logger.error(f"Error extracting action details: {str(e)}")
# Create output items with reasoning
output_items = []
if reasoning_text:
output_items.append(
{
"type": "reasoning",
"id": reasoning_id,
"summary": [
{
"type": "summary_text",
"text": reasoning_text,
}
],
}
)
# Add computer_call item with extracted or default action
computer_call = {
"type": "computer_call",
"id": action_id,
"call_id": call_id,
"action": action_details or {"type": "none", "description": "No action specified"},
"pending_safety_checks": [],
"status": "completed",
}
output_items.append(computer_call)
# Create the standard response format
standard_response = {
"id": response_id,
"object": "response",
"created_at": int(datetime.now().timestamp()),
"status": "completed",
"error": None,
"incomplete_details": None,
"instructions": None,
"max_output_tokens": None,
"model": model or "anthropic-default",
"output": output_items,
"parallel_tool_calls": True,
"previous_response_id": None,
"reasoning": {"effort": "medium", "generate_summary": "concise"},
"store": True,
"temperature": 1.0,
"text": {"format": {"type": "text"}},
"tool_choice": "auto",
"tools": [
{
"type": "computer_use_preview",
"display_height": 768,
"display_width": 1024,
"environment": "mac",
}
],
"top_p": 1.0,
"truncation": "auto",
"usage": {
"input_tokens": 0,
"input_tokens_details": {"cached_tokens": 0},
"output_tokens": 0,
"output_tokens_details": {"reasoning_tokens": 0},
"total_tokens": 0,
},
"user": None,
"metadata": {},
"response": {
"choices": [
{
"message": {
"role": "assistant",
"content": content,
"tool_calls": [],
},
"finish_reason": response.stop_reason or "stop",
}
]
},
}
# Add tool calls if present
tool_calls = []
for block in response.content:
if hasattr(block, "type") and block.type == "tool_use":
tool_calls.append(
{
"id": f"call_{block.id}",
"type": "function",
"function": {"name": block.name, "arguments": block.input},
}
)
if tool_calls:
standard_response["response"]["choices"][0]["message"]["tool_calls"] = tool_calls
return cast(AgentResponse, standard_response)

View File

@@ -1,27 +1,8 @@
"""Omni provider implementation."""
# The OmniComputerAgent has been replaced by the unified ComputerAgent
# which can be found in agent.core.agent
from .types import LLMProvider
from .experiment import ExperimentManager
from .visualization import visualize_click, visualize_scroll, calculate_element_center
from .image_utils import (
decode_base64_image,
encode_image_base64,
clean_base64_data,
extract_base64_from_text,
get_image_dimensions,
)
__all__ = [
"LLMProvider",
"ExperimentManager",
"visualize_click",
"visualize_scroll",
"calculate_element_center",
"decode_base64_image",
"encode_image_base64",
"clean_base64_data",
"extract_base64_from_text",
"get_image_dimensions",
]
__all__ = ["LLMProvider", "decode_base64_image"]

View File

@@ -0,0 +1,42 @@
"""API handling for Omni provider."""
import logging
from typing import Any, Dict, List
from .prompts import SYSTEM_PROMPT
logger = logging.getLogger(__name__)
class OmniAPIHandler:
"""Handler for Omni API calls."""
def __init__(self, loop):
"""Initialize the API handler.
Args:
loop: Parent loop instance
"""
self.loop = loop
async def make_api_call(
self, messages: List[Dict[str, Any]], system_prompt: str = SYSTEM_PROMPT
) -> Any:
"""Make an API call to the appropriate provider.
Args:
messages: List of messages in standard OpenAI format
system_prompt: System prompt to use
Returns:
API response
"""
if not self.loop._make_api_call:
raise RuntimeError("Loop does not have _make_api_call method")
try:
# Use the loop's _make_api_call method with standard messages
return await self.loop._make_api_call(messages=messages, system_prompt=system_prompt)
except Exception as e:
logger.error(f"Error making API call: {str(e)}")
raise

View File

@@ -1,78 +0,0 @@
"""Omni callback manager implementation."""
import logging
from typing import Any, Dict, Optional, Set
from ...core.callbacks import BaseCallbackManager, ContentCallback, ToolCallback, APICallback
from ...types.tools import ToolResult
logger = logging.getLogger(__name__)
class OmniCallbackManager(BaseCallbackManager):
"""Callback manager for multi-provider support."""
def __init__(
self,
content_callback: ContentCallback,
tool_callback: ToolCallback,
api_callback: APICallback,
):
"""Initialize Omni callback manager.
Args:
content_callback: Callback for content updates
tool_callback: Callback for tool execution results
api_callback: Callback for API interactions
"""
super().__init__(
content_callback=content_callback,
tool_callback=tool_callback,
api_callback=api_callback
)
self._active_tools: Set[str] = set()
def on_content(self, content: Any) -> None:
"""Handle content updates.
Args:
content: Content update data
"""
logger.debug(f"Content update: {content}")
self.content_callback(content)
def on_tool_result(self, result: ToolResult, tool_id: str) -> None:
"""Handle tool execution results.
Args:
result: Tool execution result
tool_id: ID of the tool
"""
logger.debug(f"Tool result for {tool_id}: {result}")
self.tool_callback(result, tool_id)
def on_api_interaction(
self,
request: Any,
response: Any,
error: Optional[Exception] = None
) -> None:
"""Handle API interactions.
Args:
request: API request data
response: API response data
error: Optional error that occurred
"""
if error:
logger.error(f"API error: {str(error)}")
else:
logger.debug(f"API interaction - Request: {request}, Response: {response}")
self.api_callback(request, response, error)
def get_active_tools(self) -> Set[str]:
"""Get currently active tools.
Returns:
Set of active tool names
"""
return self._active_tools.copy()

View File

@@ -44,6 +44,10 @@ class AnthropicClient(BaseOmniClient):
anthropic_messages = []
for message in messages:
# Skip messages with empty content
if not message.get("content"):
continue
if message["role"] == "user":
anthropic_messages.append({"role": "user", "content": message["content"]})
elif message["role"] == "assistant":

View File

@@ -1,101 +0,0 @@
"""Groq client implementation."""
import os
import logging
from typing import Dict, List, Optional, Any, Tuple
from groq import Groq
import re
from .utils import is_image_path
from .base import BaseOmniClient
logger = logging.getLogger(__name__)
class GroqClient(BaseOmniClient):
"""Client for making Groq API calls."""
def __init__(
self,
api_key: Optional[str] = None,
model: str = "deepseek-r1-distill-llama-70b",
max_tokens: int = 4096,
temperature: float = 0.6,
):
"""Initialize Groq client.
Args:
api_key: Groq API key (if not provided, will try to get from env)
model: Model name to use
max_tokens: Maximum tokens to generate
temperature: Temperature for sampling
"""
super().__init__(api_key=api_key, model=model)
self.api_key = api_key or os.getenv("GROQ_API_KEY")
if not self.api_key:
raise ValueError("No Groq API key provided")
self.max_tokens = max_tokens
self.temperature = temperature
self.client = Groq(api_key=self.api_key)
self.model: str = model # Add explicit type annotation
def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
) -> tuple[str, int]:
"""Run interleaved chat completion.
Args:
messages: List of message dicts
system: System prompt
max_tokens: Optional max tokens override
Returns:
Tuple of (response text, token usage)
"""
# Avoid using system messages for R1
final_messages = [{"role": "user", "content": system}]
# Process messages
if isinstance(messages, list):
for item in messages:
if isinstance(item, dict):
# For dict items, concatenate all text content, ignoring images
text_contents = []
for cnt in item["content"]:
if isinstance(cnt, str):
if not is_image_path(cnt): # Skip image paths
text_contents.append(cnt)
else:
text_contents.append(str(cnt))
if text_contents: # Only add if there's text content
message = {"role": "user", "content": " ".join(text_contents)}
final_messages.append(message)
else: # str
message = {"role": "user", "content": item}
final_messages.append(message)
elif isinstance(messages, str):
final_messages.append({"role": "user", "content": messages})
try:
completion = self.client.chat.completions.create( # type: ignore
model=self.model,
messages=final_messages, # type: ignore
temperature=self.temperature,
max_tokens=max_tokens or self.max_tokens,
top_p=0.95,
stream=False,
)
response = completion.choices[0].message.content
final_answer = response.split("</think>\n")[-1] if "</think>" in response else response
final_answer = final_answer.replace("<output>", "").replace("</output>", "")
token_usage = completion.usage.total_tokens
return final_answer, token_usage
except Exception as e:
logger.error(f"Error in Groq API call: {e}")
raise

View File

@@ -1,276 +0,0 @@
"""Experiment management for the Cua provider."""
import os
import logging
import copy
import base64
from io import BytesIO
from datetime import datetime
from typing import Any, Dict, List, Optional
from PIL import Image
import json
import time
logger = logging.getLogger(__name__)
class ExperimentManager:
"""Manages experiment directories and logging for the agent."""
def __init__(
self,
base_dir: Optional[str] = None,
only_n_most_recent_images: Optional[int] = None,
):
"""Initialize the experiment manager.
Args:
base_dir: Base directory for saving experiment data
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
"""
self.base_dir = base_dir
self.only_n_most_recent_images = only_n_most_recent_images
self.run_dir = None
self.current_turn_dir = None
self.turn_count = 0
self.screenshot_count = 0
# Track all screenshots for potential API request inclusion
self.screenshot_paths = []
# Set up experiment directories if base_dir is provided
if self.base_dir:
self.setup_experiment_dirs()
def setup_experiment_dirs(self) -> None:
"""Setup the experiment directory structure."""
if not self.base_dir:
return
# Create base experiments directory if it doesn't exist
os.makedirs(self.base_dir, exist_ok=True)
# Use the base_dir directly as the run_dir
self.run_dir = self.base_dir
logger.info(f"Using directory for experiment: {self.run_dir}")
# Create first turn directory
self.create_turn_dir()
def create_turn_dir(self) -> None:
"""Create a new directory for the current turn."""
if not self.run_dir:
return
self.turn_count += 1
self.current_turn_dir = os.path.join(self.run_dir, f"turn_{self.turn_count:03d}")
os.makedirs(self.current_turn_dir, exist_ok=True)
logger.info(f"Created turn directory: {self.current_turn_dir}")
def sanitize_log_data(self, data: Any) -> Any:
"""Sanitize data for logging by removing large base64 strings.
Args:
data: Data to sanitize (dict, list, or primitive)
Returns:
Sanitized copy of the data
"""
if isinstance(data, dict):
result = copy.deepcopy(data)
# Handle nested dictionaries and lists
for key, value in result.items():
# Process content arrays that contain image data
if key == "content" and isinstance(value, list):
for i, item in enumerate(value):
if isinstance(item, dict):
# Handle Anthropic format
if item.get("type") == "image" and isinstance(item.get("source"), dict):
source = item["source"]
if "data" in source and isinstance(source["data"], str):
# Replace base64 data with a placeholder and length info
data_len = len(source["data"])
source["data"] = f"[BASE64_IMAGE_DATA_LENGTH_{data_len}]"
# Handle OpenAI format
elif item.get("type") == "image_url" and isinstance(
item.get("image_url"), dict
):
url_dict = item["image_url"]
if "url" in url_dict and isinstance(url_dict["url"], str):
url = url_dict["url"]
if url.startswith("data:"):
# Replace base64 data with placeholder
data_len = len(url)
url_dict["url"] = f"[BASE64_IMAGE_URL_LENGTH_{data_len}]"
# Handle other nested structures recursively
if isinstance(value, dict):
result[key] = self.sanitize_log_data(value)
elif isinstance(value, list):
result[key] = [self.sanitize_log_data(item) for item in value]
return result
elif isinstance(data, list):
return [self.sanitize_log_data(item) for item in data]
else:
return data
def save_debug_image(self, image_data: str, filename: str) -> None:
"""Save a debug image to the experiment directory.
Args:
image_data: Base64 encoded image data
filename: Filename to save the image as
"""
# Since we no longer want to use the images/ folder, we'll skip this functionality
return
def save_screenshot(self, img_base64: str, action_type: str = "") -> Optional[str]:
"""Save a screenshot to the experiment directory.
Args:
img_base64: Base64 encoded screenshot
action_type: Type of action that triggered the screenshot
Returns:
Optional[str]: Path to the saved screenshot, or None if saving failed
"""
if not self.current_turn_dir:
return None
try:
# Increment screenshot counter
self.screenshot_count += 1
# Create a descriptive filename
timestamp = int(time.time() * 1000)
action_suffix = f"_{action_type}" if action_type else ""
filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
# Save directly to the turn directory (no screenshots subdirectory)
filepath = os.path.join(self.current_turn_dir, filename)
# Save the screenshot
img_data = base64.b64decode(img_base64)
with open(filepath, "wb") as f:
f.write(img_data)
# Keep track of the file path for reference
self.screenshot_paths.append(filepath)
return filepath
except Exception as e:
logger.error(f"Error saving screenshot: {str(e)}")
return None
def should_save_debug_image(self) -> bool:
"""Determine if debug images should be saved.
Returns:
Boolean indicating if debug images should be saved
"""
# We no longer need to save debug images, so always return False
return False
def save_action_visualization(
self, img: Image.Image, action_name: str, details: str = ""
) -> str:
"""Save a visualization of an action.
Args:
img: Image to save
action_name: Name of the action
details: Additional details about the action
Returns:
Path to the saved image
"""
if not self.current_turn_dir:
return ""
try:
# Create a descriptive filename
timestamp = int(time.time() * 1000)
details_suffix = f"_{details}" if details else ""
filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
# Save directly to the turn directory (no visualizations subdirectory)
filepath = os.path.join(self.current_turn_dir, filename)
# Save the image
img.save(filepath)
# Keep track of the file path for cleanup
self.screenshot_paths.append(filepath)
return filepath
except Exception as e:
logger.error(f"Error saving action visualization: {str(e)}")
return ""
def extract_and_save_images(self, data: Any, prefix: str) -> None:
"""Extract and save images from response data.
Args:
data: Response data to extract images from
prefix: Prefix for saved image filenames
"""
# Since we no longer want to save extracted images separately,
# we'll skip this functionality entirely
return
def log_api_call(
self,
call_type: str,
request: Any,
provider: str,
model: str,
response: Any = None,
error: Optional[Exception] = None,
) -> None:
"""Log API call details to file.
Args:
call_type: Type of API call (e.g., 'request', 'response', 'error')
request: The API request data
provider: The AI provider used
model: The AI model used
response: Optional API response data
error: Optional error information
"""
if not self.current_turn_dir:
return
try:
# Create a unique filename with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"api_call_{timestamp}_{call_type}.json"
filepath = os.path.join(self.current_turn_dir, filename)
# Sanitize data to remove large base64 strings
sanitized_request = self.sanitize_log_data(request)
sanitized_response = self.sanitize_log_data(response) if response is not None else None
# Prepare log data
log_data = {
"timestamp": timestamp,
"provider": provider,
"model": model,
"type": call_type,
"request": sanitized_request,
}
if sanitized_response is not None:
log_data["response"] = sanitized_response
if error is not None:
log_data["error"] = str(error)
# Write to file
with open(filepath, "w") as f:
json.dump(log_data, f, indent=2, default=str)
logger.info(f"Logged API {call_type} to {filepath}")
except Exception as e:
logger.error(f"Error logging API call: {str(e)}")

View File

@@ -32,75 +32,3 @@ def decode_base64_image(img_base64: str) -> Optional[Image.Image]:
except Exception as e:
logger.error(f"Error decoding base64 image: {str(e)}")
return None
def encode_image_base64(img: Image.Image, format: str = "PNG") -> str:
"""Encode a PIL Image to base64.
Args:
img: PIL Image to encode
format: Image format (PNG, JPEG, etc.)
Returns:
Base64 encoded image string
"""
try:
buffered = BytesIO()
img.save(buffered, format=format)
return base64.b64encode(buffered.getvalue()).decode("utf-8")
except Exception as e:
logger.error(f"Error encoding image to base64: {str(e)}")
return ""
def clean_base64_data(img_base64: str) -> str:
"""Clean base64 image data by removing data URL prefix.
Args:
img_base64: Base64 encoded image, may include data URL prefix
Returns:
Clean base64 string without prefix
"""
if img_base64.startswith("data:image"):
return img_base64.split(",")[1]
return img_base64
def extract_base64_from_text(text: str) -> Optional[str]:
"""Extract base64 image data from a text string.
Args:
text: Text potentially containing base64 image data
Returns:
Base64 string or None if not found
"""
# Look for data URL pattern
data_url_pattern = r"data:image/[^;]+;base64,([a-zA-Z0-9+/=]+)"
match = re.search(data_url_pattern, text)
if match:
return match.group(1)
# Look for plain base64 pattern (basic heuristic)
base64_pattern = r"([a-zA-Z0-9+/=]{100,})"
match = re.search(base64_pattern, text)
if match:
return match.group(1)
return None
def get_image_dimensions(img_base64: str) -> Tuple[int, int]:
"""Get the dimensions of a base64 encoded image.
Args:
img_base64: Base64 encoded image
Returns:
Tuple of (width, height) or (0, 0) if decoding fails
"""
img = decode_base64_image(img_base64)
if img:
return img.size
return (0, 0)

File diff suppressed because it is too large Load Diff

View File

@@ -1,171 +0,0 @@
"""Omni message manager implementation."""
import base64
from typing import Any, Dict, List, Optional
from io import BytesIO
from PIL import Image
from ...core.messages import BaseMessageManager, ImageRetentionConfig
class OmniMessageManager(BaseMessageManager):
"""Message manager for multi-provider support."""
def __init__(self, config: Optional[ImageRetentionConfig] = None):
"""Initialize the message manager.
Args:
config: Optional configuration for image retention
"""
super().__init__(config)
self.messages: List[Dict[str, Any]] = []
self.config = config
def add_user_message(self, content: str, images: Optional[List[bytes]] = None) -> None:
"""Add a user message to the history.
Args:
content: Message content
images: Optional list of image data
"""
# Add images if present
if images:
# Initialize with proper typing for mixed content
message_content: List[Dict[str, Any]] = [{"type": "text", "text": content}]
# Add each image
for img in images:
message_content.append(
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64.b64encode(img).decode()}"
},
}
)
message = {"role": "user", "content": message_content}
else:
# Simple text message
message = {"role": "user", "content": content}
self.messages.append(message)
# Apply retention policy
if self.config and self.config.num_images_to_keep:
self._apply_image_retention_policy()
def add_assistant_message(self, content: str) -> None:
"""Add an assistant message to the history.
Args:
content: Message content
"""
self.messages.append({"role": "assistant", "content": content})
def add_system_message(self, content: str) -> None:
"""Add a system message to the history.
Args:
content: Message content
"""
self.messages.append({"role": "system", "content": content})
def _apply_image_retention_policy(self) -> None:
"""Apply image retention policy to message history."""
if not self.config or not self.config.num_images_to_keep:
return
# Count images from newest to oldest
image_count = 0
for message in reversed(self.messages):
if message["role"] != "user":
continue
# Handle multimodal messages
if isinstance(message["content"], list):
new_content = []
for item in message["content"]:
if item["type"] == "text":
new_content.append(item)
elif item["type"] == "image_url":
if image_count < self.config.num_images_to_keep:
new_content.append(item)
image_count += 1
message["content"] = new_content
def get_formatted_messages(self, provider: str) -> List[Dict[str, Any]]:
"""Get messages formatted for specific provider.
Args:
provider: Provider name to format messages for
Returns:
List of formatted messages
"""
# Set the provider for message formatting
self.set_provider(provider)
if provider == "anthropic":
return self._format_for_anthropic()
elif provider == "openai":
return self._format_for_openai()
elif provider == "groq":
return self._format_for_groq()
elif provider == "qwen":
return self._format_for_qwen()
else:
raise ValueError(f"Unsupported provider: {provider}")
def _format_for_anthropic(self) -> List[Dict[str, Any]]:
"""Format messages for Anthropic API."""
formatted = []
for msg in self.messages:
formatted_msg = {"role": msg["role"]}
# Handle multimodal content
if isinstance(msg["content"], list):
formatted_msg["content"] = []
for item in msg["content"]:
if item["type"] == "text":
formatted_msg["content"].append({"type": "text", "text": item["text"]})
elif item["type"] == "image_url":
formatted_msg["content"].append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": item["image_url"]["url"].split(",")[1],
},
}
)
else:
formatted_msg["content"] = msg["content"]
formatted.append(formatted_msg)
return formatted
def _format_for_openai(self) -> List[Dict[str, Any]]:
"""Format messages for OpenAI API."""
# OpenAI already uses the same format
return self.messages
def _format_for_groq(self) -> List[Dict[str, Any]]:
"""Format messages for Groq API."""
# Groq uses OpenAI-compatible format
return self.messages
def _format_for_qwen(self) -> List[Dict[str, Any]]:
"""Format messages for Qwen API."""
formatted = []
for msg in self.messages:
if isinstance(msg["content"], list):
# Convert multimodal content to text-only
text_content = next(
(item["text"] for item in msg["content"] if item["type"] == "text"), ""
)
formatted.append({"role": msg["role"], "content": text_content})
else:
formatted.append(msg)
return formatted

View File

@@ -3,14 +3,11 @@
import logging
from typing import Any, Dict, List, Optional, Tuple
import base64
from PIL import Image
from io import BytesIO
import json
import torch
# Import from the SOM package
from som import OmniParser as OmniDetectParser
from som.models import ParseResult, BoundingBox, UIElement, ImageData, ParserMetadata
from som.models import ParseResult, ParserMetadata
logger = logging.getLogger(__name__)
@@ -251,3 +248,60 @@ class OmniParser:
except Exception as e:
logger.error(f"Error formatting messages: {str(e)}")
return messages # Return original messages on error
async def calculate_click_coordinates(
self, box_id: int, parsed_screen: ParseResult
) -> Tuple[int, int]:
"""Calculate click coordinates based on box ID.
Args:
box_id: The ID of the box to click
parsed_screen: The parsed screen information
Returns:
Tuple of (x, y) coordinates
Raises:
ValueError: If box_id is invalid or missing from parsed screen
"""
# First try to use structured elements data
logger.info(f"Elements count: {len(parsed_screen.elements)}")
# Try to find element with matching ID
for element in parsed_screen.elements:
if element.id == box_id:
logger.info(f"Found element with ID {box_id}: {element}")
bbox = element.bbox
# Get screen dimensions from the metadata if available, or fallback
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
logger.info(f"Screen dimensions: width={width}, height={height}")
# Create a dictionary from the element's bbox for calculate_element_center
bbox_dict = {"x1": bbox.x1, "y1": bbox.y1, "x2": bbox.x2, "y2": bbox.y2}
from ...core.visualization import calculate_element_center
center_x, center_y = calculate_element_center(bbox_dict, width, height)
logger.info(f"Calculated center: ({center_x}, {center_y})")
# Validate coordinates - if they're (0,0) or unreasonably small,
# use a default position in the center of the screen
if center_x == 0 and center_y == 0:
logger.warning("Got (0,0) coordinates, using fallback position")
center_x = width // 2
center_y = height // 2
logger.info(f"Using fallback center: ({center_x}, {center_y})")
return center_x, center_y
# If we couldn't find the box, use center of screen
logger.error(
f"Box ID {box_id} not found in structured elements (count={len(parsed_screen.elements)})"
)
# Use center of screen as fallback
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
logger.warning(f"Using fallback position in center of screen ({width//2}, {height//2})")
return width // 2, height // 2

View File

@@ -1,91 +0,0 @@
# """Omni tool manager implementation."""
# from typing import Dict, List, Type, Any
# from computer import Computer
# from ...core.tools import BaseToolManager, BashTool, EditTool
# class OmniToolManager(BaseToolManager):
# """Tool manager for multi-provider support."""
# def __init__(self, computer: Computer):
# """Initialize Omni tool manager.
# Args:
# computer: Computer instance for tools
# """
# super().__init__(computer)
# def get_anthropic_tools(self) -> List[Dict[str, Any]]:
# """Get tools formatted for Anthropic API.
# Returns:
# List of tool parameters in Anthropic format
# """
# tools: List[Dict[str, Any]] = []
# # Map base tools to Anthropic format
# for tool in self.tools.values():
# if isinstance(tool, BashTool):
# tools.append({
# "type": "bash_20241022",
# "name": tool.name
# })
# elif isinstance(tool, EditTool):
# tools.append({
# "type": "text_editor_20241022",
# "name": "str_replace_editor"
# })
# return tools
# def get_openai_tools(self) -> List[Dict]:
# """Get tools formatted for OpenAI API.
# Returns:
# List of tool parameters in OpenAI format
# """
# tools = []
# # Map base tools to OpenAI format
# for tool in self.tools.values():
# tools.append({
# "type": "function",
# "function": tool.get_schema()
# })
# return tools
# def get_groq_tools(self) -> List[Dict]:
# """Get tools formatted for Groq API.
# Returns:
# List of tool parameters in Groq format
# """
# tools = []
# # Map base tools to Groq format
# for tool in self.tools.values():
# tools.append({
# "type": "function",
# "function": tool.get_schema()
# })
# return tools
# def get_qwen_tools(self) -> List[Dict]:
# """Get tools formatted for Qwen API.
# Returns:
# List of tool parameters in Qwen format
# """
# tools = []
# # Map base tools to Qwen format
# for tool in self.tools.values():
# tools.append({
# "type": "function",
# "function": tool.get_schema()
# })
# return tools

View File

@@ -1,11 +1,30 @@
"""Omni provider tools - compatible with multiple LLM providers."""
from .bash import OmniBashTool
from .computer import OmniComputerTool
from .manager import OmniToolManager
from ....core.tools import BaseTool, ToolResult, ToolError, ToolFailure, CLIResult
from .base import BaseOmniTool
from .computer import ComputerTool
from .bash import BashTool
from .manager import ToolManager
# Re-export the tools with Omni-specific names for backward compatibility
OmniToolResult = ToolResult
OmniToolError = ToolError
OmniToolFailure = ToolFailure
OmniCLIResult = CLIResult
# We'll export specific tools once implemented
__all__ = [
"OmniBashTool",
"OmniComputerTool",
"OmniToolManager",
"BaseTool",
"BaseOmniTool",
"ToolResult",
"ToolError",
"ToolFailure",
"CLIResult",
"OmniToolResult",
"OmniToolError",
"OmniToolFailure",
"OmniCLIResult",
"ComputerTool",
"BashTool",
"ToolManager",
]

View File

@@ -0,0 +1,29 @@
"""Omni-specific tool base classes."""
from abc import ABCMeta, abstractmethod
from typing import Any, Dict
from ....core.tools.base import BaseTool
class BaseOmniTool(BaseTool, metaclass=ABCMeta):
"""Abstract base class for Omni provider tools."""
def __init__(self):
"""Initialize the base Omni tool."""
# No specific initialization needed yet, but included for future extensibility
pass
@abstractmethod
async def __call__(self, **kwargs) -> Any:
"""Executes the tool with the given arguments."""
...
@abstractmethod
def to_params(self) -> Dict[str, Any]:
"""Convert tool to Omni provider-specific API parameters.
Returns:
Dictionary with tool parameters for the specific API
"""
raise NotImplementedError

View File

@@ -1,69 +1,74 @@
"""Provider-agnostic implementation of the BashTool."""
"""Bash tool for Omni provider."""
import logging
from typing import Any, Dict
from computer.computer import Computer
from computer import Computer
from ....core.tools import ToolResult, ToolError
from .base import BaseOmniTool
from ....core.tools.bash import BaseBashTool
from ....core.tools import ToolResult
logger = logging.getLogger(__name__)
class OmniBashTool(BaseBashTool):
"""A provider-agnostic implementation of the bash tool."""
class BashTool(BaseOmniTool):
"""Tool for executing bash commands."""
name = "bash"
logger = logging.getLogger(__name__)
description = "Execute bash commands on the system"
def __init__(self, computer: Computer):
"""Initialize the BashTool.
"""Initialize the bash tool.
Args:
computer: Computer instance, may be used for related operations
computer: Computer instance
"""
super().__init__(computer)
super().__init__()
self.computer = computer
def to_params(self) -> Dict[str, Any]:
"""Convert tool to provider-agnostic parameters.
"""Convert tool to API parameters.
Returns:
Dictionary with tool parameters
"""
return {
"name": self.name,
"description": "A tool that allows the agent to run bash commands",
"parameters": {
"command": {"type": "string", "description": "The bash command to execute"},
"restart": {
"type": "boolean",
"description": "Whether to restart the bash session",
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The bash command to execute",
},
},
"required": ["command"],
},
},
}
async def __call__(self, **kwargs) -> ToolResult:
"""Execute the bash tool with the provided arguments.
"""Execute bash command.
Args:
command: The bash command to execute
restart: Whether to restart the bash session
**kwargs: Command parameters
Returns:
ToolResult with the command output
Tool execution result
"""
command = kwargs.get("command")
restart = kwargs.get("restart", False)
try:
command = kwargs.get("command", "")
if not command:
return ToolResult(error="No command specified")
if not command:
return ToolResult(error="Command is required")
# The true implementation would use the actual method to run terminal commands
# Since we're getting linter errors, we'll just implement a placeholder that will
# be replaced with the correct implementation when this tool is fully integrated
logger.info(f"Would execute command: {command}")
return ToolResult(output=f"Command executed (placeholder): {command}")
self.logger.info(f"Executing bash command: {command}")
exit_code, stdout, stderr = await self.run_command(command)
output = stdout
error = None
if exit_code != 0:
error = f"Command exited with code {exit_code}: {stderr}"
return ToolResult(output=output, error=error)
except Exception as e:
logger.error(f"Error in bash tool: {str(e)}")
return ToolResult(error=f"Error: {str(e)}")

View File

@@ -1,217 +1,179 @@
"""Provider-agnostic implementation of the ComputerTool."""
"""Computer tool for Omni provider."""
import logging
import base64
import io
from typing import Any, Dict
import json
from PIL import Image
from computer.computer import Computer
from ....core.tools.computer import BaseComputerTool
from computer import Computer
from ....core.tools import ToolResult, ToolError
from .base import BaseOmniTool
from ..parser import ParseResult
logger = logging.getLogger(__name__)
class OmniComputerTool(BaseComputerTool):
"""A provider-agnostic implementation of the computer tool."""
class ComputerTool(BaseOmniTool):
"""Tool for interacting with the computer UI."""
name = "computer"
logger = logging.getLogger(__name__)
description = "Interact with the computer's graphical user interface"
def __init__(self, computer: Computer):
"""Initialize the ComputerTool.
"""Initialize the computer tool.
Args:
computer: Computer instance for screen interactions
computer: Computer instance
"""
super().__init__(computer)
# Initialize dimensions to None, will be set in initialize_dimensions
self.width = None
self.height = None
self.display_num = None
super().__init__()
self.computer = computer
# Default to standard screen dimensions (will be set more accurately during initialization)
self.screen_dimensions = {"width": 1440, "height": 900}
async def initialize_dimensions(self) -> None:
"""Initialize screen dimensions."""
# For now, we'll use default values
# In the future, we can implement proper screen dimension detection
logger.info(f"Using default screen dimensions: {self.screen_dimensions}")
def to_params(self) -> Dict[str, Any]:
"""Convert tool to provider-agnostic parameters.
"""Convert tool to API parameters.
Returns:
Dictionary with tool parameters
"""
return {
"name": self.name,
"description": "A tool that allows the agent to interact with the screen, keyboard, and mouse",
"parameters": {
"action": {
"type": "string",
"enum": [
"key",
"type",
"mouse_move",
"left_click",
"left_click_drag",
"right_click",
"middle_click",
"double_click",
"screenshot",
"cursor_position",
"scroll",
],
"description": "The action to perform on the computer",
},
"text": {
"type": "string",
"description": "Text to type or key to press, required for 'key' and 'type' actions",
},
"coordinate": {
"type": "array",
"items": {"type": "integer"},
"description": "X,Y coordinates for mouse actions like click and move",
},
"direction": {
"type": "string",
"enum": ["up", "down"],
"description": "Direction to scroll, used with the 'scroll' action",
},
"amount": {
"type": "integer",
"description": "Amount to scroll, used with the 'scroll' action",
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"left_click",
"right_click",
"double_click",
"move_cursor",
"drag_to",
"type_text",
"press_key",
"hotkey",
"scroll_up",
"scroll_down",
],
"description": "The action to perform",
},
"x": {
"type": "number",
"description": "X coordinate for click or cursor movement",
},
"y": {
"type": "number",
"description": "Y coordinate for click or cursor movement",
},
"box_id": {
"type": "integer",
"description": "ID of the UI element to interact with",
},
"text": {
"type": "string",
"description": "Text to type",
},
"key": {
"type": "string",
"description": "Key to press",
},
"keys": {
"type": "array",
"items": {"type": "string"},
"description": "Keys to press as hotkey combination",
},
"amount": {
"type": "integer",
"description": "Amount to scroll",
},
"duration": {
"type": "number",
"description": "Duration for drag operations",
},
},
"required": ["action"],
},
},
**self.options,
}
async def __call__(self, **kwargs) -> ToolResult:
"""Execute the computer tool with the provided arguments.
"""Execute computer action.
Args:
action: The action to perform
text: Text to type or key to press (for key/type actions)
coordinate: X,Y coordinates (for mouse actions)
direction: Direction to scroll (for scroll action)
amount: Amount to scroll (for scroll action)
**kwargs: Action parameters
Returns:
ToolResult with the action output and optional screenshot
Tool execution result
"""
# Ensure dimensions are initialized
if self.width is None or self.height is None:
await self.initialize_dimensions()
action = kwargs.get("action")
text = kwargs.get("text")
coordinate = kwargs.get("coordinate")
direction = kwargs.get("direction", "down")
amount = kwargs.get("amount", 10)
self.logger.info(f"Executing computer action: {action}")
try:
if action == "screenshot":
return await self.screenshot()
elif action == "left_click" and coordinate:
x, y = coordinate
self.logger.info(f"Clicking at ({x}, {y})")
await self.computer.interface.move_cursor(x, y)
await self.computer.interface.left_click()
action = kwargs.get("action", "").lower()
if not action:
return ToolResult(error="No action specified")
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
screenshot = await self.resize_screenshot_if_needed(screenshot)
return ToolResult(
output=f"Performed left click at ({x}, {y})",
base64_image=base64.b64encode(screenshot).decode(),
# Execute the action on the computer
method = getattr(self.computer.interface, action, None)
if not method:
return ToolResult(error=f"Unsupported action: {action}")
# Prepare arguments based on action type
args = {}
if action in ["left_click", "right_click", "double_click", "move_cursor"]:
x = kwargs.get("x")
y = kwargs.get("y")
if x is None or y is None:
box_id = kwargs.get("box_id")
if box_id is None:
return ToolResult(error="Box ID or coordinates required")
# Get coordinates from box_id implementation would be here
# For now, return error
return ToolResult(error="Box ID-based clicking not implemented yet")
args["x"] = x
args["y"] = y
elif action == "drag_to":
x = kwargs.get("x")
y = kwargs.get("y")
if x is None or y is None:
return ToolResult(error="Coordinates required for drag_to")
args.update(
{
"x": x,
"y": y,
"button": kwargs.get("button", "left"),
"duration": float(kwargs.get("duration", 0.5)),
}
)
elif action == "right_click" and coordinate:
x, y = coordinate
self.logger.info(f"Right clicking at ({x}, {y})")
await self.computer.interface.move_cursor(x, y)
await self.computer.interface.right_click()
elif action == "type_text":
text = kwargs.get("text")
if not text:
return ToolResult(error="Text required for type_text")
args["text"] = text
elif action == "press_key":
key = kwargs.get("key")
if not key:
return ToolResult(error="Key required for press_key")
args["key"] = key
elif action == "hotkey":
keys = kwargs.get("keys")
if not keys:
return ToolResult(error="Keys required for hotkey")
# Call with positional arguments instead of kwargs
await method(*keys)
return ToolResult(output=f"Hotkey executed: {'+'.join(keys)}")
elif action in ["scroll_down", "scroll_up"]:
args["clicks"] = int(kwargs.get("amount", 1))
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
screenshot = await self.resize_screenshot_if_needed(screenshot)
return ToolResult(
output=f"Performed right click at ({x}, {y})",
base64_image=base64.b64encode(screenshot).decode(),
)
elif action == "double_click" and coordinate:
x, y = coordinate
self.logger.info(f"Double clicking at ({x}, {y})")
await self.computer.interface.move_cursor(x, y)
await self.computer.interface.double_click()
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
screenshot = await self.resize_screenshot_if_needed(screenshot)
return ToolResult(
output=f"Performed double click at ({x}, {y})",
base64_image=base64.b64encode(screenshot).decode(),
)
elif action == "mouse_move" and coordinate:
x, y = coordinate
self.logger.info(f"Moving cursor to ({x}, {y})")
await self.computer.interface.move_cursor(x, y)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
screenshot = await self.resize_screenshot_if_needed(screenshot)
return ToolResult(
output=f"Moved cursor to ({x}, {y})",
base64_image=base64.b64encode(screenshot).decode(),
)
elif action == "type" and text:
self.logger.info(f"Typing text: {text}")
await self.computer.interface.type_text(text)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
screenshot = await self.resize_screenshot_if_needed(screenshot)
return ToolResult(
output=f"Typed text: {text}",
base64_image=base64.b64encode(screenshot).decode(),
)
elif action == "key" and text:
self.logger.info(f"Pressing key: {text}")
# Handle special key combinations
if "+" in text:
keys = text.split("+")
await self.computer.interface.hotkey(*keys)
else:
await self.computer.interface.press_key(text)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
screenshot = await self.resize_screenshot_if_needed(screenshot)
return ToolResult(
output=f"Pressed key: {text}",
base64_image=base64.b64encode(screenshot).decode(),
)
elif action == "cursor_position":
pos = await self.computer.interface.get_cursor_position()
x, y = pos
return ToolResult(output=f"X={int(x)},Y={int(y)}")
elif action == "scroll":
if direction == "down":
self.logger.info(f"Scrolling down, amount: {amount}")
for _ in range(amount):
await self.computer.interface.hotkey("fn", "down")
else:
self.logger.info(f"Scrolling up, amount: {amount}")
for _ in range(amount):
await self.computer.interface.hotkey("fn", "up")
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
screenshot = await self.resize_screenshot_if_needed(screenshot)
return ToolResult(
output=f"Scrolled {direction} by {amount} steps",
base64_image=base64.b64encode(screenshot).decode(),
)
# Default to screenshot for unimplemented actions
self.logger.warning(f"Action {action} not fully implemented, taking screenshot")
return await self.screenshot()
# Execute action with prepared arguments
await method(**args)
return ToolResult(output=f"Action {action} executed successfully")
except Exception as e:
self.logger.error(f"Error during computer action: {str(e)}")
return ToolResult(error=f"Failed to perform {action}: {str(e)}")
logger.error(f"Error executing computer action: {str(e)}")
return ToolResult(error=f"Error: {str(e)}")

View File

@@ -1,81 +1,61 @@
"""Omni tool manager implementation."""
from typing import Dict, List, Any
from enum import Enum
"""Tool manager for the Omni provider."""
from typing import Any, Dict, List
from computer.computer import Computer
from ....core.tools import BaseToolManager
from ....core.tools import BaseToolManager, ToolResult
from ....core.tools.collection import ToolCollection
from .bash import OmniBashTool
from .computer import OmniComputerTool
from .computer import ComputerTool
from .bash import BashTool
from ..types import LLMProvider
class ProviderType(Enum):
"""Supported provider types."""
class ToolManager(BaseToolManager):
"""Manages Omni provider tool initialization and execution."""
ANTHROPIC = "anthropic"
OPENAI = "openai"
CLAUDE = "claude" # Alias for Anthropic
GPT = "gpt" # Alias for OpenAI
class OmniToolManager(BaseToolManager):
"""Tool manager for multi-provider support."""
def __init__(self, computer: Computer):
"""Initialize Omni tool manager.
def __init__(self, computer: Computer, provider: LLMProvider):
"""Initialize the tool manager.
Args:
computer: Computer instance for tools
computer: Computer instance for computer-related tools
provider: The LLM provider being used
"""
super().__init__(computer)
# Initialize tools
self.computer_tool = OmniComputerTool(self.computer)
self.bash_tool = OmniBashTool(self.computer)
self.provider = provider
# Initialize Omni-specific tools
self.computer_tool = ComputerTool(self.computer)
self.bash_tool = BashTool(self.computer)
def _initialize_tools(self) -> ToolCollection:
"""Initialize all available tools."""
return ToolCollection(self.computer_tool, self.bash_tool)
async def _initialize_tools_specific(self) -> None:
"""Initialize provider-specific tool requirements."""
"""Initialize Omni provider-specific tool requirements."""
await self.computer_tool.initialize_dimensions()
def get_tool_params(self) -> List[Dict[str, Any]]:
"""Get tool parameters for API calls.
Returns:
List of tool parameters in default format
List of tool parameters for the current provider's API
"""
if self.tools is None:
raise RuntimeError("Tools not initialized. Call initialize() first.")
return self.tools.to_params()
def get_provider_tools(self, provider: ProviderType) -> List[Dict[str, Any]]:
"""Get tools formatted for a specific provider.
async def execute_tool(self, name: str, tool_input: dict[str, Any]) -> ToolResult:
"""Execute a tool with the given input.
Args:
provider: Provider type to format tools for
name: Name of the tool to execute
tool_input: Input parameters for the tool
Returns:
List of tool parameters in provider-specific format
Result of the tool execution
"""
if self.tools is None:
raise RuntimeError("Tools not initialized. Call initialize() first.")
# Default is the base implementation
tools = self.tools.to_params()
# Customize for each provider if needed
if provider in [ProviderType.ANTHROPIC, ProviderType.CLAUDE]:
# Format for Anthropic API
# Additional adjustments can be made here
pass
elif provider in [ProviderType.OPENAI, ProviderType.GPT]:
# Format for OpenAI API
# Future implementation
pass
return tools
return await self.tools.run(name=name, tool_input=tool_input)

View File

@@ -1,157 +1,236 @@
"""Utility functions for Omni provider."""
"""Main entry point for computer agents."""
import base64
import io
import asyncio
import json
import logging
from typing import Tuple
from PIL import Image
import os
from typing import Any, Dict, List, Optional
from som.models import ParseResult
from ...core.types import AgentResponse
logger = logging.getLogger(__name__)
def compress_image_base64(
base64_str: str, max_size_bytes: int = 5 * 1024 * 1024, quality: int = 90
) -> tuple[str, str]:
"""Compress a base64 encoded image to ensure it's below a certain size.
async def to_openai_agent_response_format(
response: Any,
messages: List[Dict[str, Any]],
parsed_screen: Optional[ParseResult] = None,
parser: Optional[Any] = None,
model: Optional[str] = None,
) -> AgentResponse:
"""Create an OpenAI computer use agent compatible response format.
Args:
base64_str: Base64 encoded image string (with or without data URL prefix)
max_size_bytes: Maximum size in bytes (default: 5MB)
quality: Initial JPEG quality (0-100)
response: The original API response
messages: List of messages in standard OpenAI format
parsed_screen: Optional pre-parsed screen information
parser: Optional parser instance for coordinate calculation
model: Optional model name
Returns:
tuple[str, str]: (Compressed base64 encoded image, media_type)
A response formatted according to OpenAI's computer use agent standard, including:
- All standard OpenAI computer use agent fields
- Original response in response.choices[0].message
- Full message history in messages field
"""
# Handle data URL prefix if present (e.g., "data:image/png;base64,...")
original_prefix = ""
media_type = "image/png" # Default media type
from datetime import datetime
import time
if base64_str.startswith("data:"):
parts = base64_str.split(",", 1)
if len(parts) == 2:
original_prefix = parts[0] + ","
base64_str = parts[1]
# Try to extract media type from the prefix
if "image/jpeg" in original_prefix.lower():
media_type = "image/jpeg"
elif "image/png" in original_prefix.lower():
media_type = "image/png"
# Create a unique ID for this response
response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(response)}"
reasoning_id = f"rs_{response_id}"
action_id = f"cu_{response_id}"
call_id = f"call_{response_id}"
# Check if the base64 string is small enough already
if len(base64_str) <= max_size_bytes:
logger.info(f"Image already within size limit: {len(base64_str)} bytes")
return original_prefix + base64_str, media_type
# Extract the last assistant message
assistant_msg = None
for msg in reversed(messages):
if msg["role"] == "assistant":
assistant_msg = msg
break
try:
# Decode base64
img_data = base64.b64decode(base64_str)
img_size = len(img_data)
logger.info(f"Original image size: {img_size} bytes")
if not assistant_msg:
# If no assistant message found, create a default one
assistant_msg = {"role": "assistant", "content": "No response available"}
# Open image
img = Image.open(io.BytesIO(img_data))
# Initialize output array
output_items = []
# First, try to compress as PNG (maintains transparency if present)
buffer = io.BytesIO()
img.save(buffer, format="PNG", optimize=True)
buffer.seek(0)
compressed_data = buffer.getvalue()
compressed_b64 = base64.b64encode(compressed_data).decode("utf-8")
# Extract reasoning and action details from the response
content = assistant_msg["content"]
reasoning_text = None
action_details = None
if len(compressed_b64) <= max_size_bytes:
logger.info(f"Compressed to {len(compressed_data)} bytes as PNG")
return compressed_b64, "image/png"
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
try:
# Try to parse JSON from text block
text_content = item.get("text", "")
parsed_json = json.loads(text_content)
# Strategy 1: Try reducing quality with JPEG format
current_quality = quality
while current_quality > 20:
buffer = io.BytesIO()
# Convert to RGB if image has alpha channel (JPEG doesn't support transparency)
if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info):
logger.info("Converting transparent image to RGB for JPEG compression")
rgb_img = Image.new("RGB", img.size, (255, 255, 255))
rgb_img.paste(img, mask=img.split()[3] if img.mode == "RGBA" else None)
rgb_img.save(buffer, format="JPEG", quality=current_quality, optimize=True)
else:
img.save(buffer, format="JPEG", quality=current_quality, optimize=True)
# Get reasoning text
if reasoning_text is None:
reasoning_text = parsed_json.get("Explanation", "")
buffer.seek(0)
compressed_data = buffer.getvalue()
compressed_b64 = base64.b64encode(compressed_data).decode("utf-8")
# Extract action details
action = parsed_json.get("Action", "").lower()
text_input = parsed_json.get("Text", "")
value = parsed_json.get("Value", "") # Also handle Value field
box_id = parsed_json.get("Box ID") # Extract Box ID
if len(compressed_b64) <= max_size_bytes:
logger.info(
f"Compressed to {len(compressed_data)} bytes with JPEG quality {current_quality}"
)
return compressed_b64, "image/jpeg"
if action in ["click", "left_click"]:
# Always calculate coordinates from Box ID for click actions
x, y = 100, 100 # Default fallback values
# Reduce quality and try again
current_quality -= 10
if parsed_screen and box_id is not None and parser is not None:
try:
box_id_int = (
box_id
if isinstance(box_id, int)
else int(str(box_id)) if str(box_id).isdigit() else None
)
if box_id_int is not None:
# Use the parser's method to calculate coordinates
x, y = await parser.calculate_click_coordinates(
box_id_int, parsed_screen
)
except Exception as e:
logger.error(
f"Error extracting coordinates for Box ID {box_id}: {str(e)}"
)
# Strategy 2: If quality reduction isn't enough, reduce dimensions
scale_factor = 0.8
current_img = img
action_details = {
"type": "click",
"button": "left",
"box_id": (
(
box_id
if isinstance(box_id, int)
else int(box_id) if str(box_id).isdigit() else None
)
if box_id is not None
else None
),
"x": x,
"y": y,
}
elif action in ["type", "type_text"] and (text_input or value):
action_details = {
"type": "type",
"text": text_input or value,
}
elif action == "hotkey" and value:
action_details = {
"type": "hotkey",
"keys": value,
}
elif action == "scroll":
# Use default coordinates for scrolling
delta_x = 0
delta_y = 0
# Try to extract scroll delta values from content if available
scroll_data = parsed_json.get("Scroll", {})
if scroll_data:
delta_x = scroll_data.get("delta_x", 0)
delta_y = scroll_data.get("delta_y", 0)
action_details = {
"type": "scroll",
"x": 100,
"y": 100,
"scroll_x": delta_x,
"scroll_y": delta_y,
}
elif action == "none":
# Handle case when action is None (task completion)
action_details = {"type": "none", "description": "Task completed"}
except json.JSONDecodeError:
# If not JSON, just use as reasoning text
if reasoning_text is None:
reasoning_text = ""
reasoning_text += item.get("text", "")
while scale_factor > 0.3:
# Resize image
new_width = int(img.width * scale_factor)
new_height = int(img.height * scale_factor)
current_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Try with reduced size and quality
buffer = io.BytesIO()
# Convert to RGB if necessary for JPEG
if current_img.mode in ("RGBA", "LA") or (
current_img.mode == "P" and "transparency" in current_img.info
):
rgb_img = Image.new("RGB", current_img.size, (255, 255, 255))
rgb_img.paste(
current_img, mask=current_img.split()[3] if current_img.mode == "RGBA" else None
)
rgb_img.save(buffer, format="JPEG", quality=70, optimize=True)
else:
current_img.save(buffer, format="JPEG", quality=70, optimize=True)
buffer.seek(0)
compressed_data = buffer.getvalue()
compressed_b64 = base64.b64encode(compressed_data).decode("utf-8")
if len(compressed_b64) <= max_size_bytes:
logger.info(
f"Compressed to {len(compressed_data)} bytes with scale {scale_factor} and JPEG quality 70"
)
return compressed_b64, "image/jpeg"
# Reduce scale factor and try again
scale_factor -= 0.1
# If we get here, we couldn't compress enough
logger.warning("Could not compress image below required size with quality preservation")
# Last resort: Use minimum quality and size
buffer = io.BytesIO()
smallest_img = img.resize(
(int(img.width * 0.5), int(img.height * 0.5)), Image.Resampling.LANCZOS
# Add reasoning item if we have text content
if reasoning_text:
output_items.append(
{
"type": "reasoning",
"id": reasoning_id,
"summary": [
{
"type": "summary_text",
"text": reasoning_text[:200], # Truncate to reasonable length
}
],
}
)
# Convert to RGB if necessary
if smallest_img.mode in ("RGBA", "LA") or (
smallest_img.mode == "P" and "transparency" in smallest_img.info
):
rgb_img = Image.new("RGB", smallest_img.size, (255, 255, 255))
rgb_img.paste(
smallest_img, mask=smallest_img.split()[3] if smallest_img.mode == "RGBA" else None
)
rgb_img.save(buffer, format="JPEG", quality=20, optimize=True)
else:
smallest_img.save(buffer, format="JPEG", quality=20, optimize=True)
buffer.seek(0)
final_data = buffer.getvalue()
final_b64 = base64.b64encode(final_data).decode("utf-8")
# If no action details extracted, use default
if not action_details:
action_details = {
"type": "click",
"button": "left",
"x": 100,
"y": 100,
}
logger.warning(f"Final compressed size: {len(final_b64)} bytes (may still exceed limit)")
return final_b64, "image/jpeg"
# Add computer_call item
computer_call = {
"type": "computer_call",
"id": action_id,
"call_id": call_id,
"action": action_details,
"pending_safety_checks": [],
"status": "completed",
}
output_items.append(computer_call)
except Exception as e:
logger.error(f"Error compressing image: {str(e)}")
raise
# Extract user and assistant messages from the history
user_messages = []
assistant_messages = []
for msg in messages:
if msg["role"] == "user":
user_messages.append(msg)
elif msg["role"] == "assistant":
assistant_messages.append(msg)
# Create the OpenAI-compatible response format with all expected fields
return {
"id": response_id,
"object": "response",
"created_at": int(time.time()),
"status": "completed",
"error": None,
"incomplete_details": None,
"instructions": None,
"max_output_tokens": None,
"model": model or "unknown",
"output": output_items,
"parallel_tool_calls": True,
"previous_response_id": None,
"reasoning": {"effort": "medium", "generate_summary": "concise"},
"store": True,
"temperature": 1.0,
"text": {"format": {"type": "text"}},
"tool_choice": "auto",
"tools": [
{
"type": "computer_use_preview",
"display_height": 768,
"display_width": 1024,
"environment": "mac",
}
],
"top_p": 1.0,
"truncation": "auto",
"usage": {
"input_tokens": 0, # Placeholder values
"input_tokens_details": {"cached_tokens": 0},
"output_tokens": 0, # Placeholder values
"output_tokens_details": {"reasoning_tokens": 0},
"total_tokens": 0, # Placeholder values
},
"user": None,
"metadata": {},
# Include the original response for backward compatibility
"response": {"choices": [{"message": assistant_msg, "finish_reason": "stop"}]},
}

View File

@@ -1,130 +0,0 @@
"""Visualization utilities for the Cua provider."""
import base64
import logging
from io import BytesIO
from typing import Tuple
from PIL import Image, ImageDraw
logger = logging.getLogger(__name__)
def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
"""Visualize a click action by drawing on the screenshot.
Args:
x: X coordinate of the click
y: Y coordinate of the click
img_base64: Base64 encoded image to draw on
Returns:
PIL Image with visualization
"""
try:
# Decode the base64 image
img_data = base64.b64decode(img_base64)
img = Image.open(BytesIO(img_data))
# Create a drawing context
draw = ImageDraw.Draw(img)
# Draw concentric circles at the click position
small_radius = 10
large_radius = 30
# Draw filled inner circle
draw.ellipse(
[(x - small_radius, y - small_radius), (x + small_radius, y + small_radius)],
fill="red",
)
# Draw outlined outer circle
draw.ellipse(
[(x - large_radius, y - large_radius), (x + large_radius, y + large_radius)],
outline="red",
width=3,
)
return img
except Exception as e:
logger.error(f"Error visualizing click: {str(e)}")
# Return a blank image in case of error
return Image.new("RGB", (800, 600), color="white")
def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
"""Visualize a scroll action by drawing arrows on the screenshot.
Args:
direction: 'up' or 'down'
clicks: Number of scroll clicks
img_base64: Base64 encoded image to draw on
Returns:
PIL Image with visualization
"""
try:
# Decode the base64 image
img_data = base64.b64decode(img_base64)
img = Image.open(BytesIO(img_data))
# Get image dimensions
width, height = img.size
# Create a drawing context
draw = ImageDraw.Draw(img)
# Determine arrow direction and positions
center_x = width // 2
arrow_width = 100
if direction.lower() == "up":
# Draw up arrow in the middle of the screen
arrow_y = height // 2
# Arrow points
points = [
(center_x, arrow_y - 50), # Top point
(center_x - arrow_width // 2, arrow_y + 50), # Bottom left
(center_x + arrow_width // 2, arrow_y + 50), # Bottom right
]
color = "blue"
else: # down
# Draw down arrow in the middle of the screen
arrow_y = height // 2
# Arrow points
points = [
(center_x, arrow_y + 50), # Bottom point
(center_x - arrow_width // 2, arrow_y - 50), # Top left
(center_x + arrow_width // 2, arrow_y - 50), # Top right
]
color = "green"
# Draw filled arrow
draw.polygon(points, fill=color)
# Add text showing number of clicks
text_y = arrow_y + 70 if direction.lower() == "down" else arrow_y - 70
draw.text((center_x - 40, text_y), f"{clicks} clicks", fill="black")
return img
except Exception as e:
logger.error(f"Error visualizing scroll: {str(e)}")
# Return a blank image in case of error
return Image.new("RGB", (800, 600), color="white")
def calculate_element_center(box: Tuple[int, int, int, int]) -> Tuple[int, int]:
"""Calculate the center coordinates of a bounding box.
Args:
box: Tuple of (left, top, right, bottom) coordinates
Returns:
Tuple of (center_x, center_y) coordinates
"""
left, top, right, bottom = box
center_x = (left + right) // 2
center_y = (top + bottom) // 2
return center_x, center_y

View File

@@ -1,23 +0,0 @@
"""Type definitions for the agent package."""
from .base import HostConfig, TaskResult, Annotation
from .messages import Message, Request, Response, StepMessage, DisengageMessage
from .tools import ToolInvocation, ToolInvocationState, ClientAttachment, ToolResult
__all__ = [
# Base types
"HostConfig",
"TaskResult",
"Annotation",
# Message types
"Message",
"Request",
"Response",
"StepMessage",
"DisengageMessage",
# Tool types
"ToolInvocation",
"ToolInvocationState",
"ClientAttachment",
"ToolResult",
]

View File

@@ -1,41 +0,0 @@
"""Base type definitions."""
from enum import Enum, auto
from typing import Dict, Any
from pydantic import BaseModel, ConfigDict
class HostConfig(BaseModel):
"""Host configuration."""
model_config = ConfigDict(extra="forbid")
hostname: str
port: int
@property
def address(self) -> str:
return f"{self.hostname}:{self.port}"
class TaskResult(BaseModel):
"""Result of a task execution."""
model_config = ConfigDict(extra="forbid")
result: str
vnc_password: str
class Annotation(BaseModel):
"""Annotation metadata."""
model_config = ConfigDict(extra="forbid")
id: str
vm_url: str
class AgentLoop(Enum):
"""Enumeration of available loop types."""
ANTHROPIC = auto() # Anthropic implementation
OMNI = auto() # OmniLoop implementation
# Add more loop types as needed

View File

@@ -1,36 +0,0 @@
"""Message-related type definitions."""
from typing import List, Dict, Any, Optional
from pydantic import BaseModel, ConfigDict
from .tools import ToolInvocation
class Message(BaseModel):
"""Base message type."""
model_config = ConfigDict(extra='forbid')
role: str
content: str
annotations: Optional[List[Dict[str, Any]]] = None
toolInvocations: Optional[List[ToolInvocation]] = None
data: Optional[List[Dict[str, Any]]] = None
errors: Optional[List[str]] = None
class Request(BaseModel):
"""Request type."""
model_config = ConfigDict(extra='forbid')
messages: List[Message]
selectedModel: str
class Response(BaseModel):
"""Response type."""
model_config = ConfigDict(extra='forbid')
messages: List[Message]
vm_url: str
class StepMessage(Message):
"""Message for a single step."""
pass
class DisengageMessage(BaseModel):
"""Message indicating disengagement."""
pass

148
libs/lume/scripts/install.sh Executable file
View File

@@ -0,0 +1,148 @@
#!/bin/bash
set -e
# Lume Installer
# This script installs Lume to your system
# Define colors for output
BOLD=$(tput bold)
NORMAL=$(tput sgr0)
RED=$(tput setaf 1)
GREEN=$(tput setaf 2)
BLUE=$(tput setaf 4)
# Default installation directory
DEFAULT_INSTALL_DIR="/usr/local/bin"
INSTALL_DIR="${INSTALL_DIR:-$DEFAULT_INSTALL_DIR}"
# GitHub info
GITHUB_REPO="trycua/cua"
LATEST_RELEASE_URL="https://api.github.com/repos/$GITHUB_REPO/releases/latest"
echo "${BOLD}${BLUE}Lume Installer${NORMAL}"
echo "This script will install Lume to your system."
# Check if we're running with appropriate permissions
check_permissions() {
if [ "$INSTALL_DIR" = "$DEFAULT_INSTALL_DIR" ] && [ "$(id -u)" != "0" ]; then
echo "${RED}Error: Installing to $INSTALL_DIR requires root privileges.${NORMAL}"
echo "Please run with sudo or specify a different directory with INSTALL_DIR environment variable."
echo "Example: INSTALL_DIR=\$HOME/.local/bin $0"
exit 1
fi
}
# Detect OS and architecture
detect_platform() {
OS=$(uname -s | tr '[:upper:]' '[:lower:]')
ARCH=$(uname -m)
if [ "$OS" != "darwin" ]; then
echo "${RED}Error: Currently only macOS is supported.${NORMAL}"
exit 1
fi
if [ "$ARCH" != "arm64" ]; then
echo "${RED}Error: Lume only supports macOS on Apple Silicon (ARM64).${NORMAL}"
exit 1
fi
PLATFORM="darwin-arm64"
echo "Detected platform: ${BOLD}$PLATFORM${NORMAL}"
}
# Create temporary directory
create_temp_dir() {
TEMP_DIR=$(mktemp -d)
echo "Using temporary directory: $TEMP_DIR"
# Make sure we clean up on exit
trap 'rm -rf "$TEMP_DIR"' EXIT
}
# Download the latest release
download_release() {
echo "Downloading latest Lume release..."
# Use the direct download link with the non-versioned symlink
DOWNLOAD_URL="https://github.com/$GITHUB_REPO/releases/latest/download/lume.tar.gz"
echo "Downloading from: $DOWNLOAD_URL"
# Download the tarball
if command -v curl &> /dev/null; then
curl -L --progress-bar "$DOWNLOAD_URL" -o "$TEMP_DIR/lume.tar.gz"
# Verify the download was successful
if [ ! -s "$TEMP_DIR/lume.tar.gz" ]; then
echo "${RED}Error: Failed to download Lume.${NORMAL}"
echo "The download URL may be incorrect or the file may not exist."
exit 1
fi
# Verify the file is a valid archive
if ! tar -tzf "$TEMP_DIR/lume.tar.gz" > /dev/null 2>&1; then
echo "${RED}Error: The downloaded file is not a valid tar.gz archive.${NORMAL}"
echo "Let's try the alternative URL..."
# Try alternative URL
ALT_DOWNLOAD_URL="https://github.com/$GITHUB_REPO/releases/latest/download/lume-$PLATFORM.tar.gz"
echo "Downloading from alternative URL: $ALT_DOWNLOAD_URL"
curl -L --progress-bar "$ALT_DOWNLOAD_URL" -o "$TEMP_DIR/lume.tar.gz"
# Check again
if ! tar -tzf "$TEMP_DIR/lume.tar.gz" > /dev/null 2>&1; then
echo "${RED}Error: Could not download a valid Lume archive.${NORMAL}"
echo "Please try installing Lume manually from: https://github.com/$GITHUB_REPO/releases/latest"
exit 1
fi
fi
else
echo "${RED}Error: curl is required but not installed.${NORMAL}"
exit 1
fi
}
# Extract and install
install_binary() {
echo "Extracting archive..."
tar -xzf "$TEMP_DIR/lume.tar.gz" -C "$TEMP_DIR"
echo "Installing to $INSTALL_DIR..."
# Create install directory if it doesn't exist
mkdir -p "$INSTALL_DIR"
# Move the binary to the installation directory
mv "$TEMP_DIR/lume" "$INSTALL_DIR/"
# Make the binary executable
chmod +x "$INSTALL_DIR/lume"
echo "${GREEN}Installation complete!${NORMAL}"
echo "Lume has been installed to ${BOLD}$INSTALL_DIR/lume${NORMAL}"
# Check if the installation directory is in PATH
if [ -n "${PATH##*$INSTALL_DIR*}" ]; then
echo "${RED}Warning: $INSTALL_DIR is not in your PATH.${NORMAL}"
echo "You may need to add it to your shell profile:"
echo " For bash: echo 'export PATH=\"\$PATH:$INSTALL_DIR\"' >> ~/.bash_profile"
echo " For zsh: echo 'export PATH=\"\$PATH:$INSTALL_DIR\"' >> ~/.zshrc"
echo " For fish: echo 'fish_add_path $INSTALL_DIR' >> ~/.config/fish/config.fish"
fi
}
# Main installation flow
main() {
check_permissions
detect_platform
create_temp_dir
download_release
install_binary
echo ""
echo "${GREEN}${BOLD}Lume has been successfully installed!${NORMAL}"
echo "Run ${BOLD}lume${NORMAL} to get started."
}
# Run the installation
main