From e32b64590ab6309d5575a47999a7c217e322450d Mon Sep 17 00:00:00 2001 From: f-trycua Date: Sun, 23 Mar 2025 23:40:18 +0100 Subject: [PATCH] Standardize Agent Loop --- examples/agent_examples.py | 36 +- libs/agent/agent/core/__init__.py | 5 - libs/agent/agent/core/computer_agent.py | 54 +- libs/agent/agent/core/loop.py | 320 ++++-- libs/agent/agent/core/messages.py | 356 ++++-- libs/agent/agent/core/visualization.py | 197 ++++ .../agent/providers/anthropic/api_handler.py | 141 +++ .../providers/anthropic/callbacks/__init__.py | 5 + libs/agent/agent/providers/anthropic/loop.py | 432 ++++--- .../providers/anthropic/response_handler.py | 223 ++++ .../agent/providers/anthropic/tools/bash.py | 96 -- libs/agent/agent/providers/omni/__init__.py | 6 - .../agent/providers/omni/action_executor.py | 264 +++++ .../agent/agent/providers/omni/api_handler.py | 42 + .../agent/providers/omni/clients/anthropic.py | 4 + .../agent/providers/omni/clients/groq.py | 101 -- libs/agent/agent/providers/omni/experiment.py | 276 ----- libs/agent/agent/providers/omni/loop.py | 1017 ++++++++--------- libs/agent/agent/providers/omni/messages.py | 171 --- .../agent/providers/omni/visualization.py | 130 --- notebooks/openai_cua_nb.ipynb | 134 +++ 21 files changed, 2243 insertions(+), 1767 deletions(-) create mode 100644 libs/agent/agent/core/visualization.py create mode 100644 libs/agent/agent/providers/anthropic/api_handler.py create mode 100644 libs/agent/agent/providers/anthropic/callbacks/__init__.py create mode 100644 libs/agent/agent/providers/anthropic/response_handler.py create mode 100644 libs/agent/agent/providers/omni/action_executor.py create mode 100644 libs/agent/agent/providers/omni/api_handler.py delete mode 100644 libs/agent/agent/providers/omni/clients/groq.py delete mode 100644 libs/agent/agent/providers/omni/experiment.py delete mode 100644 libs/agent/agent/providers/omni/messages.py delete mode 100644 libs/agent/agent/providers/omni/visualization.py create mode 100644 notebooks/openai_cua_nb.ipynb diff --git a/examples/agent_examples.py b/examples/agent_examples.py index 2cf11587..77363eca 100644 --- a/examples/agent_examples.py +++ b/examples/agent_examples.py @@ -6,6 +6,7 @@ import logging import traceback from pathlib import Path import signal +import json from computer import Computer @@ -32,42 +33,31 @@ async def run_omni_agent_example(): # Create agent with loop and provider agent = ComputerAgent( computer=computer, - loop=AgentLoop.ANTHROPIC, - # loop=AgentLoop.OMNI, + # loop=AgentLoop.ANTHROPIC, + loop=AgentLoop.OMNI, # model=LLM(provider=LLMProvider.OPENAI, name="gpt-4.5-preview"), model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"), save_trajectory=True, - trajectory_dir=str(Path("trajectories")), only_n_most_recent_images=3, - verbosity=logging.INFO, + verbosity=logging.DEBUG, ) tasks = [ - """ -1. Look for a repository named trycua/lume on GitHub. -2. Check the open issues, open the most recent one and read it. -3. Clone the repository in users/lume/projects if it doesn't exist yet. -4. Open the repository with an app named Cursor (on the dock, black background and white cube icon). -5. From Cursor, open Composer if not already open. -6. Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue. -""" + "Look for a repository named trycua/cua on GitHub.", + "Check the open issues, open the most recent one and read it.", + "Clone the repository in users/lume/projects if it doesn't exist yet.", + "Open the repository with an app named Cursor (on the dock, black background and white cube icon).", + "From Cursor, open Composer if not already open.", + "Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.", ] async with agent: - for i, task in enumerate(tasks, 1): + for i, task in enumerate(tasks): print(f"\nExecuting task {i}/{len(tasks)}: {task}") async for result in agent.run(task): - # Check if result has the expected structure - if "role" in result and "content" in result and "metadata" in result: - title = result["metadata"].get("title", "Screen Analysis") - content = result["content"] - else: - title = result.get("metadata", {}).get("title", "Screen Analysis") - content = result.get("content", str(result)) + print(result) - print(f"\n{title}") - print(content) - print(f"Task {i} completed") + print(f"\n✅ Task {i+1}/{len(tasks)} completed: {task}") except Exception as e: logger.error(f"Error in run_omni_agent_example: {e}") diff --git a/libs/agent/agent/core/__init__.py b/libs/agent/agent/core/__init__.py index 19a57b5f..b4657510 100644 --- a/libs/agent/agent/core/__init__.py +++ b/libs/agent/agent/core/__init__.py @@ -2,11 +2,6 @@ from .loop import BaseLoop from .messages import ( - create_user_message, - create_assistant_message, - create_system_message, - create_image_message, - create_screen_message, BaseMessageManager, ImageRetentionConfig, ) diff --git a/libs/agent/agent/core/computer_agent.py b/libs/agent/agent/core/computer_agent.py index 0702ef11..f65a0ac7 100644 --- a/libs/agent/agent/core/computer_agent.py +++ b/libs/agent/agent/core/computer_agent.py @@ -3,8 +3,7 @@ import asyncio import logging import os -from typing import Any, AsyncGenerator, Dict, Optional, cast -from dataclasses import dataclass +from typing import Any, AsyncGenerator, Dict, Optional, cast, List from computer import Computer from ..providers.anthropic.loop import AnthropicLoop @@ -12,6 +11,7 @@ from ..providers.omni.loop import OmniLoop from ..providers.omni.parser import OmniParser from ..providers.omni.types import LLMProvider, LLM from .. import AgentLoop +from .messages import StandardMessageManager, ImageRetentionConfig logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -44,7 +44,6 @@ class ComputerAgent: save_trajectory: bool = True, trajectory_dir: str = "trajectories", only_n_most_recent_images: Optional[int] = None, - parser: Optional[OmniParser] = None, verbosity: int = logging.INFO, ): """Initialize the ComputerAgent. @@ -61,7 +60,6 @@ class ComputerAgent: save_trajectory: Whether to save the trajectory. trajectory_dir: Directory to save the trajectory. only_n_most_recent_images: Maximum number of recent screenshots to include in API requests. - parser: Parser instance for the OmniLoop. Only used if provider is not ANTHROPIC. verbosity: Logging level. """ # Basic agent configuration @@ -74,6 +72,11 @@ class ComputerAgent: self._initialized = False self._in_context = False + # Initialize the message manager for standardized message handling + self.message_manager = StandardMessageManager( + config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images) + ) + # Set logging level logger.setLevel(verbosity) @@ -118,10 +121,6 @@ class ComputerAgent: only_n_most_recent_images=only_n_most_recent_images, ) else: - # Default to OmniLoop for other loop types - # Initialize parser if not provided - actual_parser = parser or OmniParser() - self._loop = OmniLoop( provider=self.provider, api_key=actual_api_key, @@ -130,7 +129,7 @@ class ComputerAgent: save_trajectory=save_trajectory, base_dir=trajectory_dir, only_n_most_recent_images=only_n_most_recent_images, - parser=actual_parser, + parser=OmniParser(), ) logger.info( @@ -224,13 +223,25 @@ class ComputerAgent: """ try: logger.info(f"Running task: {task}") + logger.info( + f"Message history before task has {len(self.message_manager.messages)} messages" + ) # Initialize the computer if needed if not self._initialized: await self.initialize() - # Format task as a message - messages = [{"role": "user", "content": task}] + # Add task as a user message using the message manager + self.message_manager.add_user_message([{"type": "text", "text": task}]) + logger.info( + f"Added task message. Message history now has {len(self.message_manager.messages)} messages" + ) + + # Log message history types to help with debugging + message_types = [ + f"{i}: {msg['role']}" for i, msg in enumerate(self.message_manager.messages) + ] + logger.info(f"Message history roles: {', '.join(message_types)}") # Pass properly formatted messages to the loop if self._loop is None: @@ -239,9 +250,28 @@ class ComputerAgent: return # Execute the task and yield results - async for result in self._loop.run(messages): + async for result in self._loop.run(self.message_manager.messages): + # Extract the assistant message from the result and add it to our history + assistant_response = result["response"]["choices"][0].get("message", None) + if assistant_response and assistant_response.get("role") == "assistant": + # Extract the content from the assistant response + content = assistant_response.get("content") + self.message_manager.add_assistant_message(content) + + logger.info("Added assistant response to message history") + + # Yield the result to the caller yield result + # Logging the message history for debugging + logger.info( + f"Updated message history now has {len(self.message_manager.messages)} messages" + ) + message_types = [ + f"{i}: {msg['role']}" for i, msg in enumerate(self.message_manager.messages) + ] + logger.info(f"Updated message history roles: {', '.join(message_types)}") + except Exception as e: logger.error(f"Error in agent run method: {str(e)}") yield { diff --git a/libs/agent/agent/core/loop.py b/libs/agent/agent/core/loop.py index ad6c3cca..cbd7637f 100644 --- a/libs/agent/agent/core/loop.py +++ b/libs/agent/agent/core/loop.py @@ -2,12 +2,9 @@ import logging import asyncio -import json -import os from abc import ABC, abstractmethod from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple from datetime import datetime -import base64 from computer import Computer from .experiment import ExperimentManager @@ -18,6 +15,10 @@ logger = logging.getLogger(__name__) class BaseLoop(ABC): """Base class for agent loops that handle message processing and tool execution.""" + ########################################### + # INITIALIZATION AND CONFIGURATION + ########################################### + def __init__( self, computer: Computer, @@ -75,6 +76,64 @@ class BaseLoop(ABC): # Initialize basic tracking self.turn_count = 0 + async def initialize(self) -> None: + """Initialize both the API client and computer interface with retries.""" + for attempt in range(self.max_retries): + try: + logger.info( + f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..." + ) + + # Initialize API client + await self.initialize_client() + + logger.info("Initialization complete.") + return + except Exception as e: + if attempt < self.max_retries - 1: + logger.warning( + f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..." + ) + await asyncio.sleep(self.retry_delay) + else: + logger.error( + f"Initialization failed after {self.max_retries} attempts: {str(e)}" + ) + raise RuntimeError(f"Failed to initialize: {str(e)}") + + ########################################### + + # ABSTRACT METHODS TO BE IMPLEMENTED BY SUBCLASSES + ########################################### + + @abstractmethod + async def initialize_client(self) -> None: + """Initialize the API client and any provider-specific components. + + This method must be implemented by subclasses to set up + provider-specific clients and tools. + """ + raise NotImplementedError + + @abstractmethod + async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]: + """Run the agent loop with provided messages. + + This method handles the main agent loop including message processing, + API calls, response handling, and action execution. + + Args: + messages: List of message objects + + Yields: + Dict containing response data + """ + raise NotImplementedError + + ########################################### + # EXPERIMENT AND TRAJECTORY MANAGEMENT + ########################################### + def _setup_experiment_dirs(self) -> None: """Setup the experiment directory structure.""" if self.experiment_manager: @@ -100,10 +159,13 @@ class BaseLoop(ABC): ) -> None: """Log API call details to file. + Preserves provider-specific formats for requests and responses to ensure + accurate logging for debugging and analysis purposes. + Args: call_type: Type of API call (e.g., 'request', 'response', 'error') - request: The API request data - response: Optional API response data + request: The API request data in provider-specific format + response: Optional API response data in provider-specific format error: Optional error information """ if self.experiment_manager: @@ -130,119 +192,155 @@ class BaseLoop(ABC): if self.experiment_manager: self.experiment_manager.save_screenshot(img_base64, action_type) - async def initialize(self) -> None: - """Initialize both the API client and computer interface with retries.""" - for attempt in range(self.max_retries): - try: - logger.info( - f"Starting initialization (attempt {attempt + 1}/{self.max_retries})..." - ) + def _create_openai_compatible_response( + self, response: Any, messages: List[Dict[str, Any]] + ) -> Dict[str, Any]: + """Create an OpenAI computer use agent compatible response format. - # Initialize API client - await self.initialize_client() - - logger.info("Initialization complete.") - return - except Exception as e: - if attempt < self.max_retries - 1: - logger.warning( - f"Initialization failed (attempt {attempt + 1}/{self.max_retries}): {str(e)}. Retrying..." - ) - await asyncio.sleep(self.retry_delay) - else: - logger.error( - f"Initialization failed after {self.max_retries} attempts: {str(e)}" - ) - raise RuntimeError(f"Failed to initialize: {str(e)}") - - async def _get_parsed_screen_som(self) -> Dict[str, Any]: - """Get parsed screen information. + Args: + response: The original API response + messages: List of messages in standard OpenAI format Returns: - Dict containing screen information + A response formatted according to OpenAI's computer use agent standard """ - try: - # Take screenshot - screenshot = await self.computer.interface.screenshot() + import json - # Initialize with default values - width, height = 1024, 768 - base64_image = "" + # Create a unique ID for this response + response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(response)}" + reasoning_id = f"rs_{response_id}" + action_id = f"cu_{response_id}" + call_id = f"call_{response_id}" - # Handle different types of screenshot returns - if isinstance(screenshot, (bytes, bytearray, memoryview)): - # Raw bytes screenshot - base64_image = base64.b64encode(screenshot).decode("utf-8") - elif hasattr(screenshot, "base64_image"): - # Object-style screenshot with attributes - # Type checking can't infer these attributes, but they exist at runtime - # on certain screenshot return types - base64_image = getattr(screenshot, "base64_image") - width = ( - getattr(screenshot, "width", width) if hasattr(screenshot, "width") else width - ) - height = ( - getattr(screenshot, "height", height) - if hasattr(screenshot, "height") - else height - ) + # Extract the last assistant message + assistant_msg = None + for msg in reversed(messages): + if msg["role"] == "assistant": + assistant_msg = msg + break - # Create parsed screen data - parsed_screen = { - "width": width, - "height": height, - "parsed_content_list": [], - "timestamp": datetime.now().isoformat(), - "screenshot_base64": base64_image, + if not assistant_msg: + # If no assistant message found, create a default one + assistant_msg = {"role": "assistant", "content": "No response available"} + + # Initialize output array + output_items = [] + + # Extract reasoning and action details from the response + content = assistant_msg["content"] + reasoning_text = None + action_details = None + + # Extract reasoning and action from different content formats + if isinstance(content, str): + try: + # Try to parse JSON + parsed_content = json.loads(content) + reasoning_text = parsed_content.get("Explanation", "") + + # Extract action details + action = parsed_content.get("Action", "") + position = parsed_content.get("Position", {}) + text_input = parsed_content.get("Text", "") + + if action.lower() == "click" and position: + action_details = { + "type": "click", + "button": "left", + "x": position.get("x", 100), + "y": position.get("y", 100), + } + elif action.lower() == "type" and text_input: + action_details = { + "type": "type", + "text": text_input, + } + elif action.lower() == "scroll": + action_details = { + "type": "scroll", + "x": 100, + "y": 100, + "scroll_x": position.get("delta_x", 0), + "scroll_y": position.get("delta_y", 0), + } + except json.JSONDecodeError: + # If not valid JSON, use the content as reasoning + reasoning_text = content + elif isinstance(content, list): + # Handle list of content blocks (like Anthropic format) + for item in content: + if isinstance(item, dict): + if item.get("type") == "text": + # Collect text blocks for reasoning + if reasoning_text is None: + reasoning_text = "" + reasoning_text += item.get("text", "") + elif item.get("type") == "tool_use": + # Extract action from tool_use (similar to Anthropic format) + tool_input = item.get("input", {}) + if "click" in tool_input or "position" in tool_input: + position = tool_input.get("click", tool_input.get("position", {})) + if isinstance(position, dict) and "x" in position and "y" in position: + action_details = { + "type": "click", + "button": "left", + "x": position.get("x", 100), + "y": position.get("y", 100), + } + elif "type" in tool_input or "text" in tool_input: + action_details = { + "type": "type", + "text": tool_input.get("type", tool_input.get("text", "")), + } + elif "scroll" in tool_input: + scroll = tool_input.get("scroll", {}) + action_details = { + "type": "scroll", + "x": 100, + "y": 100, + "scroll_x": scroll.get("x", 0), + "scroll_y": scroll.get("y", 0), + } + + # Add reasoning item if we have text content + if reasoning_text: + output_items.append( + { + "type": "reasoning", + "id": reasoning_id, + "summary": [ + { + "type": "summary_text", + "text": reasoning_text[:200], # Truncate to reasonable length + } + ], + } + ) + + # If no action details extracted, use default + if not action_details: + action_details = { + "type": "click", + "button": "left", + "x": 100, + "y": 100, } - # Save screenshot if requested - if self.save_trajectory and self.experiment_manager: - try: - img_data = base64_image - if "," in img_data: - img_data = img_data.split(",")[1] - self._save_screenshot(img_data, action_type="state") - except Exception as e: - logger.error(f"Error saving screenshot: {str(e)}") + # Add computer_call item + computer_call = { + "type": "computer_call", + "id": action_id, + "call_id": call_id, + "action": action_details, + "pending_safety_checks": [], + "status": "completed", + } + output_items.append(computer_call) - return parsed_screen - except Exception as e: - logger.error(f"Error taking screenshot: {str(e)}") - return { - "width": 1024, - "height": 768, - "parsed_content_list": [], - "timestamp": datetime.now().isoformat(), - "error": f"Error taking screenshot: {str(e)}", - "screenshot_base64": "", - } - - @abstractmethod - async def initialize_client(self) -> None: - """Initialize the API client and any provider-specific components.""" - raise NotImplementedError - - @abstractmethod - async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]: - """Run the agent loop with provided messages. - - Args: - messages: List of message objects - - Yields: - Dict containing response data - """ - raise NotImplementedError - - @abstractmethod - async def _process_screen( - self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]] - ) -> None: - """Process screen information and add to messages. - - Args: - parsed_screen: Dictionary containing parsed screen info - messages: List of messages to update - """ - raise NotImplementedError + # Create the OpenAI-compatible response format + return { + "output": output_items, + "id": response_id, + # Include the original response for compatibility + "response": {"choices": [{"message": assistant_msg, "finish_reason": "stop"}]}, + } diff --git a/libs/agent/agent/core/messages.py b/libs/agent/agent/core/messages.py index d9a24e7b..0ae37b3e 100644 --- a/libs/agent/agent/core/messages.py +++ b/libs/agent/agent/core/messages.py @@ -4,9 +4,10 @@ import base64 from datetime import datetime from io import BytesIO import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast, Tuple from PIL import Image from dataclasses import dataclass +import re logger = logging.getLogger(__name__) @@ -123,123 +124,278 @@ class BaseMessageManager: break -def create_user_message(text: str) -> Dict[str, str]: - """Create a user message. +class StandardMessageManager: + """Manages messages in a standardized OpenAI format across different providers.""" - Args: - text: The message text + def __init__(self, config: Optional[ImageRetentionConfig] = None): + """Initialize message manager. - Returns: - Message dictionary - """ - return { - "role": "user", - "content": text, - } + Args: + config: Configuration for image retention + """ + self.messages: List[Dict[str, Any]] = [] + self.config = config or ImageRetentionConfig() + def add_user_message(self, content: Union[str, List[Dict[str, Any]]]) -> None: + """Add a user message. -def create_assistant_message(text: str) -> Dict[str, str]: - """Create an assistant message. + Args: + content: Message content (text or multimodal content) + """ + self.messages.append({"role": "user", "content": content}) - Args: - text: The message text + def add_assistant_message(self, content: Union[str, List[Dict[str, Any]]]) -> None: + """Add an assistant message. - Returns: - Message dictionary - """ - return { - "role": "assistant", - "content": text, - } + Args: + content: Message content (text or multimodal content) + """ + self.messages.append({"role": "assistant", "content": content}) + def add_system_message(self, content: str) -> None: + """Add a system message. -def create_system_message(text: str) -> Dict[str, str]: - """Create a system message. + Args: + content: System message content + """ + self.messages.append({"role": "system", "content": content}) - Args: - text: The message text + def get_messages(self) -> List[Dict[str, Any]]: + """Get all messages in standard format. - Returns: - Message dictionary - """ - return { - "role": "system", - "content": text, - } + Returns: + List of messages + """ + # If image retention is configured, apply it + if self.config.num_images_to_keep is not None: + return self._apply_image_retention(self.messages) + return self.messages + def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Apply image retention policy to messages. -def create_image_message( - image_base64: Optional[str] = None, - image_path: Optional[str] = None, - image_obj: Optional[Image.Image] = None, -) -> Dict[str, Union[str, List[Dict[str, Any]]]]: - """Create a message with an image. + Args: + messages: List of messages - Args: - image_base64: Base64 encoded image - image_path: Path to image file - image_obj: PIL Image object + Returns: + List of messages with image retention applied + """ + if not self.config.num_images_to_keep: + return messages - Returns: - Message dictionary with content list + # Find user messages with images + image_messages = [] + for msg in messages: + if msg["role"] == "user" and isinstance(msg["content"], list): + has_image = any( + item.get("type") == "image_url" or item.get("type") == "image" + for item in msg["content"] + ) + if has_image: + image_messages.append(msg) - Raises: - ValueError: If no image source is provided - """ - if not any([image_base64, image_path, image_obj]): - raise ValueError("Must provide one of image_base64, image_path, or image_obj") + # If we don't have more images than the limit, return all messages + if len(image_messages) <= self.config.num_images_to_keep: + return messages - # Convert to base64 if needed - if image_path and not image_base64: - with open(image_path, "rb") as f: - image_bytes = f.read() - image_base64 = base64.b64encode(image_bytes).decode("utf-8") - elif image_obj and not image_base64: - buffer = BytesIO() - image_obj.save(buffer, format="PNG") - image_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") + # Get the most recent N images to keep + images_to_keep = image_messages[-self.config.num_images_to_keep :] + images_to_remove = image_messages[: -self.config.num_images_to_keep] - return { - "role": "user", - "content": [ - {"type": "image", "image_url": {"url": f"data:image/png;base64,{image_base64}"}} - ], - } + # Create a new message list without the older images + result = [] + for msg in messages: + if msg in images_to_remove: + # Skip this message + continue + result.append(msg) + return result -def create_screen_message( - parsed_screen: Dict[str, Any], - include_raw: bool = False, -) -> Dict[str, Union[str, List[Dict[str, Any]]]]: - """Create a message with screen information. + def to_anthropic_format( + self, messages: List[Dict[str, Any]] + ) -> Tuple[List[Dict[str, Any]], str]: + """Convert standard OpenAI format messages to Anthropic format. - Args: - parsed_screen: Dictionary containing parsed screen info - include_raw: Whether to include raw screenshot base64 + Args: + messages: List of messages in OpenAI format - Returns: - Message dictionary with content - """ - if include_raw and "screenshot_base64" in parsed_screen: - # Create content list with both image and text - return { - "role": "user", - "content": [ - { - "type": "image", - "image_url": { - "url": f"data:image/png;base64,{parsed_screen['screenshot_base64']}" - }, - }, - { - "type": "text", - "text": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}", - }, - ], - } - else: - # Create text-only message with screen info - return { - "role": "user", - "content": f"Screen dimensions: {parsed_screen['width']}x{parsed_screen['height']}", - } + Returns: + Tuple containing (anthropic_messages, system_content) + """ + result = [] + system_content = "" + + # Process messages in order to maintain conversation flow + previous_assistant_tool_use_ids = ( + set() + ) # Track tool_use_ids in the previous assistant message + + for i, msg in enumerate(messages): + role = msg.get("role", "") + content = msg.get("content", "") + + if role == "system": + # Collect system messages for later use + system_content += content + "\n" + continue + + if role == "assistant": + # Track tool_use_ids in this assistant message for the next user message + previous_assistant_tool_use_ids = set() + if isinstance(content, list): + for item in content: + if ( + isinstance(item, dict) + and item.get("type") == "tool_use" + and "id" in item + ): + previous_assistant_tool_use_ids.add(item["id"]) + + logger.info( + f"Tool use IDs in assistant message #{i}: {previous_assistant_tool_use_ids}" + ) + + if role in ["user", "assistant"]: + anthropic_msg = {"role": role} + + # Convert content based on type + if isinstance(content, str): + # Simple text content + anthropic_msg["content"] = [{"type": "text", "text": content}] + elif isinstance(content, list): + # Convert complex content + anthropic_content = [] + for item in content: + item_type = item.get("type", "") + + if item_type == "text": + anthropic_content.append({"type": "text", "text": item.get("text", "")}) + elif item_type == "image_url": + # Convert OpenAI image format to Anthropic + image_url = item.get("image_url", {}).get("url", "") + if image_url.startswith("data:"): + # Extract base64 data and media type + match = re.match(r"data:(.+);base64,(.+)", image_url) + if match: + media_type, data = match.groups() + anthropic_content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": data, + }, + } + ) + else: + # Regular URL + anthropic_content.append( + { + "type": "image", + "source": { + "type": "url", + "url": image_url, + }, + } + ) + elif item_type == "tool_use": + # Always include tool_use blocks + anthropic_content.append(item) + elif item_type == "tool_result": + # Check if this is a user message AND if the tool_use_id exists in the previous assistant message + tool_use_id = item.get("tool_use_id") + + # Only include tool_result if it references a tool_use from the immediately preceding assistant message + if ( + role == "user" + and tool_use_id + and tool_use_id in previous_assistant_tool_use_ids + ): + anthropic_content.append(item) + logger.info( + f"Including tool_result with tool_use_id: {tool_use_id}" + ) + else: + # Convert to text to preserve information + logger.warning( + f"Converting tool_result to text. Tool use ID {tool_use_id} not found in previous assistant message" + ) + content_text = "Tool Result: " + if "content" in item: + if isinstance(item["content"], list): + for content_item in item["content"]: + if ( + isinstance(content_item, dict) + and content_item.get("type") == "text" + ): + content_text += content_item.get("text", "") + elif isinstance(item["content"], str): + content_text += item["content"] + anthropic_content.append({"type": "text", "text": content_text}) + + anthropic_msg["content"] = anthropic_content + + result.append(anthropic_msg) + + return result, system_content + + def from_anthropic_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert Anthropic format messages to standard OpenAI format. + + Args: + messages: List of messages in Anthropic format + + Returns: + List of messages in OpenAI format + """ + result = [] + + for msg in messages: + role = msg.get("role", "") + content = msg.get("content", []) + + if role in ["user", "assistant"]: + openai_msg = {"role": role} + + # Simple case: single text block + if len(content) == 1 and content[0].get("type") == "text": + openai_msg["content"] = content[0].get("text", "") + else: + # Complex case: multiple blocks or non-text + openai_content = [] + for item in content: + item_type = item.get("type", "") + + if item_type == "text": + openai_content.append({"type": "text", "text": item.get("text", "")}) + elif item_type == "image": + # Convert Anthropic image to OpenAI format + source = item.get("source", {}) + if source.get("type") == "base64": + media_type = source.get("media_type", "image/png") + data = source.get("data", "") + openai_content.append( + { + "type": "image_url", + "image_url": {"url": f"data:{media_type};base64,{data}"}, + } + ) + else: + # URL + openai_content.append( + { + "type": "image_url", + "image_url": {"url": source.get("url", "")}, + } + ) + elif item_type in ["tool_use", "tool_result"]: + # Pass through tool-related content + openai_content.append(item) + + openai_msg["content"] = openai_content + + result.append(openai_msg) + + return result diff --git a/libs/agent/agent/core/visualization.py b/libs/agent/agent/core/visualization.py new file mode 100644 index 00000000..8e9108a3 --- /dev/null +++ b/libs/agent/agent/core/visualization.py @@ -0,0 +1,197 @@ +"""Core visualization utilities for agents.""" + +import logging +import base64 +from typing import Dict, Tuple +from PIL import Image, ImageDraw +from io import BytesIO + +logger = logging.getLogger(__name__) + + +def visualize_click(x: int, y: int, img_base64: str) -> Image.Image: + """Visualize a click action by drawing a circle on the screenshot. + + Args: + x: X coordinate of the click + y: Y coordinate of the click + img_base64: Base64-encoded screenshot + + Returns: + PIL Image with visualization + """ + try: + # Decode the base64 image + image_data = base64.b64decode(img_base64) + img = Image.open(BytesIO(image_data)) + + # Create a copy to draw on + draw_img = img.copy() + draw = ImageDraw.Draw(draw_img) + + # Draw a circle at the click location + radius = 15 + draw.ellipse([(x - radius, y - radius), (x + radius, y + radius)], outline="red", width=3) + + # Draw crosshairs + line_length = 20 + draw.line([(x - line_length, y), (x + line_length, y)], fill="red", width=3) + draw.line([(x, y - line_length), (x, y + line_length)], fill="red", width=3) + + return draw_img + except Exception as e: + logger.error(f"Error visualizing click: {str(e)}") + # Return a blank image as fallback + return Image.new("RGB", (800, 600), "white") + + +def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image: + """Visualize a scroll action by drawing arrows on the screenshot. + + Args: + direction: Direction of scroll ('up' or 'down') + clicks: Number of scroll clicks + img_base64: Base64-encoded screenshot + + Returns: + PIL Image with visualization + """ + try: + # Decode the base64 image + image_data = base64.b64decode(img_base64) + img = Image.open(BytesIO(image_data)) + + # Create a copy to draw on + draw_img = img.copy() + draw = ImageDraw.Draw(draw_img) + + # Calculate parameters for visualization + width, height = img.size + center_x = width // 2 + + # Draw arrows to indicate scrolling + arrow_length = min(100, height // 4) + arrow_width = 30 + num_arrows = min(clicks, 3) # Don't draw too many arrows + + # Calculate starting position + if direction == "down": + start_y = height // 3 + arrow_dir = 1 # Down + else: + start_y = height * 2 // 3 + arrow_dir = -1 # Up + + # Draw the arrows + for i in range(num_arrows): + y_pos = start_y + (i * arrow_length * arrow_dir * 0.7) + arrow_top = (center_x, y_pos) + arrow_bottom = (center_x, y_pos + arrow_length * arrow_dir) + + # Draw the main line + draw.line([arrow_top, arrow_bottom], fill="red", width=5) + + # Draw the arrowhead + arrowhead_size = 20 + if direction == "down": + draw.line( + [ + (center_x - arrow_width // 2, arrow_bottom[1] - arrowhead_size), + arrow_bottom, + (center_x + arrow_width // 2, arrow_bottom[1] - arrowhead_size), + ], + fill="red", + width=5, + ) + else: + draw.line( + [ + (center_x - arrow_width // 2, arrow_bottom[1] + arrowhead_size), + arrow_bottom, + (center_x + arrow_width // 2, arrow_bottom[1] + arrowhead_size), + ], + fill="red", + width=5, + ) + + return draw_img + except Exception as e: + logger.error(f"Error visualizing scroll: {str(e)}") + # Return a blank image as fallback + return Image.new("RGB", (800, 600), "white") + + +def calculate_element_center(bbox: Dict[str, float], width: int, height: int) -> Tuple[int, int]: + """Calculate the center point of a UI element. + + Args: + bbox: Bounding box dictionary with x1, y1, x2, y2 coordinates (0-1 normalized) + width: Screen width in pixels + height: Screen height in pixels + + Returns: + (x, y) tuple with pixel coordinates + """ + center_x = int((bbox["x1"] + bbox["x2"]) / 2 * width) + center_y = int((bbox["y1"] + bbox["y2"]) / 2 * height) + return center_x, center_y + + +class VisualizationHelper: + """Helper class for visualizing agent actions.""" + + def __init__(self, agent): + """Initialize visualization helper. + + Args: + agent: Reference to the agent that will use this helper + """ + self.agent = agent + + def visualize_action(self, x: int, y: int, img_base64: str) -> None: + """Visualize a click action by drawing on the screenshot.""" + if ( + not self.agent.save_trajectory + or not hasattr(self.agent, "experiment_manager") + or not self.agent.experiment_manager + ): + return + + try: + # Use the visualization utility + img = visualize_click(x, y, img_base64) + + # Save the visualization + self.agent.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}") + except Exception as e: + logger.error(f"Error visualizing action: {str(e)}") + + def visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None: + """Visualize a scroll action by drawing arrows on the screenshot.""" + if ( + not self.agent.save_trajectory + or not hasattr(self.agent, "experiment_manager") + or not self.agent.experiment_manager + ): + return + + try: + # Use the visualization utility + img = visualize_scroll(direction, clicks, img_base64) + + # Save the visualization + self.agent.experiment_manager.save_action_visualization( + img, "scroll", f"{direction}_{clicks}" + ) + except Exception as e: + logger.error(f"Error visualizing scroll: {str(e)}") + + def save_action_visualization( + self, img: Image.Image, action_name: str, details: str = "" + ) -> str: + """Save a visualization of an action.""" + if hasattr(self.agent, "experiment_manager") and self.agent.experiment_manager: + return self.agent.experiment_manager.save_action_visualization( + img, action_name, details + ) + return "" diff --git a/libs/agent/agent/providers/anthropic/api_handler.py b/libs/agent/agent/providers/anthropic/api_handler.py new file mode 100644 index 00000000..8a5302fe --- /dev/null +++ b/libs/agent/agent/providers/anthropic/api_handler.py @@ -0,0 +1,141 @@ +"""API call handling for Anthropic provider.""" + +import logging +import asyncio +from typing import Any, Dict, List, Optional +from httpx import ConnectError, ReadTimeout + +from anthropic.types.beta import ( + BetaMessage, + BetaMessageParam, + BetaTextBlockParam, +) + +from .types import LLMProvider +from .prompts import SYSTEM_PROMPT + +# Constants +COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24" +PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31" + +logger = logging.getLogger(__name__) + + +class AnthropicAPIHandler: + """Handles API calls to Anthropic's API with structured error handling and retries.""" + + def __init__(self, loop): + """Initialize the API handler. + + Args: + loop: Reference to the parent loop instance that provides context + """ + self.loop = loop + + async def make_api_call( + self, messages: List[BetaMessageParam], system_prompt: str = SYSTEM_PROMPT + ) -> BetaMessage: + """Make API call to Anthropic with retry logic. + + Args: + messages: List of messages to send to the API + system_prompt: System prompt to use (default: SYSTEM_PROMPT) + + Returns: + API response + + Raises: + RuntimeError: If API call fails after all retries + """ + if self.loop.client is None: + raise RuntimeError("Client not initialized. Call initialize_client() first.") + if self.loop.tool_manager is None: + raise RuntimeError("Tool manager not initialized. Call initialize_client() first.") + + last_error = None + + # Add detailed debug logging to examine messages + logger.info(f"Sending {len(messages)} messages to Anthropic API") + + # Log tool use IDs and tool result IDs for debugging + tool_use_ids = set() + tool_result_ids = set() + + for i, msg in enumerate(messages): + logger.info(f"Message {i}: role={msg.get('role')}") + if isinstance(msg.get("content"), list): + for content_block in msg.get("content", []): + if isinstance(content_block, dict): + block_type = content_block.get("type") + if block_type == "tool_use" and "id" in content_block: + tool_id = content_block.get("id") + tool_use_ids.add(tool_id) + logger.info(f" - Found tool_use with ID: {tool_id}") + elif block_type == "tool_result" and "tool_use_id" in content_block: + result_id = content_block.get("tool_use_id") + tool_result_ids.add(result_id) + logger.info(f" - Found tool_result referencing ID: {result_id}") + + # Check for mismatches + missing_tool_uses = tool_result_ids - tool_use_ids + if missing_tool_uses: + logger.warning( + f"Found tool_result IDs without matching tool_use IDs: {missing_tool_uses}" + ) + + for attempt in range(self.loop.max_retries): + try: + # Log request + request_data = { + "messages": messages, + "max_tokens": self.loop.max_tokens, + "system": system_prompt, + } + # Let ExperimentManager handle sanitization + self.loop._log_api_call("request", request_data) + + # Setup betas and system + system = BetaTextBlockParam( + type="text", + text=system_prompt, + ) + + betas = [COMPUTER_USE_BETA_FLAG] + # Add prompt caching if enabled in the message manager's config + if self.loop.message_manager.config.enable_caching: + betas.append(PROMPT_CACHING_BETA_FLAG) + system["cache_control"] = {"type": "ephemeral"} + + # Make API call + response = await self.loop.client.create_message( + messages=messages, + system=[system], + tools=self.loop.tool_manager.get_tool_params(), + max_tokens=self.loop.max_tokens, + betas=betas, + ) + + # Let ExperimentManager handle sanitization + self.loop._log_api_call("response", request_data, response) + + return response + except Exception as e: + last_error = e + logger.error( + f"Error in API call (attempt {attempt + 1}/{self.loop.max_retries}): {str(e)}" + ) + self.loop._log_api_call("error", {"messages": messages}, error=e) + + if attempt < self.loop.max_retries - 1: + await asyncio.sleep( + self.loop.retry_delay * (attempt + 1) + ) # Exponential backoff + continue + + # If we get here, all retries failed + error_message = f"API call failed after {self.loop.max_retries} attempts" + if last_error: + error_message += f": {str(last_error)}" + + logger.error(error_message) + raise RuntimeError(error_message) diff --git a/libs/agent/agent/providers/anthropic/callbacks/__init__.py b/libs/agent/agent/providers/anthropic/callbacks/__init__.py new file mode 100644 index 00000000..0bb54ff9 --- /dev/null +++ b/libs/agent/agent/providers/anthropic/callbacks/__init__.py @@ -0,0 +1,5 @@ +"""Anthropic callbacks package.""" + +from .manager import CallbackManager + +__all__ = ["CallbackManager"] diff --git a/libs/agent/agent/providers/anthropic/loop.py b/libs/agent/agent/providers/anthropic/loop.py index af60138a..b6af9f04 100644 --- a/libs/agent/agent/providers/anthropic/loop.py +++ b/libs/agent/agent/providers/anthropic/loop.py @@ -2,40 +2,35 @@ import logging import asyncio -import json -import os from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, cast -import base64 -from datetime import datetime -from httpx import ConnectError, ReadTimeout - -# Anthropic-specific imports -from anthropic import AsyncAnthropic from anthropic.types.beta import ( BetaMessage, BetaMessageParam, BetaTextBlock, - BetaTextBlockParam, - BetaToolUseBlockParam, BetaContentBlockParam, ) +import base64 +from datetime import datetime # Computer from computer import Computer # Base imports from ...core.loop import BaseLoop -from ...core.messages import ImageRetentionConfig as CoreImageRetentionConfig +from ...core.messages import StandardMessageManager, ImageRetentionConfig # Anthropic provider-specific imports from .api.client import AnthropicClientFactory, BaseAnthropicClient from .tools.manager import ToolManager -from .messages.manager import MessageManager, ImageRetentionConfig -from .callbacks.manager import CallbackManager from .prompts import SYSTEM_PROMPT from .types import LLMProvider from .tools import ToolResult +# Import the new modules we created +from .api_handler import AnthropicAPIHandler +from .response_handler import AnthropicResponseHandler +from .callbacks.manager import CallbackManager + # Constants COMPUTER_USE_BETA_FLAG = "computer-use-2025-01-24" PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31" @@ -44,13 +39,22 @@ logger = logging.getLogger(__name__) class AnthropicLoop(BaseLoop): - """Anthropic-specific implementation of the agent loop.""" + """Anthropic-specific implementation of the agent loop. + + This class extends BaseLoop to provide specialized support for Anthropic's Claude models + with their unique tool-use capabilities, custom message formatting, and + callback-driven approach to handling responses. + """ + + ########################################### + # INITIALIZATION AND CONFIGURATION + ########################################### def __init__( self, api_key: str, computer: Computer, - model: str = "claude-3-7-sonnet-20250219", # Fixed model + model: str = "claude-3-7-sonnet-20250219", only_n_most_recent_images: Optional[int] = 2, base_dir: Optional[str] = "trajectories", max_retries: int = 3, @@ -83,27 +87,37 @@ class AnthropicLoop(BaseLoop): **kwargs, ) - # Ensure model is always the fixed one - self.model = "claude-3-7-sonnet-20250219" - # Anthropic-specific attributes self.provider = LLMProvider.ANTHROPIC self.client = None self.retry_count = 0 self.tool_manager = None - self.message_manager = None self.callback_manager = None - # Configure image retention with core config - self.image_retention_config = CoreImageRetentionConfig( - num_images_to_keep=only_n_most_recent_images + # Initialize standard message manager with image retention config + self.message_manager = StandardMessageManager( + config=ImageRetentionConfig( + num_images_to_keep=only_n_most_recent_images, enable_caching=True + ) ) - # Message history + # Message history (standard OpenAI format) self.message_history = [] + # Initialize handlers + self.api_handler = AnthropicAPIHandler(self) + self.response_handler = AnthropicResponseHandler(self) + + ########################################### + # CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD + ########################################### + async def initialize_client(self) -> None: - """Initialize the Anthropic API client and tools.""" + """Initialize the Anthropic API client and tools. + + Implements abstract method from BaseLoop to set up the Anthropic-specific + client, tool manager, message manager, and callback handlers. + """ try: logger.info(f"Initializing Anthropic client with model {self.model}...") @@ -112,14 +126,7 @@ class AnthropicLoop(BaseLoop): provider=self.provider, api_key=self.api_key, model=self.model ) - # Initialize message manager - self.message_manager = MessageManager( - image_retention_config=ImageRetentionConfig( - num_images_to_keep=self.only_n_most_recent_images, enable_caching=True - ) - ) - - # Initialize callback manager + # Initialize callback manager with our callback handlers self.callback_manager = CallbackManager( content_callback=self._handle_content, tool_callback=self._handle_tool_result, @@ -136,51 +143,18 @@ class AnthropicLoop(BaseLoop): self.client = None raise RuntimeError(f"Failed to initialize Anthropic client: {str(e)}") - async def _process_screen( - self, parsed_screen: Dict[str, Any], messages: List[Dict[str, Any]] - ) -> None: - """Process screen information and add to messages. - - Args: - parsed_screen: Dictionary containing parsed screen info - messages: List of messages to update - """ - try: - # Extract screenshot from parsed screen - screenshot_base64 = parsed_screen.get("screenshot_base64") - - if screenshot_base64: - # Remove data URL prefix if present - if "," in screenshot_base64: - screenshot_base64 = screenshot_base64.split(",")[1] - - # Create Anthropic-compatible message with image - screen_info_msg = { - "role": "user", - "content": [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": screenshot_base64, - }, - } - ], - } - - # Add screen info message to messages - messages.append(screen_info_msg) - - except Exception as e: - logger.error(f"Error processing screen info: {str(e)}") - raise + ########################################### + # MAIN LOOP - IMPLEMENTING ABSTRACT METHOD + ########################################### async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]: """Run the agent loop with provided messages. + Implements abstract method from BaseLoop to handle the main agent loop + for the AnthropicLoop implementation, using async queues and callbacks. + Args: - messages: List of message objects + messages: List of message objects in standard OpenAI format Yields: Dict containing response data @@ -188,7 +162,7 @@ class AnthropicLoop(BaseLoop): try: logger.info("Starting Anthropic loop run") - # Reset message history and add new messages + # Reset message history and add new messages in standard format self.message_history = [] self.message_history.extend(messages) @@ -236,6 +210,10 @@ class AnthropicLoop(BaseLoop): "metadata": {"title": "❌ Error"}, } + ########################################### + # AGENT LOOP IMPLEMENTATION + ########################################### + async def _run_loop(self, queue: asyncio.Queue) -> None: """Run the agent loop with current message history. @@ -244,31 +222,65 @@ class AnthropicLoop(BaseLoop): """ try: while True: - # Get up-to-date screen information - parsed_screen = await self._get_parsed_screen_som() + # Capture screenshot + try: + # Take screenshot - always returns raw PNG bytes + screenshot = await self.computer.interface.screenshot() - # Process screen info and update messages - await self._process_screen(parsed_screen, self.message_history) + # Convert PNG bytes to base64 + base64_image = base64.b64encode(screenshot).decode("utf-8") - # Prepare messages and make API call - if self.message_manager is None: - raise RuntimeError( - "Message manager not initialized. Call initialize_client() first." - ) - prepared_messages = self.message_manager.prepare_messages( - cast(List[BetaMessageParam], self.message_history.copy()) - ) + # Save screenshot if requested + if self.save_trajectory and self.experiment_manager: + try: + self._save_screenshot(base64_image, action_type="state") + except Exception as e: + logger.error(f"Error saving screenshot: {str(e)}") + + # Add screenshot to message history in OpenAI format + screen_info_msg = { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{base64_image}"}, + } + ], + } + self.message_history.append(screen_info_msg) + except Exception as e: + logger.error(f"Error capturing or processing screenshot: {str(e)}") + raise # Create new turn directory for this API call self._create_turn_dir() - # Use _make_api_call instead of direct client call to ensure logging - response = await self._make_api_call(prepared_messages) + # Convert standard messages to Anthropic format + anthropic_messages, system_content = self.message_manager.to_anthropic_format( + self.message_history.copy() + ) - # Handle the response - if not await self._handle_response(response, self.message_history): + # Use API handler to make API call with Anthropic format + response = await self.api_handler.make_api_call( + messages=cast(List[BetaMessageParam], anthropic_messages), + system_prompt=system_content or SYSTEM_PROMPT, + ) + + # Use response handler to handle the response and convert to standard format + # This adds the response to message_history + if not await self.response_handler.handle_response(response, self.message_history): break + # Get the last assistant message and convert it to OpenAI computer use format + for msg in reversed(self.message_history): + if msg["role"] == "assistant": + # Create OpenAI-compatible response and add to queue + openai_compatible_response = self._create_openai_compatible_response( + msg, response + ) + await queue.put(openai_compatible_response) + break + # Signal completion await queue.put(None) @@ -283,98 +295,128 @@ class AnthropicLoop(BaseLoop): ) await queue.put(None) - async def _make_api_call(self, messages: List[BetaMessageParam]) -> BetaMessage: - """Make API call to Anthropic with retry logic. + def _create_openai_compatible_response( + self, assistant_msg: Dict[str, Any], original_response: Any + ) -> Dict[str, Any]: + """Create an OpenAI computer use agent compatible response format. Args: - messages: List of messages to send to the API + assistant_msg: The assistant message in standard OpenAI format + original_response: The original API response object for ID generation Returns: - API response + A response formatted according to OpenAI's computer use agent standard """ - if self.client is None: - raise RuntimeError("Client not initialized. Call initialize_client() first.") - if self.tool_manager is None: - raise RuntimeError("Tool manager not initialized. Call initialize_client() first.") + # Create a unique ID for this response + response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(original_response)}" + reasoning_id = f"rs_{response_id}" + action_id = f"cu_{response_id}" + call_id = f"call_{response_id}" - last_error = None + # Extract reasoning and action details from the response + content = assistant_msg["content"] - for attempt in range(self.max_retries): - try: - # Log request - request_data = { - "messages": messages, - "max_tokens": self.max_tokens, - "system": SYSTEM_PROMPT, + # Initialize output array + output_items = [] + + # Add reasoning item if we have text content + reasoning_text = None + action_details = None + + # AnthropicLoop expects a list of content blocks with type "text" or "tool_use" + if isinstance(content, list): + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + reasoning_text = item.get("text", "") + elif isinstance(item, dict) and item.get("type") == "tool_use": + action_details = item + else: + # Fallback for string content + reasoning_text = content if isinstance(content, str) else None + + # If we have reasoning text, add reasoning item + if reasoning_text: + output_items.append( + { + "type": "reasoning", + "id": reasoning_id, + "summary": [ + { + "type": "summary_text", + "text": reasoning_text[:200], # Truncate to reasonable length + } + ], } - # Let ExperimentManager handle sanitization - self._log_api_call("request", request_data) + ) - # Setup betas and system - system = BetaTextBlockParam( - type="text", - text=SYSTEM_PROMPT, - ) + # Add computer_call item with action details if available + computer_call = { + "type": "computer_call", + "id": action_id, + "call_id": call_id, + "action": {"type": "click", "button": "left", "x": 100, "y": 100}, # Default action + "pending_safety_checks": [], + "status": "completed", + } - betas = [COMPUTER_USE_BETA_FLAG] - # Temporarily disable prompt caching due to "A maximum of 4 blocks with cache_control may be provided" error - # if self.message_manager.image_retention_config.enable_caching: - # betas.append(PROMPT_CACHING_BETA_FLAG) - # system["cache_control"] = {"type": "ephemeral"} + # If we have action details from a tool_use, update the computer_call + if action_details: + # Try to map tool_use to computer_call action + tool_input = action_details.get("input", {}) + if "click" in tool_input or "position" in tool_input: + position = tool_input.get("click", tool_input.get("position", {})) + if isinstance(position, dict) and "x" in position and "y" in position: + computer_call["action"] = { + "type": "click", + "button": "left", + "x": position.get("x", 100), + "y": position.get("y", 100), + } + elif "type" in tool_input or "text" in tool_input: + computer_call["action"] = { + "type": "type", + "text": tool_input.get("type", tool_input.get("text", "")), + } + elif "scroll" in tool_input: + scroll = tool_input.get("scroll", {}) + computer_call["action"] = { + "type": "scroll", + "x": 100, + "y": 100, + "scroll_x": scroll.get("x", 0), + "scroll_y": scroll.get("y", 0), + } - # Make API call - response = await self.client.create_message( - messages=messages, - system=[system], - tools=self.tool_manager.get_tool_params(), - max_tokens=self.max_tokens, - betas=betas, - ) + output_items.append(computer_call) - # Let ExperimentManager handle sanitization - self._log_api_call("response", request_data, response) + # Create the OpenAI-compatible response format + return { + "output": output_items, + "id": response_id, + # Include the original format for backward compatibility + "response": {"choices": [{"message": assistant_msg, "finish_reason": "stop"}]}, + } - return response - except Exception as e: - last_error = e - logger.error( - f"Error in API call (attempt {attempt + 1}/{self.max_retries}): {str(e)}" - ) - self._log_api_call("error", {"messages": messages}, error=e) - - if attempt < self.max_retries - 1: - await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff - continue - - # If we get here, all retries failed - error_message = f"API call failed after {self.max_retries} attempts" - if last_error: - error_message += f": {str(last_error)}" - - logger.error(error_message) - raise RuntimeError(error_message) + ########################################### + # RESPONSE AND CALLBACK HANDLING + ########################################### async def _handle_response(self, response: BetaMessage, messages: List[Dict[str, Any]]) -> bool: """Handle the Anthropic API response. Args: response: API response - messages: List of messages to update + messages: List of messages to update in standard OpenAI format Returns: True if the loop should continue, False otherwise """ try: - # Convert response to parameter format - response_params = self._response_to_params(response) + # Convert Anthropic response to standard OpenAI format + response_blocks = self._response_to_blocks(response) - # Add response to messages - messages.append( - { - "role": "assistant", - "content": response_params, - } - ) + # Add response to standard message history + messages.append({"role": "assistant", "content": response_blocks}) if self.callback_manager is None: raise RuntimeError( @@ -383,31 +425,33 @@ class AnthropicLoop(BaseLoop): # Handle tool use blocks and collect results tool_result_content = [] - for content_block in response_params: + for content_block in response.content: # Notify callback of content self.callback_manager.on_content(cast(BetaContentBlockParam, content_block)) - # Handle tool use - if content_block.get("type") == "tool_use": + # Handle tool use - carefully check and access attributes + if hasattr(content_block, "type") and content_block.type == "tool_use": if self.tool_manager is None: raise RuntimeError( "Tool manager not initialized. Call initialize_client() first." ) + + # Safely get attributes + tool_name = getattr(content_block, "name", "") + tool_input = getattr(content_block, "input", {}) + tool_id = getattr(content_block, "id", "") + result = await self.tool_manager.execute_tool( - name=content_block["name"], - tool_input=cast(Dict[str, Any], content_block["input"]), + name=tool_name, + tool_input=cast(Dict[str, Any], tool_input), ) - # Create tool result and add to content - tool_result = self._make_tool_result( - cast(ToolResult, result), content_block["id"] - ) + # Create tool result + tool_result = self._make_tool_result(cast(ToolResult, result), tool_id) tool_result_content.append(tool_result) # Notify callback of tool result - self.callback_manager.on_tool_result( - cast(ToolResult, result), content_block["id"] - ) + self.callback_manager.on_tool_result(cast(ToolResult, result), tool_id) # If no tool results, we're done if not tool_result_content: @@ -415,8 +459,8 @@ class AnthropicLoop(BaseLoop): self.callback_manager.on_content({"type": "text", "text": ""}) return False - # Add tool results to message history - messages.append({"content": tool_result_content, "role": "user"}) + # Add tool results to message history in standard format + messages.append({"role": "user", "content": tool_result_content}) return True except Exception as e: @@ -429,28 +473,41 @@ class AnthropicLoop(BaseLoop): ) return False - def _response_to_params( - self, - response: BetaMessage, - ) -> List[Dict[str, Any]]: - """Convert API response to message parameters. + def _response_to_blocks(self, response: BetaMessage) -> List[Dict[str, Any]]: + """Convert Anthropic API response to standard blocks format. Args: response: API response message Returns: - List of content blocks + List of content blocks in standard format """ result = [] for block in response.content: if isinstance(block, BetaTextBlock): result.append({"type": "text", "text": block.text}) + elif hasattr(block, "type") and block.type == "tool_use": + # Safely access attributes after confirming it's a tool_use + result.append( + { + "type": "tool_use", + "id": getattr(block, "id", ""), + "name": getattr(block, "name", ""), + "input": getattr(block, "input", {}), + } + ) else: - result.append(cast(Dict[str, Any], block.model_dump())) + # For other block types, convert to dict + block_dict = {} + for key, value in vars(block).items(): + if not key.startswith("_"): + block_dict[key] = value + result.append(block_dict) + return result def _make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]: - """Convert a tool result to API format. + """Convert a tool result to standard format. Args: result: Tool execution result @@ -489,12 +546,8 @@ class AnthropicLoop(BaseLoop): if result.base64_image: tool_result_content.append( { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": result.base64_image, - }, + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{result.base64_image}"}, } ) @@ -519,16 +572,19 @@ class AnthropicLoop(BaseLoop): result_text = f"{result.system}\n{result_text}" return result_text - def _handle_content(self, content: BetaContentBlockParam) -> None: + ########################################### + # CALLBACK HANDLERS + ########################################### + + def _handle_content(self, content): """Handle content updates from the assistant.""" if content.get("type") == "text": - text_content = cast(BetaTextBlockParam, content) - text = text_content["text"] + text = content.get("text", "") if text == "": return logger.info(f"Assistant: {text}") - def _handle_tool_result(self, result: ToolResult, tool_id: str) -> None: + def _handle_tool_result(self, result, tool_id): """Handle tool execution results.""" if result.error: logger.error(f"Tool {tool_id} error: {result.error}") diff --git a/libs/agent/agent/providers/anthropic/response_handler.py b/libs/agent/agent/providers/anthropic/response_handler.py new file mode 100644 index 00000000..d34560b4 --- /dev/null +++ b/libs/agent/agent/providers/anthropic/response_handler.py @@ -0,0 +1,223 @@ +"""Response and tool handling for Anthropic provider.""" + +import logging +from typing import Any, Dict, List, Optional, Tuple, cast + +from anthropic.types.beta import ( + BetaMessage, + BetaMessageParam, + BetaTextBlock, + BetaTextBlockParam, + BetaToolUseBlockParam, + BetaContentBlockParam, +) + +from .tools import ToolResult + +logger = logging.getLogger(__name__) + + +class AnthropicResponseHandler: + """Handles Anthropic API responses and tool execution results.""" + + def __init__(self, loop): + """Initialize the response handler. + + Args: + loop: Reference to the parent loop instance that provides context + """ + self.loop = loop + + async def handle_response(self, response: BetaMessage, messages: List[Dict[str, Any]]) -> bool: + """Handle the Anthropic API response. + + Args: + response: API response + messages: List of messages to update + + Returns: + True if the loop should continue, False otherwise + """ + try: + # Convert response to parameter format + response_params = self.response_to_params(response) + + # Collect all existing tool_use IDs from previous messages for validation + existing_tool_use_ids = set() + for msg in messages: + if msg.get("role") == "assistant" and isinstance(msg.get("content"), list): + for block in msg.get("content", []): + if ( + isinstance(block, dict) + and block.get("type") == "tool_use" + and "id" in block + ): + existing_tool_use_ids.add(block["id"]) + + # Also add new tool_use IDs from the current response + current_tool_use_ids = set() + for block in response_params: + if isinstance(block, dict) and block.get("type") == "tool_use" and "id" in block: + current_tool_use_ids.add(block["id"]) + existing_tool_use_ids.add(block["id"]) + + logger.info(f"Existing tool_use IDs in conversation: {existing_tool_use_ids}") + logger.info(f"New tool_use IDs in current response: {current_tool_use_ids}") + + # Add response to messages + messages.append( + { + "role": "assistant", + "content": response_params, + } + ) + + if self.loop.callback_manager is None: + raise RuntimeError( + "Callback manager not initialized. Call initialize_client() first." + ) + + # Handle tool use blocks and collect results + tool_result_content = [] + for content_block in response_params: + # Notify callback of content + self.loop.callback_manager.on_content(cast(BetaContentBlockParam, content_block)) + + # Handle tool use + if content_block.get("type") == "tool_use": + if self.loop.tool_manager is None: + raise RuntimeError( + "Tool manager not initialized. Call initialize_client() first." + ) + + # Execute the tool + result = await self.loop.tool_manager.execute_tool( + name=content_block["name"], + tool_input=cast(Dict[str, Any], content_block["input"]), + ) + + # Verify the tool_use ID exists in the conversation (which it should now) + tool_use_id = content_block["id"] + if tool_use_id in existing_tool_use_ids: + # Create tool result and add to content + tool_result = self.make_tool_result(cast(ToolResult, result), tool_use_id) + tool_result_content.append(tool_result) + + # Notify callback of tool result + self.loop.callback_manager.on_tool_result( + cast(ToolResult, result), content_block["id"] + ) + else: + logger.warning( + f"Tool use ID {tool_use_id} not found in previous messages. Skipping tool result." + ) + + # If no tool results, we're done + if not tool_result_content: + # Signal completion + self.loop.callback_manager.on_content({"type": "text", "text": ""}) + return False + + # Add tool results to message history + messages.append({"content": tool_result_content, "role": "user"}) + return True + + except Exception as e: + logger.error(f"Error handling response: {str(e)}") + messages.append( + { + "role": "assistant", + "content": f"Error: {str(e)}", + } + ) + return False + + def response_to_params( + self, + response: BetaMessage, + ) -> List[Dict[str, Any]]: + """Convert API response to message parameters. + + Args: + response: API response message + + Returns: + List of content blocks + """ + result = [] + for block in response.content: + if isinstance(block, BetaTextBlock): + result.append({"type": "text", "text": block.text}) + else: + result.append(cast(Dict[str, Any], block.model_dump())) + return result + + def make_tool_result(self, result: ToolResult, tool_use_id: str) -> Dict[str, Any]: + """Convert a tool result to API format. + + Args: + result: Tool execution result + tool_use_id: ID of the tool use + + Returns: + Formatted tool result + """ + if result.content: + return { + "type": "tool_result", + "content": result.content, + "tool_use_id": tool_use_id, + "is_error": bool(result.error), + } + + tool_result_content = [] + is_error = False + + if result.error: + is_error = True + tool_result_content = [ + { + "type": "text", + "text": self.maybe_prepend_system_tool_result(result, result.error), + } + ] + else: + if result.output: + tool_result_content.append( + { + "type": "text", + "text": self.maybe_prepend_system_tool_result(result, result.output), + } + ) + if result.base64_image: + tool_result_content.append( + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": result.base64_image, + }, + } + ) + + return { + "type": "tool_result", + "content": tool_result_content, + "tool_use_id": tool_use_id, + "is_error": is_error, + } + + def maybe_prepend_system_tool_result(self, result: ToolResult, result_text: str) -> str: + """Prepend system information to tool result if available. + + Args: + result: Tool execution result + result_text: Text to prepend to + + Returns: + Text with system information prepended if available + """ + if result.system: + result_text = f"{result.system}\n{result_text}" + return result_text diff --git a/libs/agent/agent/providers/anthropic/tools/bash.py b/libs/agent/agent/providers/anthropic/tools/bash.py index 703327e6..babbacfd 100644 --- a/libs/agent/agent/providers/anthropic/tools/bash.py +++ b/libs/agent/agent/providers/anthropic/tools/bash.py @@ -7,101 +7,6 @@ from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult from ....core.tools.bash import BaseBashTool -class _BashSession: - """A session of a bash shell.""" - - _started: bool - _process: asyncio.subprocess.Process - - command: str = "/bin/bash" - _output_delay: float = 0.2 # seconds - _timeout: float = 120.0 # seconds - _sentinel: str = "<>" - - def __init__(self): - self._started = False - self._timed_out = False - - async def start(self): - if self._started: - return - - self._process = await asyncio.create_subprocess_shell( - self.command, - preexec_fn=os.setsid, - shell=True, - bufsize=0, - stdin=asyncio.subprocess.PIPE, - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - ) - - self._started = True - - def stop(self): - """Terminate the bash shell.""" - if not self._started: - raise ToolError("Session has not started.") - if self._process.returncode is not None: - return - self._process.terminate() - - async def run(self, command: str): - """Execute a command in the bash shell.""" - if not self._started: - raise ToolError("Session has not started.") - if self._process.returncode is not None: - return ToolResult( - system="tool must be restarted", - error=f"bash has exited with returncode {self._process.returncode}", - ) - if self._timed_out: - raise ToolError( - f"timed out: bash has not returned in {self._timeout} seconds and must be restarted", - ) - - # we know these are not None because we created the process with PIPEs - assert self._process.stdin - assert self._process.stdout - assert self._process.stderr - - # send command to the process - self._process.stdin.write(command.encode() + f"; echo '{self._sentinel}'\n".encode()) - await self._process.stdin.drain() - - # read output from the process, until the sentinel is found - try: - async with asyncio.timeout(self._timeout): - while True: - await asyncio.sleep(self._output_delay) - # Read from stdout using the proper API - output_bytes = await self._process.stdout.read() - if output_bytes: - output = output_bytes.decode() - if self._sentinel in output: - # strip the sentinel and break - output = output[: output.index(self._sentinel)] - break - except asyncio.TimeoutError: - self._timed_out = True - raise ToolError( - f"timed out: bash has not returned in {self._timeout} seconds and must be restarted", - ) from None - - if output and output.endswith("\n"): - output = output[:-1] - - # Read from stderr using the proper API - error_bytes = await self._process.stderr.read() - error = error_bytes.decode() if error_bytes else "" - if error and error.endswith("\n"): - error = error[:-1] - - # No need to clear buffers as we're using read() which consumes the data - - return CLIResult(output=output, error=error) - - class BashTool(BaseBashTool, BaseAnthropicTool): """ A tool that allows the agent to run bash commands. @@ -123,7 +28,6 @@ class BashTool(BaseBashTool, BaseAnthropicTool): # Then initialize the Anthropic tool BaseAnthropicTool.__init__(self) # Initialize bash session - self._session = _BashSession() async def __call__(self, command: str | None = None, restart: bool = False, **kwargs): """Execute a bash command. diff --git a/libs/agent/agent/providers/omni/__init__.py b/libs/agent/agent/providers/omni/__init__.py index 8706c658..f6f0f338 100644 --- a/libs/agent/agent/providers/omni/__init__.py +++ b/libs/agent/agent/providers/omni/__init__.py @@ -3,8 +3,6 @@ # The OmniComputerAgent has been replaced by the unified ComputerAgent # which can be found in agent.core.agent from .types import LLMProvider -from .experiment import ExperimentManager -from .visualization import visualize_click, visualize_scroll, calculate_element_center from .image_utils import ( decode_base64_image, encode_image_base64, @@ -15,10 +13,6 @@ from .image_utils import ( __all__ = [ "LLMProvider", - "ExperimentManager", - "visualize_click", - "visualize_scroll", - "calculate_element_center", "decode_base64_image", "encode_image_base64", "clean_base64_data", diff --git a/libs/agent/agent/providers/omni/action_executor.py b/libs/agent/agent/providers/omni/action_executor.py new file mode 100644 index 00000000..33c231d8 --- /dev/null +++ b/libs/agent/agent/providers/omni/action_executor.py @@ -0,0 +1,264 @@ +"""Action execution for the Omni agent.""" + +import logging +from typing import Dict, Any, Tuple +import json + +from .parser import ParseResult +from ...core.visualization import calculate_element_center + +logger = logging.getLogger(__name__) + + +class ActionExecutor: + """Executes UI actions based on model instructions.""" + + def __init__(self, loop): + """Initialize the action executor. + + Args: + loop: Reference to the parent loop instance that provides context + """ + self.loop = loop + + async def execute_action(self, content: Dict[str, Any], parsed_screen: ParseResult) -> bool: + """Execute the action specified in the content. + + Args: + content: Dictionary containing the action details + parsed_screen: Current parsed screen information + + Returns: + Whether an action-specific screenshot was saved + """ + try: + action = content.get("Action", "").lower() + if not action: + return False + + # Track if we saved an action-specific screenshot + action_screenshot_saved = False + + try: + # Prepare kwargs based on action type + kwargs = {} + + if action in ["left_click", "right_click", "double_click", "move_cursor"]: + try: + box_id = int(content["Box ID"]) + logger.info(f"Processing Box ID: {box_id}") + + # Calculate click coordinates + x, y = await self.calculate_click_coordinates(box_id, parsed_screen) + logger.info(f"Calculated coordinates: x={x}, y={y}") + + kwargs["x"] = x + kwargs["y"] = y + + # Visualize action if screenshot is available + if parsed_screen.annotated_image_base64: + img_data = parsed_screen.annotated_image_base64 + # Remove data URL prefix if present + if img_data.startswith("data:image"): + img_data = img_data.split(",")[1] + # Only save visualization for coordinate-based actions + self.loop.viz_helper.visualize_action(x, y, img_data) + action_screenshot_saved = True + + except ValueError as e: + logger.error(f"Error processing Box ID: {str(e)}") + return False + + elif action == "drag_to": + try: + box_id = int(content["Box ID"]) + x, y = await self.calculate_click_coordinates(box_id, parsed_screen) + kwargs.update( + { + "x": x, + "y": y, + "button": content.get("button", "left"), + "duration": float(content.get("duration", 0.5)), + } + ) + + # Visualize drag destination if screenshot is available + if parsed_screen.annotated_image_base64: + img_data = parsed_screen.annotated_image_base64 + # Remove data URL prefix if present + if img_data.startswith("data:image"): + img_data = img_data.split(",")[1] + # Only save visualization for coordinate-based actions + self.loop.viz_helper.visualize_action(x, y, img_data) + action_screenshot_saved = True + + except ValueError as e: + logger.error(f"Error processing drag coordinates: {str(e)}") + return False + + elif action == "type_text": + kwargs["text"] = content["Value"] + # For type_text, store the value in the action type + action_type = f"type_{content['Value'][:20]}" # Truncate if too long + elif action == "press_key": + kwargs["key"] = content["Value"] + action_type = f"press_{content['Value']}" + elif action == "hotkey": + if isinstance(content.get("Value"), list): + keys = content["Value"] + action_type = f"hotkey_{'_'.join(keys)}" + else: + # Simply split string format like "command+space" into a list + keys = [k.strip() for k in content["Value"].lower().split("+")] + action_type = f"hotkey_{content['Value'].replace('+', '_')}" + logger.info(f"Preparing hotkey with keys: {keys}") + # Get the method but call it with *args instead of **kwargs + method = getattr(self.loop.computer.interface, action) + await method(*keys) # Unpack the keys list as positional arguments + logger.info(f"Tool execution completed successfully: {action}") + + # For hotkeys, take a screenshot after the action + try: + # Get a new screenshot after the action and save it with the action type + new_parsed_screen = await self.loop._get_parsed_screen_som( + save_screenshot=False + ) + if new_parsed_screen and new_parsed_screen.annotated_image_base64: + img_data = new_parsed_screen.annotated_image_base64 + # Remove data URL prefix if present + if img_data.startswith("data:image"): + img_data = img_data.split(",")[1] + # Save with action type to indicate this is a post-action screenshot + self.loop._save_screenshot(img_data, action_type=action_type) + action_screenshot_saved = True + except Exception as screenshot_error: + logger.error( + f"Error taking post-hotkey screenshot: {str(screenshot_error)}" + ) + + return action_screenshot_saved + + elif action in ["scroll_down", "scroll_up"]: + clicks = int(content.get("amount", 1)) + kwargs["clicks"] = clicks + action_type = f"scroll_{action.split('_')[1]}_{clicks}" + + # Visualize scrolling if screenshot is available + if parsed_screen.annotated_image_base64: + img_data = parsed_screen.annotated_image_base64 + # Remove data URL prefix if present + if img_data.startswith("data:image"): + img_data = img_data.split(",")[1] + direction = "down" if action == "scroll_down" else "up" + # For scrolling, we only save the visualization to avoid duplicate images + self.loop.viz_helper.visualize_scroll(direction, clicks, img_data) + action_screenshot_saved = True + + else: + logger.warning(f"Unknown action: {action}") + return False + + # Execute tool and handle result + try: + method = getattr(self.loop.computer.interface, action) + logger.info(f"Found method for action '{action}': {method}") + await method(**kwargs) + logger.info(f"Tool execution completed successfully: {action}") + + # For non-coordinate based actions that don't already have visualizations, + # take a new screenshot after the action + if not action_screenshot_saved: + # Take a new screenshot + try: + # Get a new screenshot after the action and save it with the action type + new_parsed_screen = await self.loop._get_parsed_screen_som( + save_screenshot=False + ) + if new_parsed_screen and new_parsed_screen.annotated_image_base64: + img_data = new_parsed_screen.annotated_image_base64 + # Remove data URL prefix if present + if img_data.startswith("data:image"): + img_data = img_data.split(",")[1] + # Save with action type to indicate this is a post-action screenshot + if "action_type" in locals(): + self.loop._save_screenshot(img_data, action_type=action_type) + else: + self.loop._save_screenshot(img_data, action_type=action) + # Update the action screenshot flag for this turn + action_screenshot_saved = True + except Exception as screenshot_error: + logger.error( + f"Error taking post-action screenshot: {str(screenshot_error)}" + ) + + except AttributeError as e: + logger.error(f"Method not found for action '{action}': {str(e)}") + return False + except Exception as tool_error: + logger.error(f"Tool execution failed: {str(tool_error)}") + return False + + return action_screenshot_saved + + except Exception as e: + logger.error(f"Error executing action {action}: {str(e)}") + return False + + except Exception as e: + logger.error(f"Error in execute_action: {str(e)}") + return False + + async def calculate_click_coordinates( + self, box_id: int, parsed_screen: ParseResult + ) -> Tuple[int, int]: + """Calculate click coordinates based on box ID. + + Args: + box_id: The ID of the box to click + parsed_screen: The parsed screen information + + Returns: + Tuple of (x, y) coordinates + + Raises: + ValueError: If box_id is invalid or missing from parsed screen + """ + # First try to use structured elements data + logger.info(f"Elements count: {len(parsed_screen.elements)}") + + # Try to find element with matching ID + for element in parsed_screen.elements: + if element.id == box_id: + logger.info(f"Found element with ID {box_id}: {element}") + bbox = element.bbox + + # Get screen dimensions from the metadata if available, or fallback + width = parsed_screen.metadata.width if parsed_screen.metadata else 1920 + height = parsed_screen.metadata.height if parsed_screen.metadata else 1080 + logger.info(f"Screen dimensions: width={width}, height={height}") + + # Create a dictionary from the element's bbox for calculate_element_center + bbox_dict = {"x1": bbox.x1, "y1": bbox.y1, "x2": bbox.x2, "y2": bbox.y2} + center_x, center_y = calculate_element_center(bbox_dict, width, height) + logger.info(f"Calculated center: ({center_x}, {center_y})") + + # Validate coordinates - if they're (0,0) or unreasonably small, + # use a default position in the center of the screen + if center_x == 0 and center_y == 0: + logger.warning("Got (0,0) coordinates, using fallback position") + center_x = width // 2 + center_y = height // 2 + logger.info(f"Using fallback center: ({center_x}, {center_y})") + + return center_x, center_y + + # If we couldn't find the box, use center of screen + logger.error( + f"Box ID {box_id} not found in structured elements (count={len(parsed_screen.elements)})" + ) + + # Use center of screen as fallback + width = parsed_screen.metadata.width if parsed_screen.metadata else 1920 + height = parsed_screen.metadata.height if parsed_screen.metadata else 1080 + logger.warning(f"Using fallback position in center of screen ({width//2}, {height//2})") + return width // 2, height // 2 diff --git a/libs/agent/agent/providers/omni/api_handler.py b/libs/agent/agent/providers/omni/api_handler.py new file mode 100644 index 00000000..8256c1f6 --- /dev/null +++ b/libs/agent/agent/providers/omni/api_handler.py @@ -0,0 +1,42 @@ +"""API handling for Omni provider.""" + +import logging +from typing import Any, Dict, List + +from .prompts import SYSTEM_PROMPT + +logger = logging.getLogger(__name__) + + +class OmniAPIHandler: + """Handler for Omni API calls.""" + + def __init__(self, loop): + """Initialize the API handler. + + Args: + loop: Parent loop instance + """ + self.loop = loop + + async def make_api_call( + self, messages: List[Dict[str, Any]], system_prompt: str = SYSTEM_PROMPT + ) -> Any: + """Make an API call to the appropriate provider. + + Args: + messages: List of messages in standard OpenAI format + system_prompt: System prompt to use + + Returns: + API response + """ + if not self.loop._make_api_call: + raise RuntimeError("Loop does not have _make_api_call method") + + try: + # Use the loop's _make_api_call method with standard messages + return await self.loop._make_api_call(messages=messages, system_prompt=system_prompt) + except Exception as e: + logger.error(f"Error making API call: {str(e)}") + raise diff --git a/libs/agent/agent/providers/omni/clients/anthropic.py b/libs/agent/agent/providers/omni/clients/anthropic.py index 6d835277..5cdd3c46 100644 --- a/libs/agent/agent/providers/omni/clients/anthropic.py +++ b/libs/agent/agent/providers/omni/clients/anthropic.py @@ -44,6 +44,10 @@ class AnthropicClient(BaseOmniClient): anthropic_messages = [] for message in messages: + # Skip messages with empty content + if not message.get("content"): + continue + if message["role"] == "user": anthropic_messages.append({"role": "user", "content": message["content"]}) elif message["role"] == "assistant": diff --git a/libs/agent/agent/providers/omni/clients/groq.py b/libs/agent/agent/providers/omni/clients/groq.py deleted file mode 100644 index a7d6776b..00000000 --- a/libs/agent/agent/providers/omni/clients/groq.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Groq client implementation.""" - -import os -import logging -from typing import Dict, List, Optional, Any, Tuple - -from groq import Groq -import re -from .utils import is_image_path -from .base import BaseOmniClient - -logger = logging.getLogger(__name__) - - -class GroqClient(BaseOmniClient): - """Client for making Groq API calls.""" - - def __init__( - self, - api_key: Optional[str] = None, - model: str = "deepseek-r1-distill-llama-70b", - max_tokens: int = 4096, - temperature: float = 0.6, - ): - """Initialize Groq client. - - Args: - api_key: Groq API key (if not provided, will try to get from env) - model: Model name to use - max_tokens: Maximum tokens to generate - temperature: Temperature for sampling - """ - super().__init__(api_key=api_key, model=model) - self.api_key = api_key or os.getenv("GROQ_API_KEY") - if not self.api_key: - raise ValueError("No Groq API key provided") - - self.max_tokens = max_tokens - self.temperature = temperature - self.client = Groq(api_key=self.api_key) - self.model: str = model # Add explicit type annotation - - def run_interleaved( - self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None - ) -> tuple[str, int]: - """Run interleaved chat completion. - - Args: - messages: List of message dicts - system: System prompt - max_tokens: Optional max tokens override - - Returns: - Tuple of (response text, token usage) - """ - # Avoid using system messages for R1 - final_messages = [{"role": "user", "content": system}] - - # Process messages - if isinstance(messages, list): - for item in messages: - if isinstance(item, dict): - # For dict items, concatenate all text content, ignoring images - text_contents = [] - for cnt in item["content"]: - if isinstance(cnt, str): - if not is_image_path(cnt): # Skip image paths - text_contents.append(cnt) - else: - text_contents.append(str(cnt)) - - if text_contents: # Only add if there's text content - message = {"role": "user", "content": " ".join(text_contents)} - final_messages.append(message) - else: # str - message = {"role": "user", "content": item} - final_messages.append(message) - - elif isinstance(messages, str): - final_messages.append({"role": "user", "content": messages}) - - try: - completion = self.client.chat.completions.create( # type: ignore - model=self.model, - messages=final_messages, # type: ignore - temperature=self.temperature, - max_tokens=max_tokens or self.max_tokens, - top_p=0.95, - stream=False, - ) - - response = completion.choices[0].message.content - final_answer = response.split("\n")[-1] if "" in response else response - final_answer = final_answer.replace("", "").replace("", "") - token_usage = completion.usage.total_tokens - - return final_answer, token_usage - - except Exception as e: - logger.error(f"Error in Groq API call: {e}") - raise diff --git a/libs/agent/agent/providers/omni/experiment.py b/libs/agent/agent/providers/omni/experiment.py deleted file mode 100644 index 7e5c3a1d..00000000 --- a/libs/agent/agent/providers/omni/experiment.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Experiment management for the Cua provider.""" - -import os -import logging -import copy -import base64 -from io import BytesIO -from datetime import datetime -from typing import Any, Dict, List, Optional -from PIL import Image -import json -import time - -logger = logging.getLogger(__name__) - - -class ExperimentManager: - """Manages experiment directories and logging for the agent.""" - - def __init__( - self, - base_dir: Optional[str] = None, - only_n_most_recent_images: Optional[int] = None, - ): - """Initialize the experiment manager. - - Args: - base_dir: Base directory for saving experiment data - only_n_most_recent_images: Maximum number of recent screenshots to include in API requests - """ - self.base_dir = base_dir - self.only_n_most_recent_images = only_n_most_recent_images - self.run_dir = None - self.current_turn_dir = None - self.turn_count = 0 - self.screenshot_count = 0 - # Track all screenshots for potential API request inclusion - self.screenshot_paths = [] - - # Set up experiment directories if base_dir is provided - if self.base_dir: - self.setup_experiment_dirs() - - def setup_experiment_dirs(self) -> None: - """Setup the experiment directory structure.""" - if not self.base_dir: - return - - # Create base experiments directory if it doesn't exist - os.makedirs(self.base_dir, exist_ok=True) - - # Use the base_dir directly as the run_dir - self.run_dir = self.base_dir - logger.info(f"Using directory for experiment: {self.run_dir}") - - # Create first turn directory - self.create_turn_dir() - - def create_turn_dir(self) -> None: - """Create a new directory for the current turn.""" - if not self.run_dir: - return - - self.turn_count += 1 - self.current_turn_dir = os.path.join(self.run_dir, f"turn_{self.turn_count:03d}") - os.makedirs(self.current_turn_dir, exist_ok=True) - logger.info(f"Created turn directory: {self.current_turn_dir}") - - def sanitize_log_data(self, data: Any) -> Any: - """Sanitize data for logging by removing large base64 strings. - - Args: - data: Data to sanitize (dict, list, or primitive) - - Returns: - Sanitized copy of the data - """ - if isinstance(data, dict): - result = copy.deepcopy(data) - - # Handle nested dictionaries and lists - for key, value in result.items(): - # Process content arrays that contain image data - if key == "content" and isinstance(value, list): - for i, item in enumerate(value): - if isinstance(item, dict): - # Handle Anthropic format - if item.get("type") == "image" and isinstance(item.get("source"), dict): - source = item["source"] - if "data" in source and isinstance(source["data"], str): - # Replace base64 data with a placeholder and length info - data_len = len(source["data"]) - source["data"] = f"[BASE64_IMAGE_DATA_LENGTH_{data_len}]" - - # Handle OpenAI format - elif item.get("type") == "image_url" and isinstance( - item.get("image_url"), dict - ): - url_dict = item["image_url"] - if "url" in url_dict and isinstance(url_dict["url"], str): - url = url_dict["url"] - if url.startswith("data:"): - # Replace base64 data with placeholder - data_len = len(url) - url_dict["url"] = f"[BASE64_IMAGE_URL_LENGTH_{data_len}]" - - # Handle other nested structures recursively - if isinstance(value, dict): - result[key] = self.sanitize_log_data(value) - elif isinstance(value, list): - result[key] = [self.sanitize_log_data(item) for item in value] - - return result - elif isinstance(data, list): - return [self.sanitize_log_data(item) for item in data] - else: - return data - - def save_debug_image(self, image_data: str, filename: str) -> None: - """Save a debug image to the experiment directory. - - Args: - image_data: Base64 encoded image data - filename: Filename to save the image as - """ - # Since we no longer want to use the images/ folder, we'll skip this functionality - return - - def save_screenshot(self, img_base64: str, action_type: str = "") -> Optional[str]: - """Save a screenshot to the experiment directory. - - Args: - img_base64: Base64 encoded screenshot - action_type: Type of action that triggered the screenshot - - Returns: - Optional[str]: Path to the saved screenshot, or None if saving failed - """ - if not self.current_turn_dir: - return None - - try: - # Increment screenshot counter - self.screenshot_count += 1 - - # Create a descriptive filename - timestamp = int(time.time() * 1000) - action_suffix = f"_{action_type}" if action_type else "" - filename = f"screenshot_{self.screenshot_count:03d}{action_suffix}_{timestamp}.png" - - # Save directly to the turn directory (no screenshots subdirectory) - filepath = os.path.join(self.current_turn_dir, filename) - - # Save the screenshot - img_data = base64.b64decode(img_base64) - with open(filepath, "wb") as f: - f.write(img_data) - - # Keep track of the file path for reference - self.screenshot_paths.append(filepath) - - return filepath - except Exception as e: - logger.error(f"Error saving screenshot: {str(e)}") - return None - - def should_save_debug_image(self) -> bool: - """Determine if debug images should be saved. - - Returns: - Boolean indicating if debug images should be saved - """ - # We no longer need to save debug images, so always return False - return False - - def save_action_visualization( - self, img: Image.Image, action_name: str, details: str = "" - ) -> str: - """Save a visualization of an action. - - Args: - img: Image to save - action_name: Name of the action - details: Additional details about the action - - Returns: - Path to the saved image - """ - if not self.current_turn_dir: - return "" - - try: - # Create a descriptive filename - timestamp = int(time.time() * 1000) - details_suffix = f"_{details}" if details else "" - filename = f"vis_{action_name}{details_suffix}_{timestamp}.png" - - # Save directly to the turn directory (no visualizations subdirectory) - filepath = os.path.join(self.current_turn_dir, filename) - - # Save the image - img.save(filepath) - - # Keep track of the file path for cleanup - self.screenshot_paths.append(filepath) - - return filepath - except Exception as e: - logger.error(f"Error saving action visualization: {str(e)}") - return "" - - def extract_and_save_images(self, data: Any, prefix: str) -> None: - """Extract and save images from response data. - - Args: - data: Response data to extract images from - prefix: Prefix for saved image filenames - """ - # Since we no longer want to save extracted images separately, - # we'll skip this functionality entirely - return - - def log_api_call( - self, - call_type: str, - request: Any, - provider: str, - model: str, - response: Any = None, - error: Optional[Exception] = None, - ) -> None: - """Log API call details to file. - - Args: - call_type: Type of API call (e.g., 'request', 'response', 'error') - request: The API request data - provider: The AI provider used - model: The AI model used - response: Optional API response data - error: Optional error information - """ - if not self.current_turn_dir: - return - - try: - # Create a unique filename with timestamp - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - filename = f"api_call_{timestamp}_{call_type}.json" - filepath = os.path.join(self.current_turn_dir, filename) - - # Sanitize data to remove large base64 strings - sanitized_request = self.sanitize_log_data(request) - sanitized_response = self.sanitize_log_data(response) if response is not None else None - - # Prepare log data - log_data = { - "timestamp": timestamp, - "provider": provider, - "model": model, - "type": call_type, - "request": sanitized_request, - } - - if sanitized_response is not None: - log_data["response"] = sanitized_response - if error is not None: - log_data["error"] = str(error) - - # Write to file - with open(filepath, "w") as f: - json.dump(log_data, f, indent=2, default=str) - - logger.info(f"Logged API {call_type} to {filepath}") - - except Exception as e: - logger.error(f"Error logging API call: {str(e)}") diff --git a/libs/agent/agent/providers/omni/loop.py b/libs/agent/agent/providers/omni/loop.py index cb547148..b73c5826 100644 --- a/libs/agent/agent/providers/omni/loop.py +++ b/libs/agent/agent/providers/omni/loop.py @@ -2,33 +2,28 @@ import logging from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator, Union -import base64 from PIL import Image -from io import BytesIO import json import re import os from datetime import datetime import asyncio from httpx import ConnectError, ReadTimeout -import shutil -import copy from typing import cast from .parser import OmniParser, ParseResult, ParserMetadata, UIElement from ...core.loop import BaseLoop +from ...core.visualization import VisualizationHelper +from ...core.messages import StandardMessageManager, ImageRetentionConfig from computer import Computer from .types import LLMProvider -from .clients.base import BaseOmniClient from .clients.openai import OpenAIClient -from .clients.groq import GroqClient from .clients.anthropic import AnthropicClient from .prompts import SYSTEM_PROMPT from .utils import compress_image_base64 -from .visualization import visualize_click, visualize_scroll, calculate_element_center from .image_utils import decode_base64_image, clean_base64_data -from ...core.messages import ImageRetentionConfig -from .messages import OmniMessageManager +from .api_handler import OmniAPIHandler +from .action_executor import ActionExecutor logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -42,7 +37,16 @@ def extract_data(input_string: str, data_type: str) -> str: class OmniLoop(BaseLoop): - """Omni-specific implementation of the agent loop.""" + """Omni-specific implementation of the agent loop. + + This class extends BaseLoop to provide support for multimodal models + from various providers (OpenAI, Anthropic, etc.) with UI parsing + and desktop automation capabilities. + """ + + ########################################### + # INITIALIZATION AND CONFIGURATION + ########################################### def __init__( self, @@ -77,8 +81,9 @@ class OmniLoop(BaseLoop): self.provider = provider # Initialize message manager with image retention config - image_retention_config = ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images) - self.message_manager = OmniMessageManager(config=image_retention_config) + self.message_manager = StandardMessageManager( + config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images) + ) # Initialize base class (which will set up experiment manager) super().__init__( @@ -97,87 +102,23 @@ class OmniLoop(BaseLoop): self.client = None self.retry_count = 0 - def _should_save_debug_image(self) -> bool: - """Check if debug images should be saved. + # Initialize handlers + self.api_handler = OmniAPIHandler(self) + self.action_executor = ActionExecutor(self) + self.viz_helper = VisualizationHelper(self) - Returns: - bool: Always returns False as debug image saving has been disabled. - """ - # Debug image saving functionality has been removed - return False + logger.info("OmniLoop initialized with StandardMessageManager") - def _extract_and_save_images(self, data: Any, prefix: str) -> None: - """Extract and save images from API data. - - This method is now a no-op as image extraction functionality has been removed. - - Args: - data: Data to extract images from - prefix: Prefix for the extracted image filenames - """ - # Image extraction functionality has been removed - return - - def _save_debug_image(self, image_data: str, filename: str) -> None: - """Save a debug image to the current turn directory. - - This method is now a no-op as debug image saving functionality has been removed. - - Args: - image_data: Base64 encoded image data - filename: Name to use for the saved image - """ - # Debug image saving functionality has been removed - return - - def _visualize_action(self, x: int, y: int, img_base64: str) -> None: - """Visualize an action by drawing on the screenshot.""" - if ( - not self.save_trajectory - or not hasattr(self, "experiment_manager") - or not self.experiment_manager - ): - return - - try: - # Use the visualization utility - img = visualize_click(x, y, img_base64) - - # Save the visualization - self.experiment_manager.save_action_visualization(img, "click", f"x{x}_y{y}") - except Exception as e: - logger.error(f"Error visualizing action: {str(e)}") - - def _visualize_scroll(self, direction: str, clicks: int, img_base64: str) -> None: - """Visualize a scroll action by drawing arrows on the screenshot.""" - if ( - not self.save_trajectory - or not hasattr(self, "experiment_manager") - or not self.experiment_manager - ): - return - - try: - # Use the visualization utility - img = visualize_scroll(direction, clicks, img_base64) - - # Save the visualization - self.experiment_manager.save_action_visualization( - img, "scroll", f"{direction}_{clicks}" - ) - except Exception as e: - logger.error(f"Error visualizing scroll: {str(e)}") - - def _save_action_visualization( - self, img: Image.Image, action_name: str, details: str = "" - ) -> str: - """Save a visualization of an action.""" - if hasattr(self, "experiment_manager") and self.experiment_manager: - return self.experiment_manager.save_action_visualization(img, action_name, details) - return "" + ########################################### + # CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD + ########################################### async def initialize_client(self) -> None: - """Initialize the appropriate client based on provider.""" + """Initialize the appropriate client based on provider. + + Implements abstract method from BaseLoop to set up the specific + provider client (OpenAI, Anthropic, etc.). + """ try: logger.info(f"Initializing {self.provider} client with model {self.model}...") @@ -199,6 +140,10 @@ class OmniLoop(BaseLoop): self.client = None raise RuntimeError(f"Failed to initialize client: {str(e)}") + ########################################### + # API CALL HANDLING + ########################################### + async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any: """Make API call to provider with retry logic.""" # Create new turn directory for this API call @@ -218,68 +163,73 @@ class OmniLoop(BaseLoop): if self.client is None: raise RuntimeError("Failed to initialize client") - # Set the provider in message manager based on current provider - provider_name = str(self.provider).split(".")[-1].lower() # Extract name from enum - self.message_manager.set_provider(provider_name) + # Get messages in standard format from the message manager + self.message_manager.messages = messages.copy() + prepared_messages = self.message_manager.get_messages() - # Apply image retention and prepare messages - # This will limit the number of images based on only_n_most_recent_images - prepared_messages = self.message_manager.get_formatted_messages(provider_name) - - # Filter out system messages for Anthropic + # Special handling for Anthropic if self.provider == LLMProvider.ANTHROPIC: + # Convert to Anthropic format + anthropic_messages, anthropic_system = self.message_manager.to_anthropic_format( + prepared_messages + ) + + # Filter out any empty/invalid messages filtered_messages = [ - msg for msg in prepared_messages if msg["role"] != "system" + msg + for msg in anthropic_messages + if msg.get("role") in ["user", "assistant"] ] + + # Ensure there's at least one message for Anthropic + if not filtered_messages: + logger.warning( + "No valid messages found for Anthropic API call. Adding a default user message." + ) + filtered_messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Please help with this task."} + ], + } + ] + + # Combine system prompts if needed + final_system_prompt = anthropic_system or system_prompt + + # Log request + request_data = { + "messages": filtered_messages, + "max_tokens": self.max_tokens, + "system": final_system_prompt, + } + + self._log_api_call("request", request_data) + + # Make API call + response = await self.client.run_interleaved( + messages=filtered_messages, + system=final_system_prompt, + max_tokens=self.max_tokens, + ) else: - filtered_messages = prepared_messages + # For OpenAI and others, use standard format directly + # Log request + request_data = { + "messages": prepared_messages, + "max_tokens": self.max_tokens, + "system": system_prompt, + } - # Log request - request_data = {"messages": filtered_messages, "max_tokens": self.max_tokens} + self._log_api_call("request", request_data) - if self.provider == LLMProvider.ANTHROPIC: - request_data["system"] = self._get_system_prompt() - else: - request_data["system"] = system_prompt - - self._log_api_call("request", request_data) - - # Make API call with appropriate parameters - if self.client is None: - raise RuntimeError("Client not initialized. Call initialize_client() first.") - - # Check if the method is async by inspecting the client implementation - run_method = self.client.run_interleaved - is_async = asyncio.iscoroutinefunction(run_method) - - if is_async: - # For async implementations (AnthropicClient) - if self.provider == LLMProvider.ANTHROPIC: - response = await run_method( - messages=filtered_messages, - system=self._get_system_prompt(), - max_tokens=self.max_tokens, - ) - else: - response = await run_method( - messages=messages, - system=system_prompt, - max_tokens=self.max_tokens, - ) - else: - # For non-async implementations (GroqClient, etc.) - if self.provider == LLMProvider.ANTHROPIC: - response = run_method( - messages=filtered_messages, - system=self._get_system_prompt(), - max_tokens=self.max_tokens, - ) - else: - response = run_method( - messages=messages, - system=system_prompt, - max_tokens=self.max_tokens, - ) + # Make API call + response = await self.client.run_interleaved( + messages=prepared_messages, + system=system_prompt, + max_tokens=self.max_tokens, + ) # Log success response self._log_api_call("response", request_data, response) @@ -327,6 +277,10 @@ class OmniLoop(BaseLoop): logger.error(error_message) raise RuntimeError(error_message) + ########################################### + # RESPONSE AND ACTION HANDLING + ########################################### + async def _handle_response( self, response: Any, messages: List[Dict[str, Any]], parsed_screen: ParseResult ) -> Tuple[bool, bool]: @@ -341,21 +295,61 @@ class OmniLoop(BaseLoop): Tuple of (should_continue, action_screenshot_saved) """ action_screenshot_saved = False + + # Helper function to safely add assistant messages using the message manager + def add_assistant_message(content): + if isinstance(content, str): + # Convert string to proper format + formatted_content = [{"type": "text", "text": content}] + self.message_manager.add_assistant_message(formatted_content) + logger.info("Added formatted text assistant message") + elif isinstance(content, list): + # Already in proper format + self.message_manager.add_assistant_message(content) + logger.info("Added structured assistant message") + else: + # Default case - convert to string + formatted_content = [{"type": "text", "text": str(content)}] + self.message_manager.add_assistant_message(formatted_content) + logger.info("Added converted assistant message") + try: # Handle Anthropic response format if self.provider == LLMProvider.ANTHROPIC: if hasattr(response, "content") and isinstance(response.content, list): - # Extract text from content blocks + # First convert Anthropic response to standard format + standard_content = [] + for block in response.content: + if hasattr(block, "type"): + if block.type == "text": + standard_content.append({"type": "text", "text": block.text}) + content = block.text + else: + # Add other block types + block_dict = {} + for key, value in vars(block).items(): + if not key.startswith("_"): + block_dict[key] = value + standard_content.append(block_dict) + continue + + # Add standard format response to messages using the message manager + add_assistant_message(standard_content) + + # Now extract JSON from the content for action execution + # Try to find JSON in the text blocks + json_content = None + parsed_content = None + for block in response.content: if hasattr(block, "type") and block.type == "text": content = block.text - - # Try to find JSON in the content try: # First look for JSON block json_content = extract_data(content, "json") parsed_content = json.loads(json_content) logger.info("Successfully parsed JSON from code block") + break except (json.JSONDecodeError, IndexError): # If no JSON block, try to find JSON object in the text try: @@ -366,61 +360,51 @@ class OmniLoop(BaseLoop): json_str = json_match.group(0) parsed_content = json.loads(json_str) logger.info("Successfully parsed JSON from text") - else: - logger.error(f"No JSON found in content: {content}") - continue - except json.JSONDecodeError as e: - logger.error(f"Failed to parse JSON from text: {str(e)}") + break + except json.JSONDecodeError: continue - # Clean up Box ID format - if "Box ID" in parsed_content and isinstance( - parsed_content["Box ID"], str - ): - parsed_content["Box ID"] = parsed_content["Box ID"].replace( - "Box #", "" - ) + if parsed_content: + # Clean up Box ID format + if "Box ID" in parsed_content and isinstance(parsed_content["Box ID"], str): + parsed_content["Box ID"] = parsed_content["Box ID"].replace("Box #", "") - # Add any explanatory text as reasoning if not present - if "Explanation" not in parsed_content: - # Extract any text before the JSON as reasoning + # Add any explanatory text as reasoning if not present + if "Explanation" not in parsed_content: + # Extract any text before the JSON as reasoning + if content: text_before_json = content.split("{")[0].strip() if text_before_json: parsed_content["Explanation"] = text_before_json - # Log the parsed content for debugging - logger.info(f"Parsed content: {json.dumps(parsed_content, indent=2)}") + # Log the parsed content for debugging + logger.info(f"Parsed content: {json.dumps(parsed_content, indent=2)}") - # Add response to messages - messages.append( - {"role": "assistant", "content": json.dumps(parsed_content)} + try: + # Execute action with current parsed screen info using the ActionExecutor + action_screenshot_saved = await self.action_executor.execute_action( + parsed_content, cast(ParseResult, parsed_screen) ) + except Exception as e: + logger.error(f"Error executing action: {str(e)}") + # Update the last assistant message with error + error_message = [ + {"type": "text", "text": f"Error executing action: {str(e)}"} + ] + # Replace the last assistant message with the error + self.message_manager.add_assistant_message(error_message) + return False, action_screenshot_saved - try: - # Execute action with current parsed screen info - await self._execute_action( - parsed_content, cast(ParseResult, parsed_screen) - ) - action_screenshot_saved = True - except Exception as e: - logger.error(f"Error executing action: {str(e)}") - # Add error message to conversation - messages.append( - { - "role": "assistant", - "content": f"Error executing action: {str(e)}", - "metadata": {"title": "❌ Error"}, - } - ) - return False, action_screenshot_saved + # Check if task is complete + if parsed_content.get("Action") == "None": + return False, action_screenshot_saved + return True, action_screenshot_saved + else: + logger.warning("No JSON found in response content") + return True, action_screenshot_saved - # Check if task is complete - if parsed_content.get("Action") == "None": - return False, action_screenshot_saved - return True, action_screenshot_saved - - logger.warning("No text block found in Anthropic response") - return True, action_screenshot_saved + logger.warning("No text block found in Anthropic response") + return True, action_screenshot_saved # Handle other providers' response formats if isinstance(response, dict) and "choices" in response: @@ -464,23 +448,19 @@ class OmniLoop(BaseLoop): if text_before_json: parsed_content["Explanation"] = text_before_json - # Add response to messages with stringified content - messages.append({"role": "assistant", "content": json.dumps(parsed_content)}) + # Add response to messages with stringified content using our helper + add_assistant_message([{"type": "text", "text": json.dumps(parsed_content)}]) try: - # Execute action with current parsed screen info - await self._execute_action(parsed_content, cast(ParseResult, parsed_screen)) - action_screenshot_saved = True + # Execute action with current parsed screen info using the ActionExecutor + action_screenshot_saved = await self.action_executor.execute_action( + parsed_content, cast(ParseResult, parsed_screen) + ) except Exception as e: logger.error(f"Error executing action: {str(e)}") - # Add error message to conversation - messages.append( - { - "role": "assistant", - "content": f"Error executing action: {str(e)}", - "metadata": {"title": "❌ Error"}, - } - ) + # Add error message using the message manager + error_message = [{"type": "text", "text": f"Error executing action: {str(e)}"}] + self.message_manager.add_assistant_message(error_message) return False, action_screenshot_saved # Check if task is complete @@ -490,22 +470,18 @@ class OmniLoop(BaseLoop): return True, action_screenshot_saved elif isinstance(content, dict): # Handle case where content is already a dictionary - messages.append({"role": "assistant", "content": json.dumps(content)}) + add_assistant_message([{"type": "text", "text": json.dumps(content)}]) try: - # Execute action with current parsed screen info - await self._execute_action(content, cast(ParseResult, parsed_screen)) - action_screenshot_saved = True + # Execute action with current parsed screen info using the ActionExecutor + action_screenshot_saved = await self.action_executor.execute_action( + content, cast(ParseResult, parsed_screen) + ) except Exception as e: logger.error(f"Error executing action: {str(e)}") - # Add error message to conversation - messages.append( - { - "role": "assistant", - "content": f"Error executing action: {str(e)}", - "metadata": {"title": "❌ Error"}, - } - ) + # Add error message using the message manager + error_message = [{"type": "text", "text": f"Error executing action: {str(e)}"}] + self.message_manager.add_assistant_message(error_message) return False, action_screenshot_saved # Check if task is complete @@ -518,17 +494,20 @@ class OmniLoop(BaseLoop): except Exception as e: logger.error(f"Error handling response: {str(e)}") - messages.append( - { - "role": "assistant", - "content": f"Error: {str(e)}", - "metadata": {"title": "❌ Error"}, - } - ) + # Add error message using the message manager + error_message = [{"type": "text", "text": f"Error: {str(e)}"}] + self.message_manager.add_assistant_message(error_message) raise + ########################################### + # SCREEN PARSING - IMPLEMENTING ABSTRACT METHOD + ########################################### + async def _get_parsed_screen_som(self, save_screenshot: bool = True) -> ParseResult: - """Get parsed screen information with SOM. + """Get parsed screen information with Screen Object Model. + + Extends the base class method to use the OmniParser to parse the screen + and extract UI elements. Args: save_screenshot: Whether to save the screenshot (set to False when screenshots will be saved elsewhere) @@ -563,337 +542,29 @@ class OmniLoop(BaseLoop): logger.error(f"Error getting parsed screen: {str(e)}") raise - async def _process_screen( - self, parsed_screen: ParseResult, messages: List[Dict[str, Any]] - ) -> None: - """Process and add screen info to messages.""" - try: - # Only add message if we have an image and provider supports it - if self.provider in [LLMProvider.OPENAI, LLMProvider.ANTHROPIC]: - image = parsed_screen.annotated_image_base64 or None - if image: - # Save screen info to current turn directory - if self.current_turn_dir: - # Save elements as JSON - elements_path = os.path.join(self.current_turn_dir, "elements.json") - with open(elements_path, "w") as f: - # Convert elements to dicts for JSON serialization - elements_json = [elem.model_dump() for elem in parsed_screen.elements] - json.dump(elements_json, f, indent=2) - logger.info(f"Saved elements to {elements_path}") - - # Format the image content based on the provider - if self.provider == LLMProvider.ANTHROPIC: - # Compress the image before sending to Anthropic (5MB limit) - image_size = len(image) - logger.info(f"Image base64 is present, length: {image_size}") - - # Anthropic has a 5MB limit - check against base64 string length - # which is what matters for the API call payload - # Use slightly smaller limit (4.9MB) to account for request overhead - max_size = int(4.9 * 1024 * 1024) # 4.9MB - - # Default media type (will be overridden if compression is needed) - media_type = "image/png" - - # Check if the image already has a media type prefix - if image.startswith("data:"): - parts = image.split(",", 1) - if len(parts) == 2 and "image/jpeg" in parts[0].lower(): - media_type = "image/jpeg" - elif len(parts) == 2 and "image/png" in parts[0].lower(): - media_type = "image/png" - - if image_size > max_size: - logger.info( - f"Image size ({image_size} bytes) exceeds Anthropic limit ({max_size} bytes), compressing..." - ) - image, media_type = compress_image_base64(image, max_size) - logger.info( - f"Image compressed to {len(image)} bytes with media_type {media_type}" - ) - - # Anthropic uses "type": "image" - screen_info_msg = { - "role": "user", - "content": [ - { - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": image, - }, - } - ], - } - else: - # OpenAI and others use "type": "image_url" - screen_info_msg = { - "role": "user", - "content": [ - { - "type": "image_url", - "image_url": {"url": f"data:image/png;base64,{image}"}, - } - ], - } - messages.append(screen_info_msg) - - except Exception as e: - logger.error(f"Error processing screen info: {str(e)}") - raise - def _get_system_prompt(self) -> str: """Get the system prompt for the model.""" return SYSTEM_PROMPT - async def _execute_action(self, content: Dict[str, Any], parsed_screen: ParseResult) -> None: - """Execute the action specified in the content using the tool manager. - - Args: - content: Dictionary containing the action details - parsed_screen: Current parsed screen information - """ - try: - action = content.get("Action", "").lower() - if not action: - return - - # Track if we saved an action-specific screenshot - action_screenshot_saved = False - - try: - # Prepare kwargs based on action type - kwargs = {} - - if action in ["left_click", "right_click", "double_click", "move_cursor"]: - try: - box_id = int(content["Box ID"]) - logger.info(f"Processing Box ID: {box_id}") - - # Calculate click coordinates - x, y = await self._calculate_click_coordinates(box_id, parsed_screen) - logger.info(f"Calculated coordinates: x={x}, y={y}") - - kwargs["x"] = x - kwargs["y"] = y - - # Visualize action if screenshot is available - if parsed_screen.annotated_image_base64: - img_data = parsed_screen.annotated_image_base64 - # Remove data URL prefix if present - if img_data.startswith("data:image"): - img_data = img_data.split(",")[1] - # Only save visualization for coordinate-based actions - self._visualize_action(x, y, img_data) - action_screenshot_saved = True - - except ValueError as e: - logger.error(f"Error processing Box ID: {str(e)}") - return - - elif action == "drag_to": - try: - box_id = int(content["Box ID"]) - x, y = await self._calculate_click_coordinates(box_id, parsed_screen) - kwargs.update( - { - "x": x, - "y": y, - "button": content.get("button", "left"), - "duration": float(content.get("duration", 0.5)), - } - ) - - # Visualize drag destination if screenshot is available - if parsed_screen.annotated_image_base64: - img_data = parsed_screen.annotated_image_base64 - # Remove data URL prefix if present - if img_data.startswith("data:image"): - img_data = img_data.split(",")[1] - # Only save visualization for coordinate-based actions - self._visualize_action(x, y, img_data) - action_screenshot_saved = True - - except ValueError as e: - logger.error(f"Error processing drag coordinates: {str(e)}") - return - - elif action == "type_text": - kwargs["text"] = content["Value"] - # For type_text, store the value in the action type - action_type = f"type_{content['Value'][:20]}" # Truncate if too long - elif action == "press_key": - kwargs["key"] = content["Value"] - action_type = f"press_{content['Value']}" - elif action == "hotkey": - if isinstance(content.get("Value"), list): - keys = content["Value"] - action_type = f"hotkey_{'_'.join(keys)}" - else: - # Simply split string format like "command+space" into a list - keys = [k.strip() for k in content["Value"].lower().split("+")] - action_type = f"hotkey_{content['Value'].replace('+', '_')}" - logger.info(f"Preparing hotkey with keys: {keys}") - # Get the method but call it with *args instead of **kwargs - method = getattr(self.computer.interface, action) - await method(*keys) # Unpack the keys list as positional arguments - logger.info(f"Tool execution completed successfully: {action}") - - # For hotkeys, take a screenshot after the action - try: - # Get a new screenshot after the action and save it with the action type - new_parsed_screen = await self._get_parsed_screen_som(save_screenshot=False) - if new_parsed_screen and new_parsed_screen.annotated_image_base64: - img_data = new_parsed_screen.annotated_image_base64 - # Remove data URL prefix if present - if img_data.startswith("data:image"): - img_data = img_data.split(",")[1] - # Save with action type to indicate this is a post-action screenshot - self._save_screenshot(img_data, action_type=action_type) - action_screenshot_saved = True - except Exception as screenshot_error: - logger.error( - f"Error taking post-hotkey screenshot: {str(screenshot_error)}" - ) - - return - - elif action in ["scroll_down", "scroll_up"]: - clicks = int(content.get("amount", 1)) - kwargs["clicks"] = clicks - action_type = f"scroll_{action.split('_')[1]}_{clicks}" - - # Visualize scrolling if screenshot is available - if parsed_screen.annotated_image_base64: - img_data = parsed_screen.annotated_image_base64 - # Remove data URL prefix if present - if img_data.startswith("data:image"): - img_data = img_data.split(",")[1] - direction = "down" if action == "scroll_down" else "up" - # For scrolling, we only save the visualization to avoid duplicate images - self._visualize_scroll(direction, clicks, img_data) - action_screenshot_saved = True - - else: - logger.warning(f"Unknown action: {action}") - return - - # Execute tool and handle result - try: - method = getattr(self.computer.interface, action) - logger.info(f"Found method for action '{action}': {method}") - await method(**kwargs) - logger.info(f"Tool execution completed successfully: {action}") - - # For non-coordinate based actions that don't already have visualizations, - # take a new screenshot after the action - if not action_screenshot_saved: - # Take a new screenshot - try: - # Get a new screenshot after the action and save it with the action type - new_parsed_screen = await self._get_parsed_screen_som( - save_screenshot=False - ) - if new_parsed_screen and new_parsed_screen.annotated_image_base64: - img_data = new_parsed_screen.annotated_image_base64 - # Remove data URL prefix if present - if img_data.startswith("data:image"): - img_data = img_data.split(",")[1] - # Save with action type to indicate this is a post-action screenshot - if "action_type" in locals(): - self._save_screenshot(img_data, action_type=action_type) - else: - self._save_screenshot(img_data, action_type=action) - # Update the action screenshot flag for this turn - action_screenshot_saved = True - except Exception as screenshot_error: - logger.error( - f"Error taking post-action screenshot: {str(screenshot_error)}" - ) - - except AttributeError as e: - logger.error(f"Method not found for action '{action}': {str(e)}") - return - except Exception as tool_error: - logger.error(f"Tool execution failed: {str(tool_error)}") - return - - except Exception as e: - logger.error(f"Error executing action {action}: {str(e)}") - return - - except Exception as e: - logger.error(f"Error in _execute_action: {str(e)}") - return - - async def _calculate_click_coordinates( - self, box_id: int, parsed_screen: ParseResult - ) -> Tuple[int, int]: - """Calculate click coordinates based on box ID. - - Args: - box_id: The ID of the box to click - parsed_screen: The parsed screen information - - Returns: - Tuple of (x, y) coordinates - - Raises: - ValueError: If box_id is invalid or missing from parsed screen - """ - # First try to use structured elements data - logger.info(f"Elements count: {len(parsed_screen.elements)}") - - # Try to find element with matching ID - for element in parsed_screen.elements: - if element.id == box_id: - logger.info(f"Found element with ID {box_id}: {element}") - bbox = element.bbox - - # Get screen dimensions from the metadata if available, or fallback - width = parsed_screen.metadata.width if parsed_screen.metadata else 1920 - height = parsed_screen.metadata.height if parsed_screen.metadata else 1080 - logger.info(f"Screen dimensions: width={width}, height={height}") - - # Calculate center of the box in pixels - center_x = int((bbox.x1 + bbox.x2) / 2 * width) - center_y = int((bbox.y1 + bbox.y2) / 2 * height) - logger.info(f"Calculated center: ({center_x}, {center_y})") - - # Validate coordinates - if they're (0,0) or unreasonably small, - # use a default position in the center of the screen - if center_x == 0 and center_y == 0: - logger.warning("Got (0,0) coordinates, using fallback position") - center_x = width // 2 - center_y = height // 2 - logger.info(f"Using fallback center: ({center_x}, {center_y})") - - return center_x, center_y - - # If we couldn't find the box, use center of screen - logger.error( - f"Box ID {box_id} not found in structured elements (count={len(parsed_screen.elements)})" - ) - - # Use center of screen as fallback - width = parsed_screen.metadata.width if parsed_screen.metadata else 1920 - height = parsed_screen.metadata.height if parsed_screen.metadata else 1080 - logger.warning(f"Using fallback position in center of screen ({width//2}, {height//2})") - return width // 2, height // 2 + ########################################### + # MAIN LOOP - IMPLEMENTING ABSTRACT METHOD + ########################################### async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]: """Run the agent loop with provided messages. + Implements abstract method from BaseLoop to handle the main agent loop + for the OmniLoop implementation. + Args: - messages: List of message objects + messages: List of message objects in standard OpenAI format Yields: Dict containing response data """ - # Keep track of conversation history - conversation_history = messages.copy() + # Initialize the message manager with the provided messages + self.message_manager.messages = messages.copy() + logger.info(f"Starting OmniLoop run with {len(self.message_manager.messages)} messages") # Continue running until explicitly told to stop running = True @@ -922,26 +593,64 @@ class OmniLoop(BaseLoop): # Get up-to-date screen information parsed_screen = await self._get_parsed_screen_som() - # Process screen info and update messages - await self._process_screen(parsed_screen, conversation_history) + # Process screen info and update messages in standard format + try: + # Get image from parsed screen + image = parsed_screen.annotated_image_base64 or None + if image: + # Save elements as JSON if we have a turn directory + if self.current_turn_dir and hasattr(parsed_screen, "elements"): + elements_path = os.path.join(self.current_turn_dir, "elements.json") + with open(elements_path, "w") as f: + # Convert elements to dicts for JSON serialization + elements_json = [ + elem.model_dump() for elem in parsed_screen.elements + ] + json.dump(elements_json, f, indent=2) + logger.info(f"Saved elements to {elements_path}") + + # Remove data URL prefix if present + if "," in image: + image = image.split(",")[1] + + # Add screenshot to message history using message manager + self.message_manager.add_user_message( + [ + { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{image}"}, + } + ] + ) + logger.info("Added screenshot to message history") + except Exception as e: + logger.error(f"Error processing screen info: {str(e)}") + raise # Get system prompt system_prompt = self._get_system_prompt() - # Make API call with retries - response = await self._make_api_call(conversation_history, system_prompt) + # Make API call with retries using the APIHandler + response = await self.api_handler.make_api_call( + self.message_manager.messages, system_prompt + ) # Handle the response (may execute actions) # Returns: (should_continue, action_screenshot_saved) should_continue, new_screenshot_saved = await self._handle_response( - response, conversation_history, parsed_screen + response, self.message_manager.messages, parsed_screen ) # Update whether an action screenshot was saved this turn action_screenshot_saved = action_screenshot_saved or new_screenshot_saved + # Create OpenAI-compatible response format + openai_compatible_response = await self._create_openai_compatible_response( + response, self.message_manager.messages, parsed_screen + ) + # Yield the response to the caller - yield {"response": response} + yield openai_compatible_response # Check if we should continue this conversation running = should_continue @@ -969,3 +678,215 @@ class OmniLoop(BaseLoop): # Create a brief delay before retrying await asyncio.sleep(1) + + async def _create_openai_compatible_response( + self, + response: Any, + messages: List[Dict[str, Any]], + parsed_screen: Optional[ParseResult] = None, + ) -> Dict[str, Any]: + """Create an OpenAI computer use agent compatible response format. + + Args: + response: The original API response + messages: List of messages in standard OpenAI format + parsed_screen: Optional pre-parsed screen information + + Returns: + A response formatted according to OpenAI's computer use agent standard + """ + from datetime import datetime + import time + + # Create a unique ID for this response + response_id = f"resp_{datetime.now().strftime('%Y%m%d%H%M%S')}_{id(response)}" + reasoning_id = f"rs_{response_id}" + action_id = f"cu_{response_id}" + call_id = f"call_{response_id}" + + # Extract the last assistant message + assistant_msg = None + for msg in reversed(messages): + if msg["role"] == "assistant": + assistant_msg = msg + break + + if not assistant_msg: + # If no assistant message found, create a default one + assistant_msg = {"role": "assistant", "content": "No response available"} + + # Initialize output array + output_items = [] + + # Extract reasoning and action details from the response + content = assistant_msg["content"] + reasoning_text = None + action_details = None + + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + try: + # Try to parse JSON from text block + text_content = item.get("text", "") + parsed_json = json.loads(text_content) + + # Get reasoning text + if reasoning_text is None: + reasoning_text = parsed_json.get("Explanation", "") + + # Extract action details + action = parsed_json.get("Action", "").lower() + text_input = parsed_json.get("Text", "") + value = parsed_json.get("Value", "") # Also handle Value field + box_id = parsed_json.get("Box ID") # Extract Box ID + + if action in ["click", "left_click"]: + # Always calculate coordinates from Box ID for click actions + x, y = 100, 100 # Default fallback values + + if parsed_screen and box_id is not None: + try: + box_id_int = ( + box_id + if isinstance(box_id, int) + else int(str(box_id)) if str(box_id).isdigit() else None + ) + if box_id_int is not None: + # Use the ActionExecutor's method to calculate coordinates with await + x, y = await self.action_executor.calculate_click_coordinates( + box_id_int, parsed_screen + ) + logger.info( + f"Extracted coordinates for Box ID {box_id_int}: ({x}, {y})" + ) + except Exception as e: + logger.error( + f"Error extracting coordinates for Box ID {box_id}: {str(e)}" + ) + + action_details = { + "type": "click", + "button": "left", + "box_id": ( + ( + box_id + if isinstance(box_id, int) + else int(box_id) if str(box_id).isdigit() else None + ) + if box_id is not None + else None + ), + "x": x, + "y": y, + } + elif action in ["type", "type_text"] and (text_input or value): + action_details = { + "type": "type", + "text": text_input or value, + } + elif action == "hotkey" and value: + action_details = { + "type": "hotkey", + "keys": value, + } + elif action == "scroll": + # Use default coordinates for scrolling + delta_x = 0 + delta_y = 0 + # Try to extract scroll delta values from content if available + scroll_data = parsed_json.get("Scroll", {}) + if scroll_data: + delta_x = scroll_data.get("delta_x", 0) + delta_y = scroll_data.get("delta_y", 0) + action_details = { + "type": "scroll", + "x": 100, + "y": 100, + "scroll_x": delta_x, + "scroll_y": delta_y, + } + elif action == "none": + # Handle case when action is None (task completion) + action_details = {"type": "none", "description": "Task completed"} + except json.JSONDecodeError: + # If not JSON, just use as reasoning text + if reasoning_text is None: + reasoning_text = "" + reasoning_text += item.get("text", "") + + # Add reasoning item if we have text content + if reasoning_text: + output_items.append( + { + "type": "reasoning", + "id": reasoning_id, + "summary": [ + { + "type": "summary_text", + "text": reasoning_text[:200], # Truncate to reasonable length + } + ], + } + ) + + # If no action details extracted, use default + if not action_details: + action_details = { + "type": "click", + "button": "left", + "x": 100, + "y": 100, + } + + # Add computer_call item + computer_call = { + "type": "computer_call", + "id": action_id, + "call_id": call_id, + "action": action_details, + "pending_safety_checks": [], + "status": "completed", + } + output_items.append(computer_call) + + # Create the OpenAI-compatible response format with all expected fields + return { + "id": response_id, + "object": "response", + "created_at": int(time.time()), + "status": "completed", + "error": None, + "incomplete_details": None, + "instructions": None, + "max_output_tokens": None, + "model": self.model, + "output": output_items, + "parallel_tool_calls": True, + "previous_response_id": None, + "reasoning": {"effort": "medium", "generate_summary": "concise"}, + "store": True, + "temperature": 1.0, + "text": {"format": {"type": "text"}}, + "tool_choice": "auto", + "tools": [ + { + "type": "computer_use_preview", + "display_height": 768, + "display_width": 1024, + "environment": "mac", + } + ], + "top_p": 1.0, + "truncation": "auto", + "usage": { + "input_tokens": 0, # Placeholder values + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens": 0, # Placeholder values + "output_tokens_details": {"reasoning_tokens": 0}, + "total_tokens": 0, # Placeholder values + }, + "user": None, + "metadata": {}, + # Include the original response for backward compatibility + "response": {"choices": [{"message": assistant_msg, "finish_reason": "stop"}]}, + } diff --git a/libs/agent/agent/providers/omni/messages.py b/libs/agent/agent/providers/omni/messages.py deleted file mode 100644 index 8c1824d7..00000000 --- a/libs/agent/agent/providers/omni/messages.py +++ /dev/null @@ -1,171 +0,0 @@ -"""Omni message manager implementation.""" - -import base64 -from typing import Any, Dict, List, Optional -from io import BytesIO -from PIL import Image - -from ...core.messages import BaseMessageManager, ImageRetentionConfig - - -class OmniMessageManager(BaseMessageManager): - """Message manager for multi-provider support.""" - - def __init__(self, config: Optional[ImageRetentionConfig] = None): - """Initialize the message manager. - - Args: - config: Optional configuration for image retention - """ - super().__init__(config) - self.messages: List[Dict[str, Any]] = [] - self.config = config - - def add_user_message(self, content: str, images: Optional[List[bytes]] = None) -> None: - """Add a user message to the history. - - Args: - content: Message content - images: Optional list of image data - """ - # Add images if present - if images: - # Initialize with proper typing for mixed content - message_content: List[Dict[str, Any]] = [{"type": "text", "text": content}] - - # Add each image - for img in images: - message_content.append( - { - "type": "image_url", - "image_url": { - "url": f"data:image/png;base64,{base64.b64encode(img).decode()}" - }, - } - ) - - message = {"role": "user", "content": message_content} - else: - # Simple text message - message = {"role": "user", "content": content} - - self.messages.append(message) - - # Apply retention policy - if self.config and self.config.num_images_to_keep: - self._apply_image_retention_policy() - - def add_assistant_message(self, content: str) -> None: - """Add an assistant message to the history. - - Args: - content: Message content - """ - self.messages.append({"role": "assistant", "content": content}) - - def add_system_message(self, content: str) -> None: - """Add a system message to the history. - - Args: - content: Message content - """ - self.messages.append({"role": "system", "content": content}) - - def _apply_image_retention_policy(self) -> None: - """Apply image retention policy to message history.""" - if not self.config or not self.config.num_images_to_keep: - return - - # Count images from newest to oldest - image_count = 0 - for message in reversed(self.messages): - if message["role"] != "user": - continue - - # Handle multimodal messages - if isinstance(message["content"], list): - new_content = [] - for item in message["content"]: - if item["type"] == "text": - new_content.append(item) - elif item["type"] == "image_url": - if image_count < self.config.num_images_to_keep: - new_content.append(item) - image_count += 1 - message["content"] = new_content - - def get_formatted_messages(self, provider: str) -> List[Dict[str, Any]]: - """Get messages formatted for specific provider. - - Args: - provider: Provider name to format messages for - - Returns: - List of formatted messages - """ - # Set the provider for message formatting - self.set_provider(provider) - - if provider == "anthropic": - return self._format_for_anthropic() - elif provider == "openai": - return self._format_for_openai() - elif provider == "groq": - return self._format_for_groq() - elif provider == "qwen": - return self._format_for_qwen() - else: - raise ValueError(f"Unsupported provider: {provider}") - - def _format_for_anthropic(self) -> List[Dict[str, Any]]: - """Format messages for Anthropic API.""" - formatted = [] - for msg in self.messages: - formatted_msg = {"role": msg["role"]} - - # Handle multimodal content - if isinstance(msg["content"], list): - formatted_msg["content"] = [] - for item in msg["content"]: - if item["type"] == "text": - formatted_msg["content"].append({"type": "text", "text": item["text"]}) - elif item["type"] == "image_url": - formatted_msg["content"].append( - { - "type": "image", - "source": { - "type": "base64", - "media_type": "image/png", - "data": item["image_url"]["url"].split(",")[1], - }, - } - ) - else: - formatted_msg["content"] = msg["content"] - - formatted.append(formatted_msg) - return formatted - - def _format_for_openai(self) -> List[Dict[str, Any]]: - """Format messages for OpenAI API.""" - # OpenAI already uses the same format - return self.messages - - def _format_for_groq(self) -> List[Dict[str, Any]]: - """Format messages for Groq API.""" - # Groq uses OpenAI-compatible format - return self.messages - - def _format_for_qwen(self) -> List[Dict[str, Any]]: - """Format messages for Qwen API.""" - formatted = [] - for msg in self.messages: - if isinstance(msg["content"], list): - # Convert multimodal content to text-only - text_content = next( - (item["text"] for item in msg["content"] if item["type"] == "text"), "" - ) - formatted.append({"role": msg["role"], "content": text_content}) - else: - formatted.append(msg) - return formatted diff --git a/libs/agent/agent/providers/omni/visualization.py b/libs/agent/agent/providers/omni/visualization.py deleted file mode 100644 index 5d856457..00000000 --- a/libs/agent/agent/providers/omni/visualization.py +++ /dev/null @@ -1,130 +0,0 @@ -"""Visualization utilities for the Cua provider.""" - -import base64 -import logging -from io import BytesIO -from typing import Tuple -from PIL import Image, ImageDraw - -logger = logging.getLogger(__name__) - - -def visualize_click(x: int, y: int, img_base64: str) -> Image.Image: - """Visualize a click action by drawing on the screenshot. - - Args: - x: X coordinate of the click - y: Y coordinate of the click - img_base64: Base64 encoded image to draw on - - Returns: - PIL Image with visualization - """ - try: - # Decode the base64 image - img_data = base64.b64decode(img_base64) - img = Image.open(BytesIO(img_data)) - - # Create a drawing context - draw = ImageDraw.Draw(img) - - # Draw concentric circles at the click position - small_radius = 10 - large_radius = 30 - - # Draw filled inner circle - draw.ellipse( - [(x - small_radius, y - small_radius), (x + small_radius, y + small_radius)], - fill="red", - ) - - # Draw outlined outer circle - draw.ellipse( - [(x - large_radius, y - large_radius), (x + large_radius, y + large_radius)], - outline="red", - width=3, - ) - - return img - - except Exception as e: - logger.error(f"Error visualizing click: {str(e)}") - # Return a blank image in case of error - return Image.new("RGB", (800, 600), color="white") - - -def visualize_scroll(direction: str, clicks: int, img_base64: str) -> Image.Image: - """Visualize a scroll action by drawing arrows on the screenshot. - - Args: - direction: 'up' or 'down' - clicks: Number of scroll clicks - img_base64: Base64 encoded image to draw on - - Returns: - PIL Image with visualization - """ - try: - # Decode the base64 image - img_data = base64.b64decode(img_base64) - img = Image.open(BytesIO(img_data)) - - # Get image dimensions - width, height = img.size - - # Create a drawing context - draw = ImageDraw.Draw(img) - - # Determine arrow direction and positions - center_x = width // 2 - arrow_width = 100 - - if direction.lower() == "up": - # Draw up arrow in the middle of the screen - arrow_y = height // 2 - # Arrow points - points = [ - (center_x, arrow_y - 50), # Top point - (center_x - arrow_width // 2, arrow_y + 50), # Bottom left - (center_x + arrow_width // 2, arrow_y + 50), # Bottom right - ] - color = "blue" - else: # down - # Draw down arrow in the middle of the screen - arrow_y = height // 2 - # Arrow points - points = [ - (center_x, arrow_y + 50), # Bottom point - (center_x - arrow_width // 2, arrow_y - 50), # Top left - (center_x + arrow_width // 2, arrow_y - 50), # Top right - ] - color = "green" - - # Draw filled arrow - draw.polygon(points, fill=color) - - # Add text showing number of clicks - text_y = arrow_y + 70 if direction.lower() == "down" else arrow_y - 70 - draw.text((center_x - 40, text_y), f"{clicks} clicks", fill="black") - - return img - - except Exception as e: - logger.error(f"Error visualizing scroll: {str(e)}") - # Return a blank image in case of error - return Image.new("RGB", (800, 600), color="white") - - -def calculate_element_center(box: Tuple[int, int, int, int]) -> Tuple[int, int]: - """Calculate the center coordinates of a bounding box. - - Args: - box: Tuple of (left, top, right, bottom) coordinates - - Returns: - Tuple of (center_x, center_y) coordinates - """ - left, top, right, bottom = box - center_x = (left + right) // 2 - center_y = (top + bottom) // 2 - return center_x, center_y diff --git a/notebooks/openai_cua_nb.ipynb b/notebooks/openai_cua_nb.ipynb new file mode 100644 index 00000000..0ce86162 --- /dev/null +++ b/notebooks/openai_cua_nb.ipynb @@ -0,0 +1,134 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install openai\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "import os\n", + "\n", + "response = requests.post(\n", + " \"https://api.openai.com/v1/responses\",\n", + " headers={\n", + " \"Content-Type\": \"application/json\", \n", + " \"Authorization\": f\"Bearer {os.environ['OPENAI_API_KEY']}\"\n", + " },\n", + " json={\n", + " \"model\": \"computer-use-preview\",\n", + " \"tools\": [{\n", + " \"type\": \"computer_use_preview\",\n", + " \"display_width\": 1024,\n", + " \"display_height\": 768,\n", + " \"environment\": \"mac\" # other possible values: \"mac\", \"windows\", \"ubuntu\"\n", + " }],\n", + " \"input\": [\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"Check the latest OpenAI news on bing.com.\"\n", + " }\n", + " # Optional: include a screenshot of the initial state of the environment\n", + " # {\n", + " # type: \"input_image\", \n", + " # image_url: f\"data:image/png;base64,{screenshot_base64}\"\n", + " # }\n", + " ],\n", + " \"reasoning\": {\n", + " \"generate_summary\": \"concise\",\n", + " },\n", + " \"truncation\": \"auto\"\n", + " }\n", + ")\n", + "\n", + "print(response.json())" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + } + ], + "source": [ + "from openai import OpenAI\n", + "client = OpenAI() # assumes OPENAI_API_KEY is set in env\n", + "\n", + "def has_model_starting_with(prefix=\"computer\"):\n", + " models = client.models.list().data\n", + " return any(model.id.startswith(prefix) for model in models)\n", + "\n", + "print(has_model_starting_with())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import requests\n", + "import os\n", + "\n", + "response = requests.post(\n", + " \"https://api.openai.com/v1/responses\",\n", + " headers={\n", + " \"Content-Type\": \"application/json\",\n", + " \"Authorization\": f\"Bearer {os.environ['OPENAI_API_KEY']}\"\n", + " },\n", + " json={\n", + " \"model\": \"gpt-4o\",\n", + " \"input\": \"Tell me a three sentence bedtime story about a unicorn.\"\n", + " }\n", + ")\n", + "print(response.json())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}