mirror of
https://github.com/trycua/computer.git
synced 2026-02-22 06:19:07 -06:00
Standardize Agent Loop
This commit is contained in:
@@ -6,6 +6,7 @@ import logging
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
import signal
|
||||
import json
|
||||
|
||||
from computer import Computer
|
||||
|
||||
@@ -32,42 +33,31 @@ async def run_omni_agent_example():
|
||||
# Create agent with loop and provider
|
||||
agent = ComputerAgent(
|
||||
computer=computer,
|
||||
loop=AgentLoop.ANTHROPIC,
|
||||
# loop=AgentLoop.OMNI,
|
||||
# loop=AgentLoop.ANTHROPIC,
|
||||
loop=AgentLoop.OMNI,
|
||||
# model=LLM(provider=LLMProvider.OPENAI, name="gpt-4.5-preview"),
|
||||
model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"),
|
||||
save_trajectory=True,
|
||||
trajectory_dir=str(Path("trajectories")),
|
||||
only_n_most_recent_images=3,
|
||||
verbosity=logging.INFO,
|
||||
verbosity=logging.DEBUG,
|
||||
)
|
||||
|
||||
tasks = [
|
||||
"""
|
||||
1. Look for a repository named trycua/lume on GitHub.
|
||||
2. Check the open issues, open the most recent one and read it.
|
||||
3. Clone the repository in users/lume/projects if it doesn't exist yet.
|
||||
4. Open the repository with an app named Cursor (on the dock, black background and white cube icon).
|
||||
5. From Cursor, open Composer if not already open.
|
||||
6. Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.
|
||||
"""
|
||||
"Look for a repository named trycua/cua on GitHub.",
|
||||
"Check the open issues, open the most recent one and read it.",
|
||||
"Clone the repository in users/lume/projects if it doesn't exist yet.",
|
||||
"Open the repository with an app named Cursor (on the dock, black background and white cube icon).",
|
||||
"From Cursor, open Composer if not already open.",
|
||||
"Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.",
|
||||
]
|
||||
|
||||
async with agent:
|
||||
for i, task in enumerate(tasks, 1):
|
||||
for i, task in enumerate(tasks):
|
||||
print(f"\nExecuting task {i}/{len(tasks)}: {task}")
|
||||
async for result in agent.run(task):
|
||||
# Check if result has the expected structure
|
||||
if "role" in result and "content" in result and "metadata" in result:
|
||||
title = result["metadata"].get("title", "Screen Analysis")
|
||||
content = result["content"]
|
||||
else:
|
||||
title = result.get("metadata", {}).get("title", "Screen Analysis")
|
||||
content = result.get("content", str(result))
|
||||
print(result)
|
||||
|
||||
print(f"\n{title}")
|
||||
print(content)
|
||||
print(f"Task {i} completed")
|
||||
print(f"\n✅ Task {i+1}/{len(tasks)} completed: {task}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_omni_agent_example: {e}")
|
||||
|
||||
@@ -2,11 +2,6 @@
|
||||
|
||||
from .loop import BaseLoop
|
||||
from .messages import (
|
||||
create_user_message,
|
||||
create_assistant_message,
|
||||
create_system_message,
|
||||
create_image_message,
|
||||
create_screen_message,
|
||||
BaseMessageManager,
|
||||
ImageRetentionConfig,
|
||||
)
|
||||
|
||||
@@ -3,8 +3,7 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Dict, Optional, cast
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, AsyncGenerator, Dict, Optional, cast, List
|
||||
|
||||
from computer import Computer
|
||||
from ..providers.anthropic.loop import AnthropicLoop
|
||||
@@ -12,6 +11,7 @@ from ..providers.omni.loop import OmniLoop
|
||||
from ..providers.omni.parser import OmniParser
|
||||
from ..providers.omni.types import LLMProvider, LLM
|
||||
from .. import AgentLoop
|
||||
from .messages import StandardMessageManager, ImageRetentionConfig
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,7 +44,6 @@ class ComputerAgent:
|
||||
save_trajectory: bool = True,
|
||||
trajectory_dir: str = "trajectories",
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
parser: Optional[OmniParser] = None,
|
||||
verbosity: int = logging.INFO,
|
||||
):
|
||||
"""Initialize the ComputerAgent.
|
||||
@@ -61,7 +60,6 @@ class ComputerAgent:
|
||||
save_trajectory: Whether to save the trajectory.
|
||||
trajectory_dir: Directory to save the trajectory.
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests.
|
||||
parser: Parser instance for the OmniLoop. Only used if provider is not ANTHROPIC.
|
||||
verbosity: Logging level.
|
||||
"""
|
||||
# Basic agent configuration
|
||||
@@ -74,6 +72,11 @@ class ComputerAgent:
|
||||
self._initialized = False
|
||||
self._in_context = False
|
||||
|
||||
# Initialize the message manager for standardized message handling
|
||||
self.message_manager = StandardMessageManager(
|
||||
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
||||
)
|
||||
|
||||
# Set logging level
|
||||
logger.setLevel(verbosity)
|
||||
|
||||
@@ -118,10 +121,6 @@ class ComputerAgent:
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
)
|
||||
else:
|
||||
# Default to OmniLoop for other loop types
|
||||
# Initialize parser if not provided
|
||||
actual_parser = parser or OmniParser()
|
||||
|
||||
self._loop = OmniLoop(
|
||||
provider=self.provider,
|
||||
api_key=actual_api_key,
|
||||
@@ -130,7 +129,7 @@ class ComputerAgent:
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
parser=actual_parser,
|
||||
parser=OmniParser(),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -224,13 +223,25 @@ class ComputerAgent:
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Running task: {task}")
|
||||
logger.info(
|
||||
f"Message history before task has {len(self.message_manager.messages)} messages"
|
||||
)
|
||||
|
||||
# Initialize the computer if needed
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
# Format task as a message
|
||||
messages = [{"role": "user", "content": task}]
|
||||
# Add task as a user message using the message manager
|
||||
self.message_manager.add_user_message([{"type": "text", "text": task}])
|
||||
logger.info(
|
||||
f"Added task message. Message history now has {len(self.message_manager.messages)} messages"
|
||||
)
|
||||
|
||||
# Log message history types to help with debugging
|
||||
message_types = [
|
||||
f"{i}: {msg['role']}" for i, msg in enumerate(self.message_manager.messages)
|
||||
]
|
||||
logger.info(f"Message history roles: {', '.join(message_types)}")
|
||||
|
||||
# Pass properly formatted messages to the loop
|
||||
if self._loop is None:
|
||||
@@ -239,9 +250,28 @@ class ComputerAgent:
|
||||
return
|
||||
|
||||
# Execute the task and yield results
|
||||
async for result in self._loop.run(messages):
|
||||
async for result in self._loop.run(self.message_manager.messages):
|
||||
# Extract the assistant message from the result and add it to our history
|
||||
assistant_response = result["response"]["choices"][0].get("message", None)
|
||||
if assistant_response and assistant_response.get("role") == "assistant":
|
||||
# Extract the content from the assistant response
|
||||
content = assistant_response.get("content")
|
||||
self.message_manager.add_assistant_message(content)
|
||||
|
||||
logger.info("Added assistant response to message history")
|
||||
|
||||
# Yield the result to the caller
|
||||
yield result
|
||||
|
||||
# Logging the message history for debugging
|
||||
logger.info(
|
||||
f"Updated message history now has {len(self.message_manager.messages)} messages"
|
||||
)
|
||||
message_types = [
|
||||
f"{i}: {msg['role']}" for i, msg in enumerate(self.message_manager.messages)
|
||||
]
|
||||
logger.info(f"Updated message history roles: {', '.join(message_types)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in agent run method: {str(e)}")
|
||||
yield {
|
||||
|
||||
@@ -2,12 +2,9 @@
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
import base64
|
||||
|
||||
from computer import Computer
|
||||
from .experiment import ExperimentManager
|
||||
@@ -18,6 +15,10 @@ logger = logging.getLogger(__name__)
|
||||
class BaseLoop(ABC):
|
||||
"""Base class for agent loops that handle message processing and tool execution."""
|
||||
|
||||
###########################################
|
||||
# INITIALIZATION AND CONFIGURATION
|
||||
###########################################
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
computer: Computer,
|
||||
@@ -75,6 +76,64 @@ class BaseLoop(ABC):
|
||||
# Initialize basic tracking
|
||||
self.turn_count = 0
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize both the API client and computer interface with retries."""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
logger.info(
|
||||
f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..."
|
||||
)
|
||||
|
||||
# Initialize API client
|
||||
await self.initialize_client()
|
||||
|
||||
logger.info("Initialization complete.")
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.warning(
|
||||
f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..."
|
||||
)
|
||||
await asyncio.sleep(self.retry_delay)
|
||||
else:
|
||||
logger.error(
|
||||
f"Initialization failed after {self.max_retries} attempts: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(f"Failed to initialize: {str(e)}")
|
||||
|
||||
###########################################
|
||||
|
||||
# ABSTRACT METHODS TO BE IMPLEMENTED BY SUBCLASSES
|
||||
###########################################
|
||||
|
||||
@abstractmethod
|
||||
async def initialize_client(self) -> None:
|
||||
"""Initialize the API client and any provider-specific components.
|
||||
|
||||
This method must be implemented by subclasses to set up
|
||||
provider-specific clients and tools.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
This method handles the main agent loop including message processing,
|
||||
API calls, response handling, and action execution.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
|
||||
Yields:
|
||||
Dict containing response data
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
###########################################
|
||||
# EXPERIMENT AND TRAJECTORY MANAGEMENT
|
||||
###########################################
|
||||
|
||||
def _setup_experiment_dirs(self) -> None:
|
||||
"""Setup the experiment directory structure."""
|
||||
if self.experiment_manager:
|
||||
@@ -100,10 +159,13 @@ class BaseLoop(ABC):
|
||||
) -> None:
|
||||
"""Log API call details to file.
|
||||
|
||||
Preserves provider-specific formats for requests and responses to ensure
|
||||
accurate logging for debugging and analysis purposes.
|
||||
|
||||
Args:
|
||||
call_type: Type of API call (e.g., 'request', 'response', 'error')
|
||||
request: The API request data
|
||||
response: Optional API response data
|
||||
request: The API request data in provider-specific format
|
||||
response: Optional API response data in provider-specific format
|
||||
error: Optional error information
|
||||
"""
|
||||
if self.experiment_manager:
|
||||
@@ -130,119 +192,155 @@ class BaseLoop(ABC):
|
||||
if self.experiment_manager:
|
||||
self.experiment_manager.save_screenshot(img_base64, action_type)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize both the API client and computer interface with retries."""
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
logger.info(
|
||||
f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..."
|
||||
)
|
||||
def _create_openai_compatible_response(
|
||||
self, response: Any, messages: List[Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Create an OpenAI computer use agent compatible response format.
|
||||
|
||||
# Initialize API client
|
||||
await self.initialize_client()
|
||||
|
||||
logger.info("Initialization complete.")
|
||||
return
|
||||
except Exception as e:
|
||||
if attempt < self.max_retries - 1:
|
||||
logger.warning(
|
||||
f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..."
|
||||
)
|
||||
await asyncio.sleep(self.retry_delay)
|
||||
else:
|
||||
logger.error(
|
||||
f"Initialization failed after {self.max_retries} attempts: {str(e)}"
|
||||
)
|
||||
raise RuntimeError(f"Failed to initialize: {str(e)}")
|
||||
|
||||
async def _get_parsed_screen_som(self) -> Dict[str, Any]:
|
||||
"""Get parsed screen information.
|
||||
Args:
|
||||
response: The original API response
|
||||
messages: List of messages in standard OpenAI format
|
||||
|
||||
Returns:
|
||||
Dict containing screen information
|
||||
A response formatted according to OpenAI's computer use agent standard
|
||||
"""
|
||||
try:
|
||||
# Take screenshot
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
import json
|
||||
|
||||
# Initialize with default values
|
||||
width, height = 1024, 768
|
||||
base64_image = ""
|
||||
# Create a unique ID for this response
|
||||
response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(response)}"
|
||||
reasoning_id = f"rs_{response_id}"
|
||||
action_id = f"cu_{response_id}"
|
||||
call_id = f"call_{response_id}"
|
||||
|
||||
# Handle different types of screenshot returns
|
||||
if isinstance(screenshot, (bytes, bytearray, memoryview)):
|
||||
# Raw bytes screenshot
|
||||
base64_image = base64.b64encode(screenshot).decode("utf-8")
|
||||
elif hasattr(screenshot, "base64_image"):
|
||||
# Object-style screenshot with attributes
|
||||
# Type checking can't infer these attributes, but they exist at runtime
|
||||
# on certain screenshot return types
|
||||
base64_image = getattr(screenshot, "base64_image")
|
||||
width = (
|
||||
getattr(screenshot, "width", width) if hasattr(screenshot, "width") else width
|
||||
)
|
||||
height = (
|
||||
getattr(screenshot, "height", height)
|
||||
if hasattr(screenshot, "height")
|
||||
else height
|
||||
)
|
||||
# Extract the last assistant message
|
||||
assistant_msg = None
|
||||
for msg in reversed(messages):
|
||||
if msg["role"] == "assistant":
|
||||
assistant_msg = msg
|
||||
break
|
||||
|
||||
# Create parsed screen data
|
||||
parsed_screen = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"parsed_content_list": [],
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"screenshot_base64": base64_image,
|
||||
if not assistant_msg:
|
||||
# If no assistant message found, create a default one
|
||||
assistant_msg = {"role": "assistant", "content": "No response available"}
|
||||
|
||||
# Initialize output array
|
||||
output_items = []
|
||||
|
||||
# Extract reasoning and action details from the response
|
||||
content = assistant_msg["content"]
|
||||
reasoning_text = None
|
||||
action_details = None
|
||||
|
||||
# Extract reasoning and action from different content formats
|
||||
if isinstance(content, str):
|
||||
try:
|
||||
# Try to parse JSON
|
||||
parsed_content = json.loads(content)
|
||||
reasoning_text = parsed_content.get("Explanation", "")
|
||||
|
||||
# Extract action details
|
||||
action = parsed_content.get("Action", "")
|
||||
position = parsed_content.get("Position", {})
|
||||
text_input = parsed_content.get("Text", "")
|
||||
|
||||
if action.lower() == "click" and position:
|
||||
action_details = {
|
||||
"type": "click",
|
||||
"button": "left",
|
||||
"x": position.get("x", 100),
|
||||
"y": position.get("y", 100),
|
||||
}
|
||||
elif action.lower() == "type" and text_input:
|
||||
action_details = {
|
||||
"type": "type",
|
||||
"text": text_input,
|
||||
}
|
||||
elif action.lower() == "scroll":
|
||||
action_details = {
|
||||
"type": "scroll",
|
||||
"x": 100,
|
||||
"y": 100,
|
||||
"scroll_x": position.get("delta_x", 0),
|
||||
"scroll_y": position.get("delta_y", 0),
|
||||
}
|
||||
except json.JSONDecodeError:
|
||||
# If not valid JSON, use the content as reasoning
|
||||
reasoning_text = content
|
||||
elif isinstance(content, list):
|
||||
# Handle list of content blocks (like Anthropic format)
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
if item.get("type") == "text":
|
||||
# Collect text blocks for reasoning
|
||||
if reasoning_text is None:
|
||||
reasoning_text = ""
|
||||
reasoning_text += item.get("text", "")
|
||||
elif item.get("type") == "tool_use":
|
||||
# Extract action from tool_use (similar to Anthropic format)
|
||||
tool_input = item.get("input", {})
|
||||
if "click" in tool_input or "position" in tool_input:
|
||||
position = tool_input.get("click", tool_input.get("position", {}))
|
||||
if isinstance(position, dict) and "x" in position and "y" in position:
|
||||
action_details = {
|
||||
"type": "click",
|
||||
"button": "left",
|
||||
"x": position.get("x", 100),
|
||||
"y": position.get("y", 100),
|
||||
}
|
||||
elif "type" in tool_input or "text" in tool_input:
|
||||
action_details = {
|
||||
"type": "type",
|
||||
"text": tool_input.get("type", tool_input.get("text", "")),
|
||||
}
|
||||
elif "scroll" in tool_input:
|
||||
scroll = tool_input.get("scroll", {})
|
||||
action_details = {
|
||||
"type": "scroll",
|
||||
"x": 100,
|
||||
"y": 100,
|
||||
"scroll_x": scroll.get("x", 0),
|
||||
"scroll_y": scroll.get("y", 0),
|
||||
}
|
||||
|
||||
# Add reasoning item if we have text content
|
||||
if reasoning_text:
|
||||
output_items.append(
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": reasoning_id,
|
||||
"summary": [
|
||||
{
|
||||
"type": "summary_text",
|
||||
"text": reasoning_text[:200], # Truncate to reasonable length
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# If no action details extracted, use default
|
||||
if not action_details:
|
||||
action_details = {
|
||||
"type": "click",
|
||||
"button": "left",
|
||||
"x": 100,
|
||||
"y": 100,
|
||||
}
|
||||
|
||||
# Save screenshot if requested
|
||||
if self.save_trajectory and self.experiment_manager:
|
||||
try:
|
||||
img_data = base64_image
|
||||
if "," in img_data:
|
||||
img_data = img_data.split(",")[1]
|
||||
self._save_screenshot(img_data, action_type="state")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving screenshot: {str(e)}")
|
||||
# Add computer_call item
|
||||
computer_call = {
|
||||
"type": "computer_call",
|
||||
"id": action_id,
|
||||
"call_id": call_id,
|
||||
"action": action_details,
|
||||
"pending_safety_checks": [],
|
||||
"status": "completed",
|
||||
}
|
||||
output_items.append(computer_call)
|
||||
|
||||
return parsed_screen
|
||||
except Exception as e:
|
||||
logger.error(f"Error taking screenshot: {str(e)}")
|
||||
return {
|
||||
"width": 1024,
|
||||
"height": 768,
|
||||
"parsed_content_list": [],
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"error": f"Error taking screenshot: {str(e)}",
|
||||
"screenshot_base64": "",
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
async def initialize_client(self) -> None:
|
||||
"""Initialize the API client and any provider-specific components."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
|
||||
Yields:
|
||||
Dict containing response data
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def _process_screen(
|
||||
self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Process screen information and add to messages.
|
||||
|
||||
Args:
|
||||
parsed_screen: Dictionary containing parsed screen info
|
||||
messages: List of messages to update
|
||||
"""
|
||||
raise NotImplementedError
|
||||
# Create the OpenAI-compatible response format
|
||||
return {
|
||||
"output": output_items,
|
||||
"id": response_id,
|
||||
# Include the original response for compatibility
|
||||
"response": {"choices": [{"message": assistant_msg, "finish_reason": "stop"}]},
|
||||
}
|
||||
|
||||
@@ -4,9 +4,10 @@ import base64
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union, cast, Tuple
|
||||
from PIL import Image
|
||||
from dataclasses import dataclass
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -123,123 +124,278 @@ class BaseMessageManager:
|
||||
break
|
||||
|
||||
|
||||
def create_user_message(text: str) -> Dict[str, str]:
|
||||
"""Create a user message.
|
||||
class StandardMessageManager:
|
||||
"""Manages messages in a standardized OpenAI format across different providers."""
|
||||
|
||||
Args:
|
||||
text: The message text
|
||||
def __init__(self, config: Optional[ImageRetentionConfig] = None):
|
||||
"""Initialize message manager.
|
||||
|
||||
Returns:
|
||||
Message dictionary
|
||||
"""
|
||||
return {
|
||||
"role": "user",
|
||||
"content": text,
|
||||
}
|
||||
Args:
|
||||
config: Configuration for image retention
|
||||
"""
|
||||
self.messages: List[Dict[str, Any]] = []
|
||||
self.config = config or ImageRetentionConfig()
|
||||
|
||||
def add_user_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
|
||||
"""Add a user message.
|
||||
|
||||
def create_assistant_message(text: str) -> Dict[str, str]:
|
||||
"""Create an assistant message.
|
||||
Args:
|
||||
content: Message content (text or multimodal content)
|
||||
"""
|
||||
self.messages.append({"role": "user", "content": content})
|
||||
|
||||
Args:
|
||||
text: The message text
|
||||
def add_assistant_message(self, content: Union[str, List[Dict[str, Any]]]) -> None:
|
||||
"""Add an assistant message.
|
||||
|
||||
Returns:
|
||||
Message dictionary
|
||||
"""
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
}
|
||||
Args:
|
||||
content: Message content (text or multimodal content)
|
||||
"""
|
||||
self.messages.append({"role": "assistant", "content": content})
|
||||
|
||||
def add_system_message(self, content: str) -> None:
|
||||
"""Add a system message.
|
||||
|
||||
def create_system_message(text: str) -> Dict[str, str]:
|
||||
"""Create a system message.
|
||||
Args:
|
||||
content: System message content
|
||||
"""
|
||||
self.messages.append({"role": "system", "content": content})
|
||||
|
||||
Args:
|
||||
text: The message text
|
||||
def get_messages(self) -> List[Dict[str, Any]]:
|
||||
"""Get all messages in standard format.
|
||||
|
||||
Returns:
|
||||
Message dictionary
|
||||
"""
|
||||
return {
|
||||
"role": "system",
|
||||
"content": text,
|
||||
}
|
||||
Returns:
|
||||
List of messages
|
||||
"""
|
||||
# If image retention is configured, apply it
|
||||
if self.config.num_images_to_keep is not None:
|
||||
return self._apply_image_retention(self.messages)
|
||||
return self.messages
|
||||
|
||||
def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Apply image retention policy to messages.
|
||||
|
||||
def create_image_message(
|
||||
image_base64: Optional[str] = None,
|
||||
image_path: Optional[str] = None,
|
||||
image_obj: Optional[Image.Image] = None,
|
||||
) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
|
||||
"""Create a message with an image.
|
||||
Args:
|
||||
messages: List of messages
|
||||
|
||||
Args:
|
||||
image_base64: Base64 encoded image
|
||||
image_path: Path to image file
|
||||
image_obj: PIL Image object
|
||||
Returns:
|
||||
List of messages with image retention applied
|
||||
"""
|
||||
if not self.config.num_images_to_keep:
|
||||
return messages
|
||||
|
||||
Returns:
|
||||
Message dictionary with content list
|
||||
# Find user messages with images
|
||||
image_messages = []
|
||||
for msg in messages:
|
||||
if msg["role"] == "user" and isinstance(msg["content"], list):
|
||||
has_image = any(
|
||||
item.get("type") == "image_url" or item.get("type") == "image"
|
||||
for item in msg["content"]
|
||||
)
|
||||
if has_image:
|
||||
image_messages.append(msg)
|
||||
|
||||
Raises:
|
||||
ValueError: If no image source is provided
|
||||
"""
|
||||
if not any([image_base64, image_path, image_obj]):
|
||||
raise ValueError("Must provide one of image_base64, image_path, or image_obj")
|
||||
# If we don't have more images than the limit, return all messages
|
||||
if len(image_messages) <= self.config.num_images_to_keep:
|
||||
return messages
|
||||
|
||||
# Convert to base64 if needed
|
||||
if image_path and not image_base64:
|
||||
with open(image_path, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
elif image_obj and not image_base64:
|
||||
buffer = BytesIO()
|
||||
image_obj.save(buffer, format="PNG")
|
||||
image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
|
||||
# Get the most recent N images to keep
|
||||
images_to_keep = image_messages[-self.config.num_images_to_keep :]
|
||||
images_to_remove = image_messages[: -self.config.num_images_to_keep]
|
||||
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
|
||||
],
|
||||
}
|
||||
# Create a new message list without the older images
|
||||
result = []
|
||||
for msg in messages:
|
||||
if msg in images_to_remove:
|
||||
# Skip this message
|
||||
continue
|
||||
result.append(msg)
|
||||
|
||||
return result
|
||||
|
||||
def create_screen_message(
|
||||
parsed_screen: Dict[str, Any],
|
||||
include_raw: bool = False,
|
||||
) -> Dict[str, Union[str, List[Dict[str, Any]]]]:
|
||||
"""Create a message with screen information.
|
||||
def to_anthropic_format(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> Tuple[List[Dict[str, Any]], str]:
|
||||
"""Convert standard OpenAI format messages to Anthropic format.
|
||||
|
||||
Args:
|
||||
parsed_screen: Dictionary containing parsed screen info
|
||||
include_raw: Whether to include raw screenshot base64
|
||||
Args:
|
||||
messages: List of messages in OpenAI format
|
||||
|
||||
Returns:
|
||||
Message dictionary with content
|
||||
"""
|
||||
if include_raw and "screenshot_base64" in parsed_screen:
|
||||
# Create content list with both image and text
|
||||
return {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{parsed_screen['screenshot_base64']}"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
|
||||
},
|
||||
],
|
||||
}
|
||||
else:
|
||||
# Create text-only message with screen info
|
||||
return {
|
||||
"role": "user",
|
||||
"content": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}",
|
||||
}
|
||||
Returns:
|
||||
Tuple containing (anthropic_messages, system_content)
|
||||
"""
|
||||
result = []
|
||||
system_content = ""
|
||||
|
||||
# Process messages in order to maintain conversation flow
|
||||
previous_assistant_tool_use_ids = (
|
||||
set()
|
||||
) # Track tool_use_ids in the previous assistant message
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
# Collect system messages for later use
|
||||
system_content += content + "\n"
|
||||
continue
|
||||
|
||||
if role == "assistant":
|
||||
# Track tool_use_ids in this assistant message for the next user message
|
||||
previous_assistant_tool_use_ids = set()
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if (
|
||||
isinstance(item, dict)
|
||||
and item.get("type") == "tool_use"
|
||||
and "id" in item
|
||||
):
|
||||
previous_assistant_tool_use_ids.add(item["id"])
|
||||
|
||||
logger.info(
|
||||
f"Tool use IDs in assistant message #{i}: {previous_assistant_tool_use_ids}"
|
||||
)
|
||||
|
||||
if role in ["user", "assistant"]:
|
||||
anthropic_msg = {"role": role}
|
||||
|
||||
# Convert content based on type
|
||||
if isinstance(content, str):
|
||||
# Simple text content
|
||||
anthropic_msg["content"] = [{"type": "text", "text": content}]
|
||||
elif isinstance(content, list):
|
||||
# Convert complex content
|
||||
anthropic_content = []
|
||||
for item in content:
|
||||
item_type = item.get("type", "")
|
||||
|
||||
if item_type == "text":
|
||||
anthropic_content.append({"type": "text", "text": item.get("text", "")})
|
||||
elif item_type == "image_url":
|
||||
# Convert OpenAI image format to Anthropic
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
if image_url.startswith("data:"):
|
||||
# Extract base64 data and media type
|
||||
match = re.match(r"data:(.+);base64,(.+)", image_url)
|
||||
if match:
|
||||
media_type, data = match.groups()
|
||||
anthropic_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": media_type,
|
||||
"data": data,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Regular URL
|
||||
anthropic_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "url",
|
||||
"url": image_url,
|
||||
},
|
||||
}
|
||||
)
|
||||
elif item_type == "tool_use":
|
||||
# Always include tool_use blocks
|
||||
anthropic_content.append(item)
|
||||
elif item_type == "tool_result":
|
||||
# Check if this is a user message AND if the tool_use_id exists in the previous assistant message
|
||||
tool_use_id = item.get("tool_use_id")
|
||||
|
||||
# Only include tool_result if it references a tool_use from the immediately preceding assistant message
|
||||
if (
|
||||
role == "user"
|
||||
and tool_use_id
|
||||
and tool_use_id in previous_assistant_tool_use_ids
|
||||
):
|
||||
anthropic_content.append(item)
|
||||
logger.info(
|
||||
f"Including tool_result with tool_use_id: {tool_use_id}"
|
||||
)
|
||||
else:
|
||||
# Convert to text to preserve information
|
||||
logger.warning(
|
||||
f"Converting tool_result to text. Tool use ID {tool_use_id} not found in previous assistant message"
|
||||
)
|
||||
content_text = "Tool Result: "
|
||||
if "content" in item:
|
||||
if isinstance(item["content"], list):
|
||||
for content_item in item["content"]:
|
||||
if (
|
||||
isinstance(content_item, dict)
|
||||
and content_item.get("type") == "text"
|
||||
):
|
||||
content_text += content_item.get("text", "")
|
||||
elif isinstance(item["content"], str):
|
||||
content_text += item["content"]
|
||||
anthropic_content.append({"type": "text", "text": content_text})
|
||||
|
||||
anthropic_msg["content"] = anthropic_content
|
||||
|
||||
result.append(anthropic_msg)
|
||||
|
||||
return result, system_content
|
||||
|
||||
def from_anthropic_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert Anthropic format messages to standard OpenAI format.
|
||||
|
||||
Args:
|
||||
messages: List of messages in Anthropic format
|
||||
|
||||
Returns:
|
||||
List of messages in OpenAI format
|
||||
"""
|
||||
result = []
|
||||
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", [])
|
||||
|
||||
if role in ["user", "assistant"]:
|
||||
openai_msg = {"role": role}
|
||||
|
||||
# Simple case: single text block
|
||||
if len(content) == 1 and content[0].get("type") == "text":
|
||||
openai_msg["content"] = content[0].get("text", "")
|
||||
else:
|
||||
# Complex case: multiple blocks or non-text
|
||||
openai_content = []
|
||||
for item in content:
|
||||
item_type = item.get("type", "")
|
||||
|
||||
if item_type == "text":
|
||||
openai_content.append({"type": "text", "text": item.get("text", "")})
|
||||
elif item_type == "image":
|
||||
# Convert Anthropic image to OpenAI format
|
||||
source = item.get("source", {})
|
||||
if source.get("type") == "base64":
|
||||
media_type = source.get("media_type", "image/png")
|
||||
data = source.get("data", "")
|
||||
openai_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:{media_type};base64,{data}"},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# URL
|
||||
openai_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": source.get("url", "")},
|
||||
}
|
||||
)
|
||||
elif item_type in ["tool_use", "tool_result"]:
|
||||
# Pass through tool-related content
|
||||
openai_content.append(item)
|
||||
|
||||
openai_msg["content"] = openai_content
|
||||
|
||||
result.append(openai_msg)
|
||||
|
||||
return result
|
||||
|
||||
197
libs/agent/agent/core/visualization.py
Normal file
197
libs/agent/agent/core/visualization.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Core visualization utilities for agents."""
|
||||
|
||||
import logging
|
||||
import base64
|
||||
from typing import Dict, Tuple
|
||||
from PIL import Image, ImageDraw
|
||||
from io import BytesIO
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
|
||||
"""Visualize a click action by drawing a circle on the screenshot.
|
||||
|
||||
Args:
|
||||
x: X coordinate of the click
|
||||
y: Y coordinate of the click
|
||||
img_base64: Base64-encoded screenshot
|
||||
|
||||
Returns:
|
||||
PIL Image with visualization
|
||||
"""
|
||||
try:
|
||||
# Decode the base64 image
|
||||
image_data = base64.b64decode(img_base64)
|
||||
img = Image.open(BytesIO(image_data))
|
||||
|
||||
# Create a copy to draw on
|
||||
draw_img = img.copy()
|
||||
draw = ImageDraw.Draw(draw_img)
|
||||
|
||||
# Draw a circle at the click location
|
||||
radius = 15
|
||||
draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], outline="red", width=3)
|
||||
|
||||
# Draw crosshairs
|
||||
line_length = 20
|
||||
draw.line([(x - line_length, y), (x + line_length, y)], fill="red", width=3)
|
||||
draw.line([(x, y - line_length), (x, y + line_length)], fill="red", width=3)
|
||||
|
||||
return draw_img
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing click: {str(e)}")
|
||||
# Return a blank image as fallback
|
||||
return Image.new("RGB", (800, 600), "white")
|
||||
|
||||
|
||||
def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
|
||||
"""Visualize a scroll action by drawing arrows on the screenshot.
|
||||
|
||||
Args:
|
||||
direction: Direction of scroll ('up' or 'down')
|
||||
clicks: Number of scroll clicks
|
||||
img_base64: Base64-encoded screenshot
|
||||
|
||||
Returns:
|
||||
PIL Image with visualization
|
||||
"""
|
||||
try:
|
||||
# Decode the base64 image
|
||||
image_data = base64.b64decode(img_base64)
|
||||
img = Image.open(BytesIO(image_data))
|
||||
|
||||
# Create a copy to draw on
|
||||
draw_img = img.copy()
|
||||
draw = ImageDraw.Draw(draw_img)
|
||||
|
||||
# Calculate parameters for visualization
|
||||
width, height = img.size
|
||||
center_x = width // 2
|
||||
|
||||
# Draw arrows to indicate scrolling
|
||||
arrow_length = min(100, height // 4)
|
||||
arrow_width = 30
|
||||
num_arrows = min(clicks, 3) # Don't draw too many arrows
|
||||
|
||||
# Calculate starting position
|
||||
if direction == "down":
|
||||
start_y = height // 3
|
||||
arrow_dir = 1 # Down
|
||||
else:
|
||||
start_y = height * 2 // 3
|
||||
arrow_dir = -1 # Up
|
||||
|
||||
# Draw the arrows
|
||||
for i in range(num_arrows):
|
||||
y_pos = start_y + (i * arrow_length * arrow_dir * 0.7)
|
||||
arrow_top = (center_x, y_pos)
|
||||
arrow_bottom = (center_x, y_pos + arrow_length * arrow_dir)
|
||||
|
||||
# Draw the main line
|
||||
draw.line([arrow_top, arrow_bottom], fill="red", width=5)
|
||||
|
||||
# Draw the arrowhead
|
||||
arrowhead_size = 20
|
||||
if direction == "down":
|
||||
draw.line(
|
||||
[
|
||||
(center_x - arrow_width // 2, arrow_bottom[1] - arrowhead_size),
|
||||
arrow_bottom,
|
||||
(center_x + arrow_width // 2, arrow_bottom[1] - arrowhead_size),
|
||||
],
|
||||
fill="red",
|
||||
width=5,
|
||||
)
|
||||
else:
|
||||
draw.line(
|
||||
[
|
||||
(center_x - arrow_width // 2, arrow_bottom[1] + arrowhead_size),
|
||||
arrow_bottom,
|
||||
(center_x + arrow_width // 2, arrow_bottom[1] + arrowhead_size),
|
||||
],
|
||||
fill="red",
|
||||
width=5,
|
||||
)
|
||||
|
||||
return draw_img
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing scroll: {str(e)}")
|
||||
# Return a blank image as fallback
|
||||
return Image.new("RGB", (800, 600), "white")
|
||||
|
||||
|
||||
def calculate_element_center(bbox: Dict[str, float], width: int, height: int) -> Tuple[int, int]:
|
||||
"""Calculate the center point of a UI element.
|
||||
|
||||
Args:
|
||||
bbox: Bounding box dictionary with x1, y1, x2, y2 coordinates (0-1 normalized)
|
||||
width: Screen width in pixels
|
||||
height: Screen height in pixels
|
||||
|
||||
Returns:
|
||||
(x, y) tuple with pixel coordinates
|
||||
"""
|
||||
center_x = int((bbox["x1"] + bbox["x2"]) / 2 * width)
|
||||
center_y = int((bbox["y1"] + bbox["y2"]) / 2 * height)
|
||||
return center_x, center_y
|
||||
|
||||
|
||||
class VisualizationHelper:
|
||||
"""Helper class for visualizing agent actions."""
|
||||
|
||||
def __init__(self, agent):
|
||||
"""Initialize visualization helper.
|
||||
|
||||
Args:
|
||||
agent: Reference to the agent that will use this helper
|
||||
"""
|
||||
self.agent = agent
|
||||
|
||||
def visualize_action(self, x: int, y: int, img_base64: str) -> None:
|
||||
"""Visualize a click action by drawing on the screenshot."""
|
||||
if (
|
||||
not self.agent.save_trajectory
|
||||
or not hasattr(self.agent, "experiment_manager")
|
||||
or not self.agent.experiment_manager
|
||||
):
|
||||
return
|
||||
|
||||
try:
|
||||
# Use the visualization utility
|
||||
img = visualize_click(x, y, img_base64)
|
||||
|
||||
# Save the visualization
|
||||
self.agent.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing action: {str(e)}")
|
||||
|
||||
def visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None:
|
||||
"""Visualize a scroll action by drawing arrows on the screenshot."""
|
||||
if (
|
||||
not self.agent.save_trajectory
|
||||
or not hasattr(self.agent, "experiment_manager")
|
||||
or not self.agent.experiment_manager
|
||||
):
|
||||
return
|
||||
|
||||
try:
|
||||
# Use the visualization utility
|
||||
img = visualize_scroll(direction, clicks, img_base64)
|
||||
|
||||
# Save the visualization
|
||||
self.agent.experiment_manager.save_action_visualization(
|
||||
img, "scroll", f"{direction}_{clicks}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing scroll: {str(e)}")
|
||||
|
||||
def save_action_visualization(
|
||||
self, img: Image.Image, action_name: str, details: str = ""
|
||||
) -> str:
|
||||
"""Save a visualization of an action."""
|
||||
if hasattr(self.agent, "experiment_manager") and self.agent.experiment_manager:
|
||||
return self.agent.experiment_manager.save_action_visualization(
|
||||
img, action_name, details
|
||||
)
|
||||
return ""
|
||||
141
libs/agent/agent/providers/anthropic/api_handler.py
Normal file
141
libs/agent/agent/providers/anthropic/api_handler.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""API call handling for Anthropic provider."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
from httpx import ConnectError, ReadTimeout
|
||||
|
||||
from anthropic.types.beta import (
|
||||
BetaMessage,
|
||||
BetaMessageParam,
|
||||
BetaTextBlockParam,
|
||||
)
|
||||
|
||||
from .types import LLMProvider
|
||||
from .prompts import SYSTEM_PROMPT
|
||||
|
||||
# Constants
|
||||
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
|
||||
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicAPIHandler:
|
||||
"""Handles API calls to Anthropic's API with structured error handling and retries."""
|
||||
|
||||
def __init__(self, loop):
|
||||
"""Initialize the API handler.
|
||||
|
||||
Args:
|
||||
loop: Reference to the parent loop instance that provides context
|
||||
"""
|
||||
self.loop = loop
|
||||
|
||||
async def make_api_call(
|
||||
self, messages: List[BetaMessageParam], system_prompt: str = SYSTEM_PROMPT
|
||||
) -> BetaMessage:
|
||||
"""Make API call to Anthropic with retry logic.
|
||||
|
||||
Args:
|
||||
messages: List of messages to send to the API
|
||||
system_prompt: System prompt to use (default: SYSTEM_PROMPT)
|
||||
|
||||
Returns:
|
||||
API response
|
||||
|
||||
Raises:
|
||||
RuntimeError: If API call fails after all retries
|
||||
"""
|
||||
if self.loop.client is None:
|
||||
raise RuntimeError("Client not initialized. Call initialize_client() first.")
|
||||
if self.loop.tool_manager is None:
|
||||
raise RuntimeError("Tool manager not initialized. Call initialize_client() first.")
|
||||
|
||||
last_error = None
|
||||
|
||||
# Add detailed debug logging to examine messages
|
||||
logger.info(f"Sending {len(messages)} messages to Anthropic API")
|
||||
|
||||
# Log tool use IDs and tool result IDs for debugging
|
||||
tool_use_ids = set()
|
||||
tool_result_ids = set()
|
||||
|
||||
for i, msg in enumerate(messages):
|
||||
logger.info(f"Message {i}: role={msg.get('role')}")
|
||||
if isinstance(msg.get("content"), list):
|
||||
for content_block in msg.get("content", []):
|
||||
if isinstance(content_block, dict):
|
||||
block_type = content_block.get("type")
|
||||
if block_type == "tool_use" and "id" in content_block:
|
||||
tool_id = content_block.get("id")
|
||||
tool_use_ids.add(tool_id)
|
||||
logger.info(f" - Found tool_use with ID: {tool_id}")
|
||||
elif block_type == "tool_result" and "tool_use_id" in content_block:
|
||||
result_id = content_block.get("tool_use_id")
|
||||
tool_result_ids.add(result_id)
|
||||
logger.info(f" - Found tool_result referencing ID: {result_id}")
|
||||
|
||||
# Check for mismatches
|
||||
missing_tool_uses = tool_result_ids - tool_use_ids
|
||||
if missing_tool_uses:
|
||||
logger.warning(
|
||||
f"Found tool_result IDs without matching tool_use IDs: {missing_tool_uses}"
|
||||
)
|
||||
|
||||
for attempt in range(self.loop.max_retries):
|
||||
try:
|
||||
# Log request
|
||||
request_data = {
|
||||
"messages": messages,
|
||||
"max_tokens": self.loop.max_tokens,
|
||||
"system": system_prompt,
|
||||
}
|
||||
# Let ExperimentManager handle sanitization
|
||||
self.loop._log_api_call("request", request_data)
|
||||
|
||||
# Setup betas and system
|
||||
system = BetaTextBlockParam(
|
||||
type="text",
|
||||
text=system_prompt,
|
||||
)
|
||||
|
||||
betas = [COMPUTER_USE_BETA_FLAG]
|
||||
# Add prompt caching if enabled in the message manager's config
|
||||
if self.loop.message_manager.config.enable_caching:
|
||||
betas.append(PROMPT_CACHING_BETA_FLAG)
|
||||
system["cache_control"] = {"type": "ephemeral"}
|
||||
|
||||
# Make API call
|
||||
response = await self.loop.client.create_message(
|
||||
messages=messages,
|
||||
system=[system],
|
||||
tools=self.loop.tool_manager.get_tool_params(),
|
||||
max_tokens=self.loop.max_tokens,
|
||||
betas=betas,
|
||||
)
|
||||
|
||||
# Let ExperimentManager handle sanitization
|
||||
self.loop._log_api_call("response", request_data, response)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.error(
|
||||
f"Error in API call (attempt {attempt + 1}/{self.loop.max_retries}): {str(e)}"
|
||||
)
|
||||
self.loop._log_api_call("error", {"messages": messages}, error=e)
|
||||
|
||||
if attempt < self.loop.max_retries - 1:
|
||||
await asyncio.sleep(
|
||||
self.loop.retry_delay * (attempt + 1)
|
||||
) # Exponential backoff
|
||||
continue
|
||||
|
||||
# If we get here, all retries failed
|
||||
error_message = f"API call failed after {self.loop.max_retries} attempts"
|
||||
if last_error:
|
||||
error_message += f": {str(last_error)}"
|
||||
|
||||
logger.error(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Anthropic callbacks package."""
|
||||
|
||||
from .manager import CallbackManager
|
||||
|
||||
__all__ = ["CallbackManager"]
|
||||
@@ -2,40 +2,35 @@
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, cast
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from httpx import ConnectError, ReadTimeout
|
||||
|
||||
# Anthropic-specific imports
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types.beta import (
|
||||
BetaMessage,
|
||||
BetaMessageParam,
|
||||
BetaTextBlock,
|
||||
BetaTextBlockParam,
|
||||
BetaToolUseBlockParam,
|
||||
BetaContentBlockParam,
|
||||
)
|
||||
import base64
|
||||
from datetime import datetime
|
||||
|
||||
# Computer
|
||||
from computer import Computer
|
||||
|
||||
# Base imports
|
||||
from ...core.loop import BaseLoop
|
||||
from ...core.messages import ImageRetentionConfig as CoreImageRetentionConfig
|
||||
from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
||||
|
||||
# Anthropic provider-specific imports
|
||||
from .api.client import AnthropicClientFactory, BaseAnthropicClient
|
||||
from .tools.manager import ToolManager
|
||||
from .messages.manager import MessageManager, ImageRetentionConfig
|
||||
from .callbacks.manager import CallbackManager
|
||||
from .prompts import SYSTEM_PROMPT
|
||||
from .types import LLMProvider
|
||||
from .tools import ToolResult
|
||||
|
||||
# Import the new modules we created
|
||||
from .api_handler import AnthropicAPIHandler
|
||||
from .response_handler import AnthropicResponseHandler
|
||||
from .callbacks.manager import CallbackManager
|
||||
|
||||
# Constants
|
||||
COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24"
|
||||
PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31"
|
||||
@@ -44,13 +39,22 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicLoop(BaseLoop):
|
||||
"""Anthropic-specific implementation of the agent loop."""
|
||||
"""Anthropic-specific implementation of the agent loop.
|
||||
|
||||
This class extends BaseLoop to provide specialized support for Anthropic's Claude models
|
||||
with their unique tool-use capabilities, custom message formatting, and
|
||||
callback-driven approach to handling responses.
|
||||
"""
|
||||
|
||||
###########################################
|
||||
# INITIALIZATION AND CONFIGURATION
|
||||
###########################################
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
computer: Computer,
|
||||
model: str = "claude-3-7-sonnet-20250219", # Fixed model
|
||||
model: str = "claude-3-7-sonnet-20250219",
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
max_retries: int = 3,
|
||||
@@ -83,27 +87,37 @@ class AnthropicLoop(BaseLoop):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Ensure model is always the fixed one
|
||||
self.model = "claude-3-7-sonnet-20250219"
|
||||
|
||||
# Anthropic-specific attributes
|
||||
self.provider = LLMProvider.ANTHROPIC
|
||||
self.client = None
|
||||
self.retry_count = 0
|
||||
self.tool_manager = None
|
||||
self.message_manager = None
|
||||
self.callback_manager = None
|
||||
|
||||
# Configure image retention with core config
|
||||
self.image_retention_config = CoreImageRetentionConfig(
|
||||
num_images_to_keep=only_n_most_recent_images
|
||||
# Initialize standard message manager with image retention config
|
||||
self.message_manager = StandardMessageManager(
|
||||
config=ImageRetentionConfig(
|
||||
num_images_to_keep=only_n_most_recent_images, enable_caching=True
|
||||
)
|
||||
)
|
||||
|
||||
# Message history
|
||||
# Message history (standard OpenAI format)
|
||||
self.message_history = []
|
||||
|
||||
# Initialize handlers
|
||||
self.api_handler = AnthropicAPIHandler(self)
|
||||
self.response_handler = AnthropicResponseHandler(self)
|
||||
|
||||
###########################################
|
||||
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def initialize_client(self) -> None:
|
||||
"""Initialize the Anthropic API client and tools."""
|
||||
"""Initialize the Anthropic API client and tools.
|
||||
|
||||
Implements abstract method from BaseLoop to set up the Anthropic-specific
|
||||
client, tool manager, message manager, and callback handlers.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Initializing Anthropic client with model {self.model}...")
|
||||
|
||||
@@ -112,14 +126,7 @@ class AnthropicLoop(BaseLoop):
|
||||
provider=self.provider, api_key=self.api_key, model=self.model
|
||||
)
|
||||
|
||||
# Initialize message manager
|
||||
self.message_manager = MessageManager(
|
||||
image_retention_config=ImageRetentionConfig(
|
||||
num_images_to_keep=self.only_n_most_recent_images, enable_caching=True
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize callback manager
|
||||
# Initialize callback manager with our callback handlers
|
||||
self.callback_manager = CallbackManager(
|
||||
content_callback=self._handle_content,
|
||||
tool_callback=self._handle_tool_result,
|
||||
@@ -136,51 +143,18 @@ class AnthropicLoop(BaseLoop):
|
||||
self.client = None
|
||||
raise RuntimeError(f"Failed to initialize Anthropic client: {str(e)}")
|
||||
|
||||
async def _process_screen(
|
||||
self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
"""Process screen information and add to messages.
|
||||
|
||||
Args:
|
||||
parsed_screen: Dictionary containing parsed screen info
|
||||
messages: List of messages to update
|
||||
"""
|
||||
try:
|
||||
# Extract screenshot from parsed screen
|
||||
screenshot_base64 = parsed_screen.get("screenshot_base64")
|
||||
|
||||
if screenshot_base64:
|
||||
# Remove data URL prefix if present
|
||||
if "," in screenshot_base64:
|
||||
screenshot_base64 = screenshot_base64.split(",")[1]
|
||||
|
||||
# Create Anthropic-compatible message with image
|
||||
screen_info_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": screenshot_base64,
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Add screen info message to messages
|
||||
messages.append(screen_info_msg)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing screen info: {str(e)}")
|
||||
raise
|
||||
###########################################
|
||||
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Implements abstract method from BaseLoop to handle the main agent loop
|
||||
for the AnthropicLoop implementation, using async queues and callbacks.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
messages: List of message objects in standard OpenAI format
|
||||
|
||||
Yields:
|
||||
Dict containing response data
|
||||
@@ -188,7 +162,7 @@ class AnthropicLoop(BaseLoop):
|
||||
try:
|
||||
logger.info("Starting Anthropic loop run")
|
||||
|
||||
# Reset message history and add new messages
|
||||
# Reset message history and add new messages in standard format
|
||||
self.message_history = []
|
||||
self.message_history.extend(messages)
|
||||
|
||||
@@ -236,6 +210,10 @@ class AnthropicLoop(BaseLoop):
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
|
||||
###########################################
|
||||
# AGENT LOOP IMPLEMENTATION
|
||||
###########################################
|
||||
|
||||
async def _run_loop(self, queue: asyncio.Queue) -> None:
|
||||
"""Run the agent loop with current message history.
|
||||
|
||||
@@ -244,31 +222,65 @@ class AnthropicLoop(BaseLoop):
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
# Get up-to-date screen information
|
||||
parsed_screen = await self._get_parsed_screen_som()
|
||||
# Capture screenshot
|
||||
try:
|
||||
# Take screenshot - always returns raw PNG bytes
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
|
||||
# Process screen info and update messages
|
||||
await self._process_screen(parsed_screen, self.message_history)
|
||||
# Convert PNG bytes to base64
|
||||
base64_image = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
# Prepare messages and make API call
|
||||
if self.message_manager is None:
|
||||
raise RuntimeError(
|
||||
"Message manager not initialized. Call initialize_client() first."
|
||||
)
|
||||
prepared_messages = self.message_manager.prepare_messages(
|
||||
cast(List[BetaMessageParam], self.message_history.copy())
|
||||
)
|
||||
# Save screenshot if requested
|
||||
if self.save_trajectory and self.experiment_manager:
|
||||
try:
|
||||
self._save_screenshot(base64_image, action_type="state")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving screenshot: {str(e)}")
|
||||
|
||||
# Add screenshot to message history in OpenAI format
|
||||
screen_info_msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{base64_image}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
self.message_history.append(screen_info_msg)
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing or processing screenshot: {str(e)}")
|
||||
raise
|
||||
|
||||
# Create new turn directory for this API call
|
||||
self._create_turn_dir()
|
||||
|
||||
# Use _make_api_call instead of direct client call to ensure logging
|
||||
response = await self._make_api_call(prepared_messages)
|
||||
# Convert standard messages to Anthropic format
|
||||
anthropic_messages, system_content = self.message_manager.to_anthropic_format(
|
||||
self.message_history.copy()
|
||||
)
|
||||
|
||||
# Handle the response
|
||||
if not await self._handle_response(response, self.message_history):
|
||||
# Use API handler to make API call with Anthropic format
|
||||
response = await self.api_handler.make_api_call(
|
||||
messages=cast(List[BetaMessageParam], anthropic_messages),
|
||||
system_prompt=system_content or SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
# Use response handler to handle the response and convert to standard format
|
||||
# This adds the response to message_history
|
||||
if not await self.response_handler.handle_response(response, self.message_history):
|
||||
break
|
||||
|
||||
# Get the last assistant message and convert it to OpenAI computer use format
|
||||
for msg in reversed(self.message_history):
|
||||
if msg["role"] == "assistant":
|
||||
# Create OpenAI-compatible response and add to queue
|
||||
openai_compatible_response = self._create_openai_compatible_response(
|
||||
msg, response
|
||||
)
|
||||
await queue.put(openai_compatible_response)
|
||||
break
|
||||
|
||||
# Signal completion
|
||||
await queue.put(None)
|
||||
|
||||
@@ -283,98 +295,128 @@ class AnthropicLoop(BaseLoop):
|
||||
)
|
||||
await queue.put(None)
|
||||
|
||||
async def _make_api_call(self, messages: List[BetaMessageParam]) -> BetaMessage:
|
||||
"""Make API call to Anthropic with retry logic.
|
||||
def _create_openai_compatible_response(
|
||||
self, assistant_msg: Dict[str, Any], original_response: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""Create an OpenAI computer use agent compatible response format.
|
||||
|
||||
Args:
|
||||
messages: List of messages to send to the API
|
||||
assistant_msg: The assistant message in standard OpenAI format
|
||||
original_response: The original API response object for ID generation
|
||||
|
||||
Returns:
|
||||
API response
|
||||
A response formatted according to OpenAI's computer use agent standard
|
||||
"""
|
||||
if self.client is None:
|
||||
raise RuntimeError("Client not initialized. Call initialize_client() first.")
|
||||
if self.tool_manager is None:
|
||||
raise RuntimeError("Tool manager not initialized. Call initialize_client() first.")
|
||||
# Create a unique ID for this response
|
||||
response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(original_response)}"
|
||||
reasoning_id = f"rs_{response_id}"
|
||||
action_id = f"cu_{response_id}"
|
||||
call_id = f"call_{response_id}"
|
||||
|
||||
last_error = None
|
||||
# Extract reasoning and action details from the response
|
||||
content = assistant_msg["content"]
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
try:
|
||||
# Log request
|
||||
request_data = {
|
||||
"messages": messages,
|
||||
"max_tokens": self.max_tokens,
|
||||
"system": SYSTEM_PROMPT,
|
||||
# Initialize output array
|
||||
output_items = []
|
||||
|
||||
# Add reasoning item if we have text content
|
||||
reasoning_text = None
|
||||
action_details = None
|
||||
|
||||
# AnthropicLoop expects a list of content blocks with type "text" or "tool_use"
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
reasoning_text = item.get("text", "")
|
||||
elif isinstance(item, dict) and item.get("type") == "tool_use":
|
||||
action_details = item
|
||||
else:
|
||||
# Fallback for string content
|
||||
reasoning_text = content if isinstance(content, str) else None
|
||||
|
||||
# If we have reasoning text, add reasoning item
|
||||
if reasoning_text:
|
||||
output_items.append(
|
||||
{
|
||||
"type": "reasoning",
|
||||
"id": reasoning_id,
|
||||
"summary": [
|
||||
{
|
||||
"type": "summary_text",
|
||||
"text": reasoning_text[:200], # Truncate to reasonable length
|
||||
}
|
||||
],
|
||||
}
|
||||
# Let ExperimentManager handle sanitization
|
||||
self._log_api_call("request", request_data)
|
||||
)
|
||||
|
||||
# Setup betas and system
|
||||
system = BetaTextBlockParam(
|
||||
type="text",
|
||||
text=SYSTEM_PROMPT,
|
||||
)
|
||||
# Add computer_call item with action details if available
|
||||
computer_call = {
|
||||
"type": "computer_call",
|
||||
"id": action_id,
|
||||
"call_id": call_id,
|
||||
"action": {"type": "click", "button": "left", "x": 100, "y": 100}, # Default action
|
||||
"pending_safety_checks": [],
|
||||
"status": "completed",
|
||||
}
|
||||
|
||||
betas = [COMPUTER_USE_BETA_FLAG]
|
||||
# Temporarily disable prompt caching due to "A maximum of 4 blocks with cache_control may be provided" error
|
||||
# if self.message_manager.image_retention_config.enable_caching:
|
||||
# betas.append(PROMPT_CACHING_BETA_FLAG)
|
||||
# system["cache_control"] = {"type": "ephemeral"}
|
||||
# If we have action details from a tool_use, update the computer_call
|
||||
if action_details:
|
||||
# Try to map tool_use to computer_call action
|
||||
tool_input = action_details.get("input", {})
|
||||
if "click" in tool_input or "position" in tool_input:
|
||||
position = tool_input.get("click", tool_input.get("position", {}))
|
||||
if isinstance(position, dict) and "x" in position and "y" in position:
|
||||
computer_call["action"] = {
|
||||
"type": "click",
|
||||
"button": "left",
|
||||
"x": position.get("x", 100),
|
||||
"y": position.get("y", 100),
|
||||
}
|
||||
elif "type" in tool_input or "text" in tool_input:
|
||||
computer_call["action"] = {
|
||||
"type": "type",
|
||||
"text": tool_input.get("type", tool_input.get("text", "")),
|
||||
}
|
||||
elif "scroll" in tool_input:
|
||||
scroll = tool_input.get("scroll", {})
|
||||
computer_call["action"] = {
|
||||
"type": "scroll",
|
||||
"x": 100,
|
||||
"y": 100,
|
||||
"scroll_x": scroll.get("x", 0),
|
||||
"scroll_y": scroll.get("y", 0),
|
||||
}
|
||||
|
||||
# Make API call
|
||||
response = await self.client.create_message(
|
||||
messages=messages,
|
||||
system=[system],
|
||||
tools=self.tool_manager.get_tool_params(),
|
||||
max_tokens=self.max_tokens,
|
||||
betas=betas,
|
||||
)
|
||||
output_items.append(computer_call)
|
||||
|
||||
# Let ExperimentManager handle sanitization
|
||||
self._log_api_call("response", request_data, response)
|
||||
# Create the OpenAI-compatible response format
|
||||
return {
|
||||
"output": output_items,
|
||||
"id": response_id,
|
||||
# Include the original format for backward compatibility
|
||||
"response": {"choices": [{"message": assistant_msg, "finish_reason": "stop"}]},
|
||||
}
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.error(
|
||||
f"Error in API call (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
|
||||
)
|
||||
self._log_api_call("error", {"messages": messages}, error=e)
|
||||
|
||||
if attempt < self.max_retries - 1:
|
||||
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
|
||||
continue
|
||||
|
||||
# If we get here, all retries failed
|
||||
error_message = f"API call failed after {self.max_retries} attempts"
|
||||
if last_error:
|
||||
error_message += f": {str(last_error)}"
|
||||
|
||||
logger.error(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
###########################################
|
||||
# RESPONSE AND CALLBACK HANDLING
|
||||
###########################################
|
||||
|
||||
async def _handle_response(self, response: BetaMessage, messages: List[Dict[str, Any]]) -> bool:
|
||||
"""Handle the Anthropic API response.
|
||||
|
||||
Args:
|
||||
response: API response
|
||||
messages: List of messages to update
|
||||
messages: List of messages to update in standard OpenAI format
|
||||
|
||||
Returns:
|
||||
True if the loop should continue, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Convert response to parameter format
|
||||
response_params = self._response_to_params(response)
|
||||
# Convert Anthropic response to standard OpenAI format
|
||||
response_blocks = self._response_to_blocks(response)
|
||||
|
||||
# Add response to messages
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": response_params,
|
||||
}
|
||||
)
|
||||
# Add response to standard message history
|
||||
messages.append({"role": "assistant", "content": response_blocks})
|
||||
|
||||
if self.callback_manager is None:
|
||||
raise RuntimeError(
|
||||
@@ -383,31 +425,33 @@ class AnthropicLoop(BaseLoop):
|
||||
|
||||
# Handle tool use blocks and collect results
|
||||
tool_result_content = []
|
||||
for content_block in response_params:
|
||||
for content_block in response.content:
|
||||
# Notify callback of content
|
||||
self.callback_manager.on_content(cast(BetaContentBlockParam, content_block))
|
||||
|
||||
# Handle tool use
|
||||
if content_block.get("type") == "tool_use":
|
||||
# Handle tool use - carefully check and access attributes
|
||||
if hasattr(content_block, "type") and content_block.type == "tool_use":
|
||||
if self.tool_manager is None:
|
||||
raise RuntimeError(
|
||||
"Tool manager not initialized. Call initialize_client() first."
|
||||
)
|
||||
|
||||
# Safely get attributes
|
||||
tool_name = getattr(content_block, "name", "")
|
||||
tool_input = getattr(content_block, "input", {})
|
||||
tool_id = getattr(content_block, "id", "")
|
||||
|
||||
result = await self.tool_manager.execute_tool(
|
||||
name=content_block["name"],
|
||||
tool_input=cast(Dict[str, Any], content_block["input"]),
|
||||
name=tool_name,
|
||||
tool_input=cast(Dict[str, Any], tool_input),
|
||||
)
|
||||
|
||||
# Create tool result and add to content
|
||||
tool_result = self._make_tool_result(
|
||||
cast(ToolResult, result), content_block["id"]
|
||||
)
|
||||
# Create tool result
|
||||
tool_result = self._make_tool_result(cast(ToolResult, result), tool_id)
|
||||
tool_result_content.append(tool_result)
|
||||
|
||||
# Notify callback of tool result
|
||||
self.callback_manager.on_tool_result(
|
||||
cast(ToolResult, result), content_block["id"]
|
||||
)
|
||||
self.callback_manager.on_tool_result(cast(ToolResult, result), tool_id)
|
||||
|
||||
# If no tool results, we're done
|
||||
if not tool_result_content:
|
||||
@@ -415,8 +459,8 @@ class AnthropicLoop(BaseLoop):
|
||||
self.callback_manager.on_content({"type": "text", "text": "<DONE>"})
|
||||
return False
|
||||
|
||||
# Add tool results to message history
|
||||
messages.append({"content": tool_result_content, "role": "user"})
|
||||
# Add tool results to message history in standard format
|
||||
messages.append({"role": "user", "content": tool_result_content})
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
@@ -429,28 +473,41 @@ class AnthropicLoop(BaseLoop):
|
||||
)
|
||||
return False
|
||||
|
||||
def _response_to_params(
|
||||
self,
|
||||
response: BetaMessage,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert API response to message parameters.
|
||||
def _response_to_blocks(self, response: BetaMessage) -> List[Dict[str, Any]]:
|
||||
"""Convert Anthropic API response to standard blocks format.
|
||||
|
||||
Args:
|
||||
response: API response message
|
||||
|
||||
Returns:
|
||||
List of content blocks
|
||||
List of content blocks in standard format
|
||||
"""
|
||||
result = []
|
||||
for block in response.content:
|
||||
if isinstance(block, BetaTextBlock):
|
||||
result.append({"type": "text", "text": block.text})
|
||||
elif hasattr(block, "type") and block.type == "tool_use":
|
||||
# Safely access attributes after confirming it's a tool_use
|
||||
result.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": getattr(block, "id", ""),
|
||||
"name": getattr(block, "name", ""),
|
||||
"input": getattr(block, "input", {}),
|
||||
}
|
||||
)
|
||||
else:
|
||||
result.append(cast(Dict[str, Any], block.model_dump()))
|
||||
# For other block types, convert to dict
|
||||
block_dict = {}
|
||||
for key, value in vars(block).items():
|
||||
if not key.startswith("_"):
|
||||
block_dict[key] = value
|
||||
result.append(block_dict)
|
||||
|
||||
return result
|
||||
|
||||
def _make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]:
|
||||
"""Convert a tool result to API format.
|
||||
"""Convert a tool result to standard format.
|
||||
|
||||
Args:
|
||||
result: Tool execution result
|
||||
@@ -489,12 +546,8 @@ class AnthropicLoop(BaseLoop):
|
||||
if result.base64_image:
|
||||
tool_result_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": result.base64_image,
|
||||
},
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{result.base64_image}"},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -519,16 +572,19 @@ class AnthropicLoop(BaseLoop):
|
||||
result_text = f"<s>{result.system}</s>\n{result_text}"
|
||||
return result_text
|
||||
|
||||
def _handle_content(self, content: BetaContentBlockParam) -> None:
|
||||
###########################################
|
||||
# CALLBACK HANDLERS
|
||||
###########################################
|
||||
|
||||
def _handle_content(self, content):
|
||||
"""Handle content updates from the assistant."""
|
||||
if content.get("type") == "text":
|
||||
text_content = cast(BetaTextBlockParam, content)
|
||||
text = text_content["text"]
|
||||
text = content.get("text", "")
|
||||
if text == "<DONE>":
|
||||
return
|
||||
logger.info(f"Assistant: {text}")
|
||||
|
||||
def _handle_tool_result(self, result: ToolResult, tool_id: str) -> None:
|
||||
def _handle_tool_result(self, result, tool_id):
|
||||
"""Handle tool execution results."""
|
||||
if result.error:
|
||||
logger.error(f"Tool {tool_id} error: {result.error}")
|
||||
|
||||
223
libs/agent/agent/providers/anthropic/response_handler.py
Normal file
223
libs/agent/agent/providers/anthropic/response_handler.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""Response and tool handling for Anthropic provider."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
|
||||
from anthropic.types.beta import (
|
||||
BetaMessage,
|
||||
BetaMessageParam,
|
||||
BetaTextBlock,
|
||||
BetaTextBlockParam,
|
||||
BetaToolUseBlockParam,
|
||||
BetaContentBlockParam,
|
||||
)
|
||||
|
||||
from .tools import ToolResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnthropicResponseHandler:
|
||||
"""Handles Anthropic API responses and tool execution results."""
|
||||
|
||||
def __init__(self, loop):
|
||||
"""Initialize the response handler.
|
||||
|
||||
Args:
|
||||
loop: Reference to the parent loop instance that provides context
|
||||
"""
|
||||
self.loop = loop
|
||||
|
||||
async def handle_response(self, response: BetaMessage, messages: List[Dict[str, Any]]) -> bool:
|
||||
"""Handle the Anthropic API response.
|
||||
|
||||
Args:
|
||||
response: API response
|
||||
messages: List of messages to update
|
||||
|
||||
Returns:
|
||||
True if the loop should continue, False otherwise
|
||||
"""
|
||||
try:
|
||||
# Convert response to parameter format
|
||||
response_params = self.response_to_params(response)
|
||||
|
||||
# Collect all existing tool_use IDs from previous messages for validation
|
||||
existing_tool_use_ids = set()
|
||||
for msg in messages:
|
||||
if msg.get("role") == "assistant" and isinstance(msg.get("content"), list):
|
||||
for block in msg.get("content", []):
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "tool_use"
|
||||
and "id" in block
|
||||
):
|
||||
existing_tool_use_ids.add(block["id"])
|
||||
|
||||
# Also add new tool_use IDs from the current response
|
||||
current_tool_use_ids = set()
|
||||
for block in response_params:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use" and "id" in block:
|
||||
current_tool_use_ids.add(block["id"])
|
||||
existing_tool_use_ids.add(block["id"])
|
||||
|
||||
logger.info(f"Existing tool_use IDs in conversation: {existing_tool_use_ids}")
|
||||
logger.info(f"New tool_use IDs in current response: {current_tool_use_ids}")
|
||||
|
||||
# Add response to messages
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": response_params,
|
||||
}
|
||||
)
|
||||
|
||||
if self.loop.callback_manager is None:
|
||||
raise RuntimeError(
|
||||
"Callback manager not initialized. Call initialize_client() first."
|
||||
)
|
||||
|
||||
# Handle tool use blocks and collect results
|
||||
tool_result_content = []
|
||||
for content_block in response_params:
|
||||
# Notify callback of content
|
||||
self.loop.callback_manager.on_content(cast(BetaContentBlockParam, content_block))
|
||||
|
||||
# Handle tool use
|
||||
if content_block.get("type") == "tool_use":
|
||||
if self.loop.tool_manager is None:
|
||||
raise RuntimeError(
|
||||
"Tool manager not initialized. Call initialize_client() first."
|
||||
)
|
||||
|
||||
# Execute the tool
|
||||
result = await self.loop.tool_manager.execute_tool(
|
||||
name=content_block["name"],
|
||||
tool_input=cast(Dict[str, Any], content_block["input"]),
|
||||
)
|
||||
|
||||
# Verify the tool_use ID exists in the conversation (which it should now)
|
||||
tool_use_id = content_block["id"]
|
||||
if tool_use_id in existing_tool_use_ids:
|
||||
# Create tool result and add to content
|
||||
tool_result = self.make_tool_result(cast(ToolResult, result), tool_use_id)
|
||||
tool_result_content.append(tool_result)
|
||||
|
||||
# Notify callback of tool result
|
||||
self.loop.callback_manager.on_tool_result(
|
||||
cast(ToolResult, result), content_block["id"]
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Tool use ID {tool_use_id} not found in previous messages. Skipping tool result."
|
||||
)
|
||||
|
||||
# If no tool results, we're done
|
||||
if not tool_result_content:
|
||||
# Signal completion
|
||||
self.loop.callback_manager.on_content({"type": "text", "text": "<DONE>"})
|
||||
return False
|
||||
|
||||
# Add tool results to message history
|
||||
messages.append({"content": tool_result_content, "role": "user"})
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling response: {str(e)}")
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
}
|
||||
)
|
||||
return False
|
||||
|
||||
def response_to_params(
|
||||
self,
|
||||
response: BetaMessage,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Convert API response to message parameters.
|
||||
|
||||
Args:
|
||||
response: API response message
|
||||
|
||||
Returns:
|
||||
List of content blocks
|
||||
"""
|
||||
result = []
|
||||
for block in response.content:
|
||||
if isinstance(block, BetaTextBlock):
|
||||
result.append({"type": "text", "text": block.text})
|
||||
else:
|
||||
result.append(cast(Dict[str, Any], block.model_dump()))
|
||||
return result
|
||||
|
||||
def make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]:
|
||||
"""Convert a tool result to API format.
|
||||
|
||||
Args:
|
||||
result: Tool execution result
|
||||
tool_use_id: ID of the tool use
|
||||
|
||||
Returns:
|
||||
Formatted tool result
|
||||
"""
|
||||
if result.content:
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"content": result.content,
|
||||
"tool_use_id": tool_use_id,
|
||||
"is_error": bool(result.error),
|
||||
}
|
||||
|
||||
tool_result_content = []
|
||||
is_error = False
|
||||
|
||||
if result.error:
|
||||
is_error = True
|
||||
tool_result_content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": self.maybe_prepend_system_tool_result(result, result.error),
|
||||
}
|
||||
]
|
||||
else:
|
||||
if result.output:
|
||||
tool_result_content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": self.maybe_prepend_system_tool_result(result, result.output),
|
||||
}
|
||||
)
|
||||
if result.base64_image:
|
||||
tool_result_content.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": result.base64_image,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"type": "tool_result",
|
||||
"content": tool_result_content,
|
||||
"tool_use_id": tool_use_id,
|
||||
"is_error": is_error,
|
||||
}
|
||||
|
||||
def maybe_prepend_system_tool_result(self, result: ToolResult, result_text: str) -> str:
|
||||
"""Prepend system information to tool result if available.
|
||||
|
||||
Args:
|
||||
result: Tool execution result
|
||||
result_text: Text to prepend to
|
||||
|
||||
Returns:
|
||||
Text with system information prepended if available
|
||||
"""
|
||||
if result.system:
|
||||
result_text = f"<s>{result.system}</s>\n{result_text}"
|
||||
return result_text
|
||||
@@ -7,101 +7,6 @@ from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
|
||||
from ....core.tools.bash import BaseBashTool
|
||||
|
||||
|
||||
class _BashSession:
|
||||
"""A session of a bash shell."""
|
||||
|
||||
_started: bool
|
||||
_process: asyncio.subprocess.Process
|
||||
|
||||
command: str = "/bin/bash"
|
||||
_output_delay: float = 0.2 # seconds
|
||||
_timeout: float = 120.0 # seconds
|
||||
_sentinel: str = "<<exit>>"
|
||||
|
||||
def __init__(self):
|
||||
self._started = False
|
||||
self._timed_out = False
|
||||
|
||||
async def start(self):
|
||||
if self._started:
|
||||
return
|
||||
|
||||
self._process = await asyncio.create_subprocess_shell(
|
||||
self.command,
|
||||
preexec_fn=os.setsid,
|
||||
shell=True,
|
||||
bufsize=0,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
self._started = True
|
||||
|
||||
def stop(self):
|
||||
"""Terminate the bash shell."""
|
||||
if not self._started:
|
||||
raise ToolError("Session has not started.")
|
||||
if self._process.returncode is not None:
|
||||
return
|
||||
self._process.terminate()
|
||||
|
||||
async def run(self, command: str):
|
||||
"""Execute a command in the bash shell."""
|
||||
if not self._started:
|
||||
raise ToolError("Session has not started.")
|
||||
if self._process.returncode is not None:
|
||||
return ToolResult(
|
||||
system="tool must be restarted",
|
||||
error=f"bash has exited with returncode {self._process.returncode}",
|
||||
)
|
||||
if self._timed_out:
|
||||
raise ToolError(
|
||||
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||
)
|
||||
|
||||
# we know these are not None because we created the process with PIPEs
|
||||
assert self._process.stdin
|
||||
assert self._process.stdout
|
||||
assert self._process.stderr
|
||||
|
||||
# send command to the process
|
||||
self._process.stdin.write(command.encode() + f"; echo '{self._sentinel}'\n".encode())
|
||||
await self._process.stdin.drain()
|
||||
|
||||
# read output from the process, until the sentinel is found
|
||||
try:
|
||||
async with asyncio.timeout(self._timeout):
|
||||
while True:
|
||||
await asyncio.sleep(self._output_delay)
|
||||
# Read from stdout using the proper API
|
||||
output_bytes = await self._process.stdout.read()
|
||||
if output_bytes:
|
||||
output = output_bytes.decode()
|
||||
if self._sentinel in output:
|
||||
# strip the sentinel and break
|
||||
output = output[: output.index(self._sentinel)]
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
self._timed_out = True
|
||||
raise ToolError(
|
||||
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||
) from None
|
||||
|
||||
if output and output.endswith("\n"):
|
||||
output = output[:-1]
|
||||
|
||||
# Read from stderr using the proper API
|
||||
error_bytes = await self._process.stderr.read()
|
||||
error = error_bytes.decode() if error_bytes else ""
|
||||
if error and error.endswith("\n"):
|
||||
error = error[:-1]
|
||||
|
||||
# No need to clear buffers as we're using read() which consumes the data
|
||||
|
||||
return CLIResult(output=output, error=error)
|
||||
|
||||
|
||||
class BashTool(BaseBashTool, BaseAnthropicTool):
|
||||
"""
|
||||
A tool that allows the agent to run bash commands.
|
||||
@@ -123,7 +28,6 @@ class BashTool(BaseBashTool, BaseAnthropicTool):
|
||||
# Then initialize the Anthropic tool
|
||||
BaseAnthropicTool.__init__(self)
|
||||
# Initialize bash session
|
||||
self._session = _BashSession()
|
||||
|
||||
async def __call__(self, command: str | None = None, restart: bool = False, **kwargs):
|
||||
"""Execute a bash command.
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
# The OmniComputerAgent has been replaced by the unified ComputerAgent
|
||||
# which can be found in agent.core.agent
|
||||
from .types import LLMProvider
|
||||
from .experiment import ExperimentManager
|
||||
from .visualization import visualize_click, visualize_scroll, calculate_element_center
|
||||
from .image_utils import (
|
||||
decode_base64_image,
|
||||
encode_image_base64,
|
||||
@@ -15,10 +13,6 @@ from .image_utils import (
|
||||
|
||||
__all__ = [
|
||||
"LLMProvider",
|
||||
"ExperimentManager",
|
||||
"visualize_click",
|
||||
"visualize_scroll",
|
||||
"calculate_element_center",
|
||||
"decode_base64_image",
|
||||
"encode_image_base64",
|
||||
"clean_base64_data",
|
||||
|
||||
264
libs/agent/agent/providers/omni/action_executor.py
Normal file
264
libs/agent/agent/providers/omni/action_executor.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""Action execution for the Omni agent."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Tuple
|
||||
import json
|
||||
|
||||
from .parser import ParseResult
|
||||
from ...core.visualization import calculate_element_center
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionExecutor:
|
||||
"""Executes UI actions based on model instructions."""
|
||||
|
||||
def __init__(self, loop):
|
||||
"""Initialize the action executor.
|
||||
|
||||
Args:
|
||||
loop: Reference to the parent loop instance that provides context
|
||||
"""
|
||||
self.loop = loop
|
||||
|
||||
async def execute_action(self, content: Dict[str, Any], parsed_screen: ParseResult) -> bool:
|
||||
"""Execute the action specified in the content.
|
||||
|
||||
Args:
|
||||
content: Dictionary containing the action details
|
||||
parsed_screen: Current parsed screen information
|
||||
|
||||
Returns:
|
||||
Whether an action-specific screenshot was saved
|
||||
"""
|
||||
try:
|
||||
action = content.get("Action", "").lower()
|
||||
if not action:
|
||||
return False
|
||||
|
||||
# Track if we saved an action-specific screenshot
|
||||
action_screenshot_saved = False
|
||||
|
||||
try:
|
||||
# Prepare kwargs based on action type
|
||||
kwargs = {}
|
||||
|
||||
if action in ["left_click", "right_click", "double_click", "move_cursor"]:
|
||||
try:
|
||||
box_id = int(content["Box ID"])
|
||||
logger.info(f"Processing Box ID: {box_id}")
|
||||
|
||||
# Calculate click coordinates
|
||||
x, y = await self.calculate_click_coordinates(box_id, parsed_screen)
|
||||
logger.info(f"Calculated coordinates: x={x}, y={y}")
|
||||
|
||||
kwargs["x"] = x
|
||||
kwargs["y"] = y
|
||||
|
||||
# Visualize action if screenshot is available
|
||||
if parsed_screen.annotated_image_base64:
|
||||
img_data = parsed_screen.annotated_image_base64
|
||||
# Remove data URL prefix if present
|
||||
if img_data.startswith("data:image"):
|
||||
img_data = img_data.split(",")[1]
|
||||
# Only save visualization for coordinate-based actions
|
||||
self.loop.viz_helper.visualize_action(x, y, img_data)
|
||||
action_screenshot_saved = True
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Error processing Box ID: {str(e)}")
|
||||
return False
|
||||
|
||||
elif action == "drag_to":
|
||||
try:
|
||||
box_id = int(content["Box ID"])
|
||||
x, y = await self.calculate_click_coordinates(box_id, parsed_screen)
|
||||
kwargs.update(
|
||||
{
|
||||
"x": x,
|
||||
"y": y,
|
||||
"button": content.get("button", "left"),
|
||||
"duration": float(content.get("duration", 0.5)),
|
||||
}
|
||||
)
|
||||
|
||||
# Visualize drag destination if screenshot is available
|
||||
if parsed_screen.annotated_image_base64:
|
||||
img_data = parsed_screen.annotated_image_base64
|
||||
# Remove data URL prefix if present
|
||||
if img_data.startswith("data:image"):
|
||||
img_data = img_data.split(",")[1]
|
||||
# Only save visualization for coordinate-based actions
|
||||
self.loop.viz_helper.visualize_action(x, y, img_data)
|
||||
action_screenshot_saved = True
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Error processing drag coordinates: {str(e)}")
|
||||
return False
|
||||
|
||||
elif action == "type_text":
|
||||
kwargs["text"] = content["Value"]
|
||||
# For type_text, store the value in the action type
|
||||
action_type = f"type_{content['Value'][:20]}" # Truncate if too long
|
||||
elif action == "press_key":
|
||||
kwargs["key"] = content["Value"]
|
||||
action_type = f"press_{content['Value']}"
|
||||
elif action == "hotkey":
|
||||
if isinstance(content.get("Value"), list):
|
||||
keys = content["Value"]
|
||||
action_type = f"hotkey_{'_'.join(keys)}"
|
||||
else:
|
||||
# Simply split string format like "command+space" into a list
|
||||
keys = [k.strip() for k in content["Value"].lower().split("+")]
|
||||
action_type = f"hotkey_{content['Value'].replace('+', '_')}"
|
||||
logger.info(f"Preparing hotkey with keys: {keys}")
|
||||
# Get the method but call it with *args instead of **kwargs
|
||||
method = getattr(self.loop.computer.interface, action)
|
||||
await method(*keys) # Unpack the keys list as positional arguments
|
||||
logger.info(f"Tool execution completed successfully: {action}")
|
||||
|
||||
# For hotkeys, take a screenshot after the action
|
||||
try:
|
||||
# Get a new screenshot after the action and save it with the action type
|
||||
new_parsed_screen = await self.loop._get_parsed_screen_som(
|
||||
save_screenshot=False
|
||||
)
|
||||
if new_parsed_screen and new_parsed_screen.annotated_image_base64:
|
||||
img_data = new_parsed_screen.annotated_image_base64
|
||||
# Remove data URL prefix if present
|
||||
if img_data.startswith("data:image"):
|
||||
img_data = img_data.split(",")[1]
|
||||
# Save with action type to indicate this is a post-action screenshot
|
||||
self.loop._save_screenshot(img_data, action_type=action_type)
|
||||
action_screenshot_saved = True
|
||||
except Exception as screenshot_error:
|
||||
logger.error(
|
||||
f"Error taking post-hotkey screenshot: {str(screenshot_error)}"
|
||||
)
|
||||
|
||||
return action_screenshot_saved
|
||||
|
||||
elif action in ["scroll_down", "scroll_up"]:
|
||||
clicks = int(content.get("amount", 1))
|
||||
kwargs["clicks"] = clicks
|
||||
action_type = f"scroll_{action.split('_')[1]}_{clicks}"
|
||||
|
||||
# Visualize scrolling if screenshot is available
|
||||
if parsed_screen.annotated_image_base64:
|
||||
img_data = parsed_screen.annotated_image_base64
|
||||
# Remove data URL prefix if present
|
||||
if img_data.startswith("data:image"):
|
||||
img_data = img_data.split(",")[1]
|
||||
direction = "down" if action == "scroll_down" else "up"
|
||||
# For scrolling, we only save the visualization to avoid duplicate images
|
||||
self.loop.viz_helper.visualize_scroll(direction, clicks, img_data)
|
||||
action_screenshot_saved = True
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown action: {action}")
|
||||
return False
|
||||
|
||||
# Execute tool and handle result
|
||||
try:
|
||||
method = getattr(self.loop.computer.interface, action)
|
||||
logger.info(f"Found method for action '{action}': {method}")
|
||||
await method(**kwargs)
|
||||
logger.info(f"Tool execution completed successfully: {action}")
|
||||
|
||||
# For non-coordinate based actions that don't already have visualizations,
|
||||
# take a new screenshot after the action
|
||||
if not action_screenshot_saved:
|
||||
# Take a new screenshot
|
||||
try:
|
||||
# Get a new screenshot after the action and save it with the action type
|
||||
new_parsed_screen = await self.loop._get_parsed_screen_som(
|
||||
save_screenshot=False
|
||||
)
|
||||
if new_parsed_screen and new_parsed_screen.annotated_image_base64:
|
||||
img_data = new_parsed_screen.annotated_image_base64
|
||||
# Remove data URL prefix if present
|
||||
if img_data.startswith("data:image"):
|
||||
img_data = img_data.split(",")[1]
|
||||
# Save with action type to indicate this is a post-action screenshot
|
||||
if "action_type" in locals():
|
||||
self.loop._save_screenshot(img_data, action_type=action_type)
|
||||
else:
|
||||
self.loop._save_screenshot(img_data, action_type=action)
|
||||
# Update the action screenshot flag for this turn
|
||||
action_screenshot_saved = True
|
||||
except Exception as screenshot_error:
|
||||
logger.error(
|
||||
f"Error taking post-action screenshot: {str(screenshot_error)}"
|
||||
)
|
||||
|
||||
except AttributeError as e:
|
||||
logger.error(f"Method not found for action '{action}': {str(e)}")
|
||||
return False
|
||||
except Exception as tool_error:
|
||||
logger.error(f"Tool execution failed: {str(tool_error)}")
|
||||
return False
|
||||
|
||||
return action_screenshot_saved
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing action {action}: {str(e)}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in execute_action: {str(e)}")
|
||||
return False
|
||||
|
||||
async def calculate_click_coordinates(
|
||||
self, box_id: int, parsed_screen: ParseResult
|
||||
) -> Tuple[int, int]:
|
||||
"""Calculate click coordinates based on box ID.
|
||||
|
||||
Args:
|
||||
box_id: The ID of the box to click
|
||||
parsed_screen: The parsed screen information
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates
|
||||
|
||||
Raises:
|
||||
ValueError: If box_id is invalid or missing from parsed screen
|
||||
"""
|
||||
# First try to use structured elements data
|
||||
logger.info(f"Elements count: {len(parsed_screen.elements)}")
|
||||
|
||||
# Try to find element with matching ID
|
||||
for element in parsed_screen.elements:
|
||||
if element.id == box_id:
|
||||
logger.info(f"Found element with ID {box_id}: {element}")
|
||||
bbox = element.bbox
|
||||
|
||||
# Get screen dimensions from the metadata if available, or fallback
|
||||
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
|
||||
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
|
||||
logger.info(f"Screen dimensions: width={width}, height={height}")
|
||||
|
||||
# Create a dictionary from the element's bbox for calculate_element_center
|
||||
bbox_dict = {"x1": bbox.x1, "y1": bbox.y1, "x2": bbox.x2, "y2": bbox.y2}
|
||||
center_x, center_y = calculate_element_center(bbox_dict, width, height)
|
||||
logger.info(f"Calculated center: ({center_x}, {center_y})")
|
||||
|
||||
# Validate coordinates - if they're (0,0) or unreasonably small,
|
||||
# use a default position in the center of the screen
|
||||
if center_x == 0 and center_y == 0:
|
||||
logger.warning("Got (0,0) coordinates, using fallback position")
|
||||
center_x = width // 2
|
||||
center_y = height // 2
|
||||
logger.info(f"Using fallback center: ({center_x}, {center_y})")
|
||||
|
||||
return center_x, center_y
|
||||
|
||||
# If we couldn't find the box, use center of screen
|
||||
logger.error(
|
||||
f"Box ID {box_id} not found in structured elements (count={len(parsed_screen.elements)})"
|
||||
)
|
||||
|
||||
# Use center of screen as fallback
|
||||
width = parsed_screen.metadata.width if parsed_screen.metadata else 1920
|
||||
height = parsed_screen.metadata.height if parsed_screen.metadata else 1080
|
||||
logger.warning(f"Using fallback position in center of screen ({width//2}, {height//2})")
|
||||
return width // 2, height // 2
|
||||
42
libs/agent/agent/providers/omni/api_handler.py
Normal file
42
libs/agent/agent/providers/omni/api_handler.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""API handling for Omni provider."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .prompts import SYSTEM_PROMPT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OmniAPIHandler:
|
||||
"""Handler for Omni API calls."""
|
||||
|
||||
def __init__(self, loop):
|
||||
"""Initialize the API handler.
|
||||
|
||||
Args:
|
||||
loop: Parent loop instance
|
||||
"""
|
||||
self.loop = loop
|
||||
|
||||
async def make_api_call(
|
||||
self, messages: List[Dict[str, Any]], system_prompt: str = SYSTEM_PROMPT
|
||||
) -> Any:
|
||||
"""Make an API call to the appropriate provider.
|
||||
|
||||
Args:
|
||||
messages: List of messages in standard OpenAI format
|
||||
system_prompt: System prompt to use
|
||||
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
if not self.loop._make_api_call:
|
||||
raise RuntimeError("Loop does not have _make_api_call method")
|
||||
|
||||
try:
|
||||
# Use the loop's _make_api_call method with standard messages
|
||||
return await self.loop._make_api_call(messages=messages, system_prompt=system_prompt)
|
||||
except Exception as e:
|
||||
logger.error(f"Error making API call: {str(e)}")
|
||||
raise
|
||||
@@ -44,6 +44,10 @@ class AnthropicClient(BaseOmniClient):
|
||||
anthropic_messages = []
|
||||
|
||||
for message in messages:
|
||||
# Skip messages with empty content
|
||||
if not message.get("content"):
|
||||
continue
|
||||
|
||||
if message["role"] == "user":
|
||||
anthropic_messages.append({"role": "user", "content": message["content"]})
|
||||
elif message["role"] == "assistant":
|
||||
|
||||
@@ -1,101 +0,0 @@
|
||||
"""Groq client implementation."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
|
||||
from groq import Groq
|
||||
import re
|
||||
from .utils import is_image_path
|
||||
from .base import BaseOmniClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GroqClient(BaseOmniClient):
|
||||
"""Client for making Groq API calls."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "deepseek-r1-distill-llama-70b",
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.6,
|
||||
):
|
||||
"""Initialize Groq client.
|
||||
|
||||
Args:
|
||||
api_key: Groq API key (if not provided, will try to get from env)
|
||||
model: Model name to use
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Temperature for sampling
|
||||
"""
|
||||
super().__init__(api_key=api_key, model=model)
|
||||
self.api_key = api_key or os.getenv("GROQ_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("No Groq API key provided")
|
||||
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
self.client = Groq(api_key=self.api_key)
|
||||
self.model: str = model # Add explicit type annotation
|
||||
|
||||
def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
) -> tuple[str, int]:
|
||||
"""Run interleaved chat completion.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
system: System prompt
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
Tuple of (response text, token usage)
|
||||
"""
|
||||
# Avoid using system messages for R1
|
||||
final_messages = [{"role": "user", "content": system}]
|
||||
|
||||
# Process messages
|
||||
if isinstance(messages, list):
|
||||
for item in messages:
|
||||
if isinstance(item, dict):
|
||||
# For dict items, concatenate all text content, ignoring images
|
||||
text_contents = []
|
||||
for cnt in item["content"]:
|
||||
if isinstance(cnt, str):
|
||||
if not is_image_path(cnt): # Skip image paths
|
||||
text_contents.append(cnt)
|
||||
else:
|
||||
text_contents.append(str(cnt))
|
||||
|
||||
if text_contents: # Only add if there's text content
|
||||
message = {"role": "user", "content": " ".join(text_contents)}
|
||||
final_messages.append(message)
|
||||
else: # str
|
||||
message = {"role": "user", "content": item}
|
||||
final_messages.append(message)
|
||||
|
||||
elif isinstance(messages, str):
|
||||
final_messages.append({"role": "user", "content": messages})
|
||||
|
||||
try:
|
||||
completion = self.client.chat.completions.create( # type: ignore
|
||||
model=self.model,
|
||||
messages=final_messages, # type: ignore
|
||||
temperature=self.temperature,
|
||||
max_tokens=max_tokens or self.max_tokens,
|
||||
top_p=0.95,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
response = completion.choices[0].message.content
|
||||
final_answer = response.split("</think>\n")[-1] if "</think>" in response else response
|
||||
final_answer = final_answer.replace("<output>", "").replace("</output>", "")
|
||||
token_usage = completion.usage.total_tokens
|
||||
|
||||
return final_answer, token_usage
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in Groq API call: {e}")
|
||||
raise
|
||||
@@ -1,276 +0,0 @@
|
||||
"""Experiment management for the Cua provider."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import copy
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from PIL import Image
|
||||
import json
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExperimentManager:
|
||||
"""Manages experiment directories and logging for the agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_dir: Optional[str] = None,
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
):
|
||||
"""Initialize the experiment manager.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for saving experiment data
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
"""
|
||||
self.base_dir = base_dir
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self.run_dir = None
|
||||
self.current_turn_dir = None
|
||||
self.turn_count = 0
|
||||
self.screenshot_count = 0
|
||||
# Track all screenshots for potential API request inclusion
|
||||
self.screenshot_paths = []
|
||||
|
||||
# Set up experiment directories if base_dir is provided
|
||||
if self.base_dir:
|
||||
self.setup_experiment_dirs()
|
||||
|
||||
def setup_experiment_dirs(self) -> None:
|
||||
"""Setup the experiment directory structure."""
|
||||
if not self.base_dir:
|
||||
return
|
||||
|
||||
# Create base experiments directory if it doesn't exist
|
||||
os.makedirs(self.base_dir, exist_ok=True)
|
||||
|
||||
# Use the base_dir directly as the run_dir
|
||||
self.run_dir = self.base_dir
|
||||
logger.info(f"Using directory for experiment: {self.run_dir}")
|
||||
|
||||
# Create first turn directory
|
||||
self.create_turn_dir()
|
||||
|
||||
def create_turn_dir(self) -> None:
|
||||
"""Create a new directory for the current turn."""
|
||||
if not self.run_dir:
|
||||
return
|
||||
|
||||
self.turn_count += 1
|
||||
self.current_turn_dir = os.path.join(self.run_dir, f"turn_{self.turn_count:03d}")
|
||||
os.makedirs(self.current_turn_dir, exist_ok=True)
|
||||
logger.info(f"Created turn directory: {self.current_turn_dir}")
|
||||
|
||||
def sanitize_log_data(self, data: Any) -> Any:
|
||||
"""Sanitize data for logging by removing large base64 strings.
|
||||
|
||||
Args:
|
||||
data: Data to sanitize (dict, list, or primitive)
|
||||
|
||||
Returns:
|
||||
Sanitized copy of the data
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
result = copy.deepcopy(data)
|
||||
|
||||
# Handle nested dictionaries and lists
|
||||
for key, value in result.items():
|
||||
# Process content arrays that contain image data
|
||||
if key == "content" and isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if isinstance(item, dict):
|
||||
# Handle Anthropic format
|
||||
if item.get("type") == "image" and isinstance(item.get("source"), dict):
|
||||
source = item["source"]
|
||||
if "data" in source and isinstance(source["data"], str):
|
||||
# Replace base64 data with a placeholder and length info
|
||||
data_len = len(source["data"])
|
||||
source["data"] = f"[BASE64_IMAGE_DATA_LENGTH_{data_len}]"
|
||||
|
||||
# Handle OpenAI format
|
||||
elif item.get("type") == "image_url" and isinstance(
|
||||
item.get("image_url"), dict
|
||||
):
|
||||
url_dict = item["image_url"]
|
||||
if "url" in url_dict and isinstance(url_dict["url"], str):
|
||||
url = url_dict["url"]
|
||||
if url.startswith("data:"):
|
||||
# Replace base64 data with placeholder
|
||||
data_len = len(url)
|
||||
url_dict["url"] = f"[BASE64_IMAGE_URL_LENGTH_{data_len}]"
|
||||
|
||||
# Handle other nested structures recursively
|
||||
if isinstance(value, dict):
|
||||
result[key] = self.sanitize_log_data(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [self.sanitize_log_data(item) for item in value]
|
||||
|
||||
return result
|
||||
elif isinstance(data, list):
|
||||
return [self.sanitize_log_data(item) for item in data]
|
||||
else:
|
||||
return data
|
||||
|
||||
def save_debug_image(self, image_data: str, filename: str) -> None:
|
||||
"""Save a debug image to the experiment directory.
|
||||
|
||||
Args:
|
||||
image_data: Base64 encoded image data
|
||||
filename: Filename to save the image as
|
||||
"""
|
||||
# Since we no longer want to use the images/ folder, we'll skip this functionality
|
||||
return
|
||||
|
||||
def save_screenshot(self, img_base64: str, action_type: str = "") -> Optional[str]:
|
||||
"""Save a screenshot to the experiment directory.
|
||||
|
||||
Args:
|
||||
img_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
|
||||
Returns:
|
||||
Optional[str]: Path to the saved screenshot, or None if saving failed
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Increment screenshot counter
|
||||
self.screenshot_count += 1
|
||||
|
||||
# Create a descriptive filename
|
||||
timestamp = int(time.time() * 1000)
|
||||
action_suffix = f"_{action_type}" if action_type else ""
|
||||
filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png"
|
||||
|
||||
# Save directly to the turn directory (no screenshots subdirectory)
|
||||
filepath = os.path.join(self.current_turn_dir, filename)
|
||||
|
||||
# Save the screenshot
|
||||
img_data = base64.b64decode(img_base64)
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(img_data)
|
||||
|
||||
# Keep track of the file path for reference
|
||||
self.screenshot_paths.append(filepath)
|
||||
|
||||
return filepath
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving screenshot: {str(e)}")
|
||||
return None
|
||||
|
||||
def should_save_debug_image(self) -> bool:
|
||||
"""Determine if debug images should be saved.
|
||||
|
||||
Returns:
|
||||
Boolean indicating if debug images should be saved
|
||||
"""
|
||||
# We no longer need to save debug images, so always return False
|
||||
return False
|
||||
|
||||
def save_action_visualization(
|
||||
self, img: Image.Image, action_name: str, details: str = ""
|
||||
) -> str:
|
||||
"""Save a visualization of an action.
|
||||
|
||||
Args:
|
||||
img: Image to save
|
||||
action_name: Name of the action
|
||||
details: Additional details about the action
|
||||
|
||||
Returns:
|
||||
Path to the saved image
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
return ""
|
||||
|
||||
try:
|
||||
# Create a descriptive filename
|
||||
timestamp = int(time.time() * 1000)
|
||||
details_suffix = f"_{details}" if details else ""
|
||||
filename = f"vis_{action_name}{details_suffix}_{timestamp}.png"
|
||||
|
||||
# Save directly to the turn directory (no visualizations subdirectory)
|
||||
filepath = os.path.join(self.current_turn_dir, filename)
|
||||
|
||||
# Save the image
|
||||
img.save(filepath)
|
||||
|
||||
# Keep track of the file path for cleanup
|
||||
self.screenshot_paths.append(filepath)
|
||||
|
||||
return filepath
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving action visualization: {str(e)}")
|
||||
return ""
|
||||
|
||||
def extract_and_save_images(self, data: Any, prefix: str) -> None:
|
||||
"""Extract and save images from response data.
|
||||
|
||||
Args:
|
||||
data: Response data to extract images from
|
||||
prefix: Prefix for saved image filenames
|
||||
"""
|
||||
# Since we no longer want to save extracted images separately,
|
||||
# we'll skip this functionality entirely
|
||||
return
|
||||
|
||||
def log_api_call(
|
||||
self,
|
||||
call_type: str,
|
||||
request: Any,
|
||||
provider: str,
|
||||
model: str,
|
||||
response: Any = None,
|
||||
error: Optional[Exception] = None,
|
||||
) -> None:
|
||||
"""Log API call details to file.
|
||||
|
||||
Args:
|
||||
call_type: Type of API call (e.g., 'request', 'response', 'error')
|
||||
request: The API request data
|
||||
provider: The AI provider used
|
||||
model: The AI model used
|
||||
response: Optional API response data
|
||||
error: Optional error information
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
return
|
||||
|
||||
try:
|
||||
# Create a unique filename with timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"api_call_{timestamp}_{call_type}.json"
|
||||
filepath = os.path.join(self.current_turn_dir, filename)
|
||||
|
||||
# Sanitize data to remove large base64 strings
|
||||
sanitized_request = self.sanitize_log_data(request)
|
||||
sanitized_response = self.sanitize_log_data(response) if response is not None else None
|
||||
|
||||
# Prepare log data
|
||||
log_data = {
|
||||
"timestamp": timestamp,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"type": call_type,
|
||||
"request": sanitized_request,
|
||||
}
|
||||
|
||||
if sanitized_response is not None:
|
||||
log_data["response"] = sanitized_response
|
||||
if error is not None:
|
||||
log_data["error"] = str(error)
|
||||
|
||||
# Write to file
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(log_data, f, indent=2, default=str)
|
||||
|
||||
logger.info(f"Logged API {call_type} to {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging API call: {str(e)}")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,171 +0,0 @@
|
||||
"""Omni message manager implementation."""
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from ...core.messages import BaseMessageManager, ImageRetentionConfig
|
||||
|
||||
|
||||
class OmniMessageManager(BaseMessageManager):
|
||||
"""Message manager for multi-provider support."""
|
||||
|
||||
def __init__(self, config: Optional[ImageRetentionConfig] = None):
|
||||
"""Initialize the message manager.
|
||||
|
||||
Args:
|
||||
config: Optional configuration for image retention
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.messages: List[Dict[str, Any]] = []
|
||||
self.config = config
|
||||
|
||||
def add_user_message(self, content: str, images: Optional[List[bytes]] = None) -> None:
|
||||
"""Add a user message to the history.
|
||||
|
||||
Args:
|
||||
content: Message content
|
||||
images: Optional list of image data
|
||||
"""
|
||||
# Add images if present
|
||||
if images:
|
||||
# Initialize with proper typing for mixed content
|
||||
message_content: List[Dict[str, Any]] = [{"type": "text", "text": content}]
|
||||
|
||||
# Add each image
|
||||
for img in images:
|
||||
message_content.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64.b64encode(img).decode()}"
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
message = {"role": "user", "content": message_content}
|
||||
else:
|
||||
# Simple text message
|
||||
message = {"role": "user", "content": content}
|
||||
|
||||
self.messages.append(message)
|
||||
|
||||
# Apply retention policy
|
||||
if self.config and self.config.num_images_to_keep:
|
||||
self._apply_image_retention_policy()
|
||||
|
||||
def add_assistant_message(self, content: str) -> None:
|
||||
"""Add an assistant message to the history.
|
||||
|
||||
Args:
|
||||
content: Message content
|
||||
"""
|
||||
self.messages.append({"role": "assistant", "content": content})
|
||||
|
||||
def add_system_message(self, content: str) -> None:
|
||||
"""Add a system message to the history.
|
||||
|
||||
Args:
|
||||
content: Message content
|
||||
"""
|
||||
self.messages.append({"role": "system", "content": content})
|
||||
|
||||
def _apply_image_retention_policy(self) -> None:
|
||||
"""Apply image retention policy to message history."""
|
||||
if not self.config or not self.config.num_images_to_keep:
|
||||
return
|
||||
|
||||
# Count images from newest to oldest
|
||||
image_count = 0
|
||||
for message in reversed(self.messages):
|
||||
if message["role"] != "user":
|
||||
continue
|
||||
|
||||
# Handle multimodal messages
|
||||
if isinstance(message["content"], list):
|
||||
new_content = []
|
||||
for item in message["content"]:
|
||||
if item["type"] == "text":
|
||||
new_content.append(item)
|
||||
elif item["type"] == "image_url":
|
||||
if image_count < self.config.num_images_to_keep:
|
||||
new_content.append(item)
|
||||
image_count += 1
|
||||
message["content"] = new_content
|
||||
|
||||
def get_formatted_messages(self, provider: str) -> List[Dict[str, Any]]:
|
||||
"""Get messages formatted for specific provider.
|
||||
|
||||
Args:
|
||||
provider: Provider name to format messages for
|
||||
|
||||
Returns:
|
||||
List of formatted messages
|
||||
"""
|
||||
# Set the provider for message formatting
|
||||
self.set_provider(provider)
|
||||
|
||||
if provider == "anthropic":
|
||||
return self._format_for_anthropic()
|
||||
elif provider == "openai":
|
||||
return self._format_for_openai()
|
||||
elif provider == "groq":
|
||||
return self._format_for_groq()
|
||||
elif provider == "qwen":
|
||||
return self._format_for_qwen()
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
|
||||
def _format_for_anthropic(self) -> List[Dict[str, Any]]:
|
||||
"""Format messages for Anthropic API."""
|
||||
formatted = []
|
||||
for msg in self.messages:
|
||||
formatted_msg = {"role": msg["role"]}
|
||||
|
||||
# Handle multimodal content
|
||||
if isinstance(msg["content"], list):
|
||||
formatted_msg["content"] = []
|
||||
for item in msg["content"]:
|
||||
if item["type"] == "text":
|
||||
formatted_msg["content"].append({"type": "text", "text": item["text"]})
|
||||
elif item["type"] == "image_url":
|
||||
formatted_msg["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": item["image_url"]["url"].split(",")[1],
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
formatted_msg["content"] = msg["content"]
|
||||
|
||||
formatted.append(formatted_msg)
|
||||
return formatted
|
||||
|
||||
def _format_for_openai(self) -> List[Dict[str, Any]]:
|
||||
"""Format messages for OpenAI API."""
|
||||
# OpenAI already uses the same format
|
||||
return self.messages
|
||||
|
||||
def _format_for_groq(self) -> List[Dict[str, Any]]:
|
||||
"""Format messages for Groq API."""
|
||||
# Groq uses OpenAI-compatible format
|
||||
return self.messages
|
||||
|
||||
def _format_for_qwen(self) -> List[Dict[str, Any]]:
|
||||
"""Format messages for Qwen API."""
|
||||
formatted = []
|
||||
for msg in self.messages:
|
||||
if isinstance(msg["content"], list):
|
||||
# Convert multimodal content to text-only
|
||||
text_content = next(
|
||||
(item["text"] for item in msg["content"] if item["type"] == "text"), ""
|
||||
)
|
||||
formatted.append({"role": msg["role"], "content": text_content})
|
||||
else:
|
||||
formatted.append(msg)
|
||||
return formatted
|
||||
@@ -1,130 +0,0 @@
|
||||
"""Visualization utilities for the Cua provider."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
from PIL import Image, ImageDraw
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def visualize_click(x: int, y: int, img_base64: str) -> Image.Image:
|
||||
"""Visualize a click action by drawing on the screenshot.
|
||||
|
||||
Args:
|
||||
x: X coordinate of the click
|
||||
y: Y coordinate of the click
|
||||
img_base64: Base64 encoded image to draw on
|
||||
|
||||
Returns:
|
||||
PIL Image with visualization
|
||||
"""
|
||||
try:
|
||||
# Decode the base64 image
|
||||
img_data = base64.b64decode(img_base64)
|
||||
img = Image.open(BytesIO(img_data))
|
||||
|
||||
# Create a drawing context
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Draw concentric circles at the click position
|
||||
small_radius = 10
|
||||
large_radius = 30
|
||||
|
||||
# Draw filled inner circle
|
||||
draw.ellipse(
|
||||
[(x - small_radius, y - small_radius), (x + small_radius, y + small_radius)],
|
||||
fill="red",
|
||||
)
|
||||
|
||||
# Draw outlined outer circle
|
||||
draw.ellipse(
|
||||
[(x - large_radius, y - large_radius), (x + large_radius, y + large_radius)],
|
||||
outline="red",
|
||||
width=3,
|
||||
)
|
||||
|
||||
return img
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing click: {str(e)}")
|
||||
# Return a blank image in case of error
|
||||
return Image.new("RGB", (800, 600), color="white")
|
||||
|
||||
|
||||
def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image:
|
||||
"""Visualize a scroll action by drawing arrows on the screenshot.
|
||||
|
||||
Args:
|
||||
direction: 'up' or 'down'
|
||||
clicks: Number of scroll clicks
|
||||
img_base64: Base64 encoded image to draw on
|
||||
|
||||
Returns:
|
||||
PIL Image with visualization
|
||||
"""
|
||||
try:
|
||||
# Decode the base64 image
|
||||
img_data = base64.b64decode(img_base64)
|
||||
img = Image.open(BytesIO(img_data))
|
||||
|
||||
# Get image dimensions
|
||||
width, height = img.size
|
||||
|
||||
# Create a drawing context
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Determine arrow direction and positions
|
||||
center_x = width // 2
|
||||
arrow_width = 100
|
||||
|
||||
if direction.lower() == "up":
|
||||
# Draw up arrow in the middle of the screen
|
||||
arrow_y = height // 2
|
||||
# Arrow points
|
||||
points = [
|
||||
(center_x, arrow_y - 50), # Top point
|
||||
(center_x - arrow_width // 2, arrow_y + 50), # Bottom left
|
||||
(center_x + arrow_width // 2, arrow_y + 50), # Bottom right
|
||||
]
|
||||
color = "blue"
|
||||
else: # down
|
||||
# Draw down arrow in the middle of the screen
|
||||
arrow_y = height // 2
|
||||
# Arrow points
|
||||
points = [
|
||||
(center_x, arrow_y + 50), # Bottom point
|
||||
(center_x - arrow_width // 2, arrow_y - 50), # Top left
|
||||
(center_x + arrow_width // 2, arrow_y - 50), # Top right
|
||||
]
|
||||
color = "green"
|
||||
|
||||
# Draw filled arrow
|
||||
draw.polygon(points, fill=color)
|
||||
|
||||
# Add text showing number of clicks
|
||||
text_y = arrow_y + 70 if direction.lower() == "down" else arrow_y - 70
|
||||
draw.text((center_x - 40, text_y), f"{clicks} clicks", fill="black")
|
||||
|
||||
return img
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error visualizing scroll: {str(e)}")
|
||||
# Return a blank image in case of error
|
||||
return Image.new("RGB", (800, 600), color="white")
|
||||
|
||||
|
||||
def calculate_element_center(box: Tuple[int, int, int, int]) -> Tuple[int, int]:
|
||||
"""Calculate the center coordinates of a bounding box.
|
||||
|
||||
Args:
|
||||
box: Tuple of (left, top, right, bottom) coordinates
|
||||
|
||||
Returns:
|
||||
Tuple of (center_x, center_y) coordinates
|
||||
"""
|
||||
left, top, right, bottom = box
|
||||
center_x = (left + right) // 2
|
||||
center_y = (top + bottom) // 2
|
||||
return center_x, center_y
|
||||
134
notebooks/openai_cua_nb.ipynb
Normal file
134
notebooks/openai_cua_nb.ipynb
Normal file
@@ -0,0 +1,134 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install openai\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import requests\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"response = requests.post(\n",
|
||||
" \"https://api.openai.com/v1/responses\",\n",
|
||||
" headers={\n",
|
||||
" \"Content-Type\": \"application/json\", \n",
|
||||
" \"Authorization\": f\"Bearer {os.environ['OPENAI_API_KEY']}\"\n",
|
||||
" },\n",
|
||||
" json={\n",
|
||||
" \"model\": \"computer-use-preview\",\n",
|
||||
" \"tools\": [{\n",
|
||||
" \"type\": \"computer_use_preview\",\n",
|
||||
" \"display_width\": 1024,\n",
|
||||
" \"display_height\": 768,\n",
|
||||
" \"environment\": \"mac\" # other possible values: \"mac\", \"windows\", \"ubuntu\"\n",
|
||||
" }],\n",
|
||||
" \"input\": [\n",
|
||||
" {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": \"Check the latest OpenAI news on bing.com.\"\n",
|
||||
" }\n",
|
||||
" # Optional: include a screenshot of the initial state of the environment\n",
|
||||
" # {\n",
|
||||
" # type: \"input_image\", \n",
|
||||
" # image_url: f\"data:image/png;base64,{screenshot_base64}\"\n",
|
||||
" # }\n",
|
||||
" ],\n",
|
||||
" \"reasoning\": {\n",
|
||||
" \"generate_summary\": \"concise\",\n",
|
||||
" },\n",
|
||||
" \"truncation\": \"auto\"\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(response.json())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"True\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from openai import OpenAI\n",
|
||||
"client = OpenAI() # assumes OPENAI_API_KEY is set in env\n",
|
||||
"\n",
|
||||
"def has_model_starting_with(prefix=\"computer\"):\n",
|
||||
" models = client.models.list().data\n",
|
||||
" return any(model.id.startswith(prefix) for model in models)\n",
|
||||
"\n",
|
||||
"print(has_model_starting_with())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import requests\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"response = requests.post(\n",
|
||||
" \"https://api.openai.com/v1/responses\",\n",
|
||||
" headers={\n",
|
||||
" \"Content-Type\": \"application/json\",\n",
|
||||
" \"Authorization\": f\"Bearer {os.environ['OPENAI_API_KEY']}\"\n",
|
||||
" },\n",
|
||||
" json={\n",
|
||||
" \"model\": \"gpt-4o\",\n",
|
||||
" \"input\": \"Tell me a three sentence bedtime story about a unicorn.\"\n",
|
||||
" }\n",
|
||||
")\n",
|
||||
"print(response.json())"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
Reference in New Issue
Block a user