diff --git a/examples/agent_examples.py b/examples/agent_examples.py index beb75265..cfb7dd52 100644 --- a/examples/agent_examples.py +++ b/examples/agent_examples.py @@ -29,9 +29,11 @@ async def run_agent_example(): # Create agent with loop and provider agent = ComputerAgent( computer=computer, + loop=AgentLoop.OPENAI, # loop=AgentLoop.ANTHROPIC, - loop=AgentLoop.OMNI, - model=LLM(provider=LLMProvider.OPENAI, name="gpt-4.5-preview"), + # loop=AgentLoop.OMNI, + model=LLM(provider=LLMProvider.OPENAI), # No model name for Operator CUA + # model=LLM(provider=LLMProvider.OPENAI, name="gpt-4.5-preview"), # model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"), save_trajectory=True, only_n_most_recent_images=3, diff --git a/libs/agent/agent/__init__.py b/libs/agent/agent/__init__.py index a0ce6c17..a6ac1216 100644 --- a/libs/agent/agent/__init__.py +++ b/libs/agent/agent/__init__.py @@ -49,7 +49,7 @@ except Exception as e: logger.warning(f"Error initializing telemetry: {e}") from .providers.omni.types import LLMProvider, LLM -from .core.loop import AgentLoop -from .core.computer_agent import ComputerAgent +from .core.factory import AgentLoop +from .core.agent import ComputerAgent __all__ = ["AgentLoop", "LLMProvider", "LLM", "ComputerAgent"] diff --git a/libs/agent/agent/core/__init__.py b/libs/agent/agent/core/__init__.py index b4657510..7fa4aae1 100644 --- a/libs/agent/agent/core/__init__.py +++ b/libs/agent/agent/core/__init__.py @@ -1,6 +1,6 @@ """Core agent components.""" -from .loop import BaseLoop +from .factory import BaseLoop from .messages import ( BaseMessageManager, ImageRetentionConfig, diff --git a/libs/agent/agent/core/computer_agent.py b/libs/agent/agent/core/agent.py similarity index 75% rename from libs/agent/agent/core/computer_agent.py rename to libs/agent/agent/core/agent.py index 8ad58ebb..372b5bf1 100644 --- a/libs/agent/agent/core/computer_agent.py +++ b/libs/agent/agent/core/agent.py @@ -3,32 +3,18 @@ import asyncio import logging import os -from typing import Any, AsyncGenerator, Dict, Optional, cast, List +from typing import AsyncGenerator, Optional from computer import Computer -from ..providers.anthropic.loop import AnthropicLoop -from ..providers.omni.loop import OmniLoop -from ..providers.omni.parser import OmniParser -from ..providers.omni.types import LLMProvider, LLM +from ..providers.omni.types import LLM from .. import AgentLoop -from .messages import StandardMessageManager, ImageRetentionConfig from .types import AgentResponse +from .factory import LoopFactory +from .provider_config import DEFAULT_MODELS, ENV_VARS logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -# Default models for different providers -DEFAULT_MODELS = { - LLMProvider.OPENAI: "gpt-4o", - LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219", -} - -# Map providers to their environment variable names -ENV_VARS = { - LLMProvider.OPENAI: "OPENAI_API_KEY", - LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY", -} - class ComputerAgent: """A computer agent that can perform automated tasks using natural language instructions.""" @@ -98,35 +84,27 @@ class ComputerAgent: f"No model specified for provider {self.provider} and no default found" ) - # Ensure computer is properly cast for typing purposes - computer_instance = self.computer - # Get API key from environment if not provided actual_api_key = api_key or os.environ.get(ENV_VARS[self.provider], "") if not actual_api_key: raise ValueError(f"No API key provided for {self.provider}") - # Initialize the appropriate loop based on the loop parameter - if loop == AgentLoop.ANTHROPIC: - self._loop = AnthropicLoop( - api_key=actual_api_key, - model=actual_model_name, - computer=computer_instance, - save_trajectory=save_trajectory, - base_dir=trajectory_dir, - only_n_most_recent_images=only_n_most_recent_images, - ) - else: - self._loop = OmniLoop( + # Create the appropriate loop using the factory + try: + # Let the factory create the appropriate loop with needed components + self._loop = LoopFactory.create_loop( + loop_type=loop, provider=self.provider, + computer=self.computer, + model_name=actual_model_name, api_key=actual_api_key, - model=actual_model_name, - computer=computer_instance, save_trajectory=save_trajectory, - base_dir=trajectory_dir, + trajectory_dir=trajectory_dir, only_n_most_recent_images=only_n_most_recent_images, - parser=OmniParser(), ) + except ValueError as e: + logger.error(f"Failed to create loop: {str(e)}") + raise # Initialize the message manager from the loop self.message_manager = self._loop.message_manager @@ -152,21 +130,6 @@ class ComputerAgent: else: logger.info("Computer already initialized, skipping initialization") - # Take a test screenshot to verify the computer is working - logger.info("Testing computer with a screenshot...") - try: - test_screenshot = await self.computer.interface.screenshot() - # Determine the screenshot size based on its type - if isinstance(test_screenshot, (bytes, bytearray, memoryview)): - size = len(test_screenshot) - elif hasattr(test_screenshot, "base64_image"): - size = len(test_screenshot.base64_image) - else: - size = "unknown" - logger.info(f"Screenshot test successful, size: {size}") - except Exception as e: - logger.error(f"Screenshot test failed: {str(e)}") - # Even though screenshot failed, we continue since some tests might not need it except Exception as e: logger.error(f"Error initializing computer in __aenter__: {str(e)}") raise @@ -232,7 +195,6 @@ class ComputerAgent: # Execute the task and yield results async for result in self._loop.run(self.message_manager.messages): - # Yield the result to the caller yield result except Exception as e: diff --git a/libs/agent/agent/core/loop.py b/libs/agent/agent/core/base.py similarity index 89% rename from libs/agent/agent/core/loop.py rename to libs/agent/agent/core/base.py index 31196632..fb91d855 100644 --- a/libs/agent/agent/core/loop.py +++ b/libs/agent/agent/core/base.py @@ -1,35 +1,21 @@ -"""Base agent loop implementation.""" +"""Base loop definitions.""" import logging import asyncio from abc import ABC, abstractmethod -from enum import Enum, auto -from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple -from datetime import datetime +from typing import Any, AsyncGenerator, Dict, List, Optional from computer import Computer -from .experiment import ExperimentManager from .messages import StandardMessageManager, ImageRetentionConfig from .types import AgentResponse +from .experiment import ExperimentManager logger = logging.getLogger(__name__) -class AgentLoop(Enum): - """Enumeration of available loop types.""" - - ANTHROPIC = auto() # Anthropic implementation - OMNI = auto() # OmniLoop implementation - # Add more loop types as needed - - class BaseLoop(ABC): """Base class for agent loops that handle message processing and tool execution.""" - ########################################### - # INITIALIZATION AND CONFIGURATION - ########################################### - def __init__( self, computer: Computer, @@ -68,6 +54,11 @@ class BaseLoop(ABC): self.only_n_most_recent_images = only_n_most_recent_images self._kwargs = kwargs + # Initialize message manager + self.message_manager = StandardMessageManager( + config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images) + ) + # Initialize experiment manager if self.save_trajectory and self.base_dir: self.experiment_manager = ExperimentManager( @@ -110,8 +101,7 @@ class BaseLoop(ABC): ) raise RuntimeError(f"Failed to initialize: {str(e)}") - ########################################### - + ########################################### # ABSTRACT METHODS TO BE IMPLEMENTED BY SUBCLASSES ########################################### @@ -125,17 +115,14 @@ class BaseLoop(ABC): raise NotImplementedError @abstractmethod - async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]: + def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]: """Run the agent loop with provided messages. - This method handles the main agent loop including message processing, - API calls, response handling, and action execution. - Args: messages: List of message objects - Yields: - Agent response format + Returns: + An async generator that yields agent responses """ raise NotImplementedError diff --git a/libs/agent/agent/core/factory.py b/libs/agent/agent/core/factory.py new file mode 100644 index 00000000..5b77cdc7 --- /dev/null +++ b/libs/agent/agent/core/factory.py @@ -0,0 +1,104 @@ +"""Base agent loop implementation.""" + +import logging +import importlib.util +from typing import Dict, Optional, Type, TYPE_CHECKING, Any, cast, Callable, Awaitable + +from computer import Computer +from .types import AgentLoop +from .base import BaseLoop + +# For type checking only +if TYPE_CHECKING: + from ..providers.omni.types import LLMProvider + +logger = logging.getLogger(__name__) + + +class LoopFactory: + """Factory class for creating agent loops.""" + + # Registry to store loop implementations + _loop_registry: Dict[AgentLoop, Type[BaseLoop]] = {} + + @classmethod + def create_loop( + cls, + loop_type: AgentLoop, + api_key: str, + model_name: str, + computer: Computer, + provider: Any = None, + save_trajectory: bool = True, + trajectory_dir: str = "trajectories", + only_n_most_recent_images: Optional[int] = None, + acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None, + ) -> BaseLoop: + """Create and return an appropriate loop instance based on type.""" + if loop_type == AgentLoop.ANTHROPIC: + # Lazy import AnthropicLoop only when needed + try: + from ..providers.anthropic.loop import AnthropicLoop + except ImportError: + raise ImportError( + "The 'anthropic' provider is not installed. " + "Install it with 'pip install cua-agent[anthropic]'" + ) + + return AnthropicLoop( + api_key=api_key, + model=model_name, + computer=computer, + save_trajectory=save_trajectory, + base_dir=trajectory_dir, + only_n_most_recent_images=only_n_most_recent_images, + ) + elif loop_type == AgentLoop.OPENAI: + # Lazy import OpenAILoop only when needed + try: + from ..providers.openai.loop import OpenAILoop + except ImportError: + raise ImportError( + "The 'openai' provider is not installed. " + "Install it with 'pip install cua-agent[openai]'" + ) + + return OpenAILoop( + api_key=api_key, + model=model_name, + computer=computer, + save_trajectory=save_trajectory, + base_dir=trajectory_dir, + only_n_most_recent_images=only_n_most_recent_images, + acknowledge_safety_check_callback=acknowledge_safety_check_callback, + ) + elif loop_type == AgentLoop.OMNI: + # Lazy import OmniLoop and related classes only when needed + try: + from ..providers.omni.loop import OmniLoop + from ..providers.omni.parser import OmniParser + from ..providers.omni.types import LLMProvider + except ImportError: + raise ImportError( + "The 'omni' provider is not installed. " + "Install it with 'pip install cua-agent[all]'" + ) + + if provider is None: + raise ValueError("Provider is required for OMNI loop type") + + # We know provider is the correct type at this point, so cast it + provider_instance = cast(LLMProvider, provider) + + return OmniLoop( + provider=provider_instance, + api_key=api_key, + model=model_name, + computer=computer, + save_trajectory=save_trajectory, + base_dir=trajectory_dir, + only_n_most_recent_images=only_n_most_recent_images, + parser=OmniParser(), + ) + else: + raise ValueError(f"Unsupported loop type: {loop_type}") diff --git a/libs/agent/agent/core/provider_config.py b/libs/agent/agent/core/provider_config.py new file mode 100644 index 00000000..f7078f3f --- /dev/null +++ b/libs/agent/agent/core/provider_config.py @@ -0,0 +1,15 @@ +"""Provider-specific configurations and constants.""" + +from ..providers.omni.types import LLMProvider + +# Default models for different providers +DEFAULT_MODELS = { + LLMProvider.OPENAI: "gpt-4o", + LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219", +} + +# Map providers to their environment variable names +ENV_VARS = { + LLMProvider.OPENAI: "OPENAI_API_KEY", + LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY", +} diff --git a/libs/agent/agent/core/types.py b/libs/agent/agent/core/types.py index e80d24c7..ae2af868 100644 --- a/libs/agent/agent/core/types.py +++ b/libs/agent/agent/core/types.py @@ -1,6 +1,16 @@ """Core type definitions.""" from typing import Any, Dict, List, Optional, TypedDict, Union +from enum import Enum, auto + + +class AgentLoop(Enum): + """Enumeration of available loop types.""" + + ANTHROPIC = auto() # Anthropic implementation + OMNI = auto() # OmniLoop implementation + OPENAI = auto() # OpenAI implementation + # Add more loop types as needed class AgentResponse(TypedDict, total=False): diff --git a/libs/agent/agent/providers/anthropic/loop.py b/libs/agent/agent/providers/anthropic/loop.py index 9bcaf233..0ccdc79a 100644 --- a/libs/agent/agent/providers/anthropic/loop.py +++ b/libs/agent/agent/providers/anthropic/loop.py @@ -16,7 +16,7 @@ from datetime import datetime from computer import Computer # Base imports -from ...core.loop import BaseLoop +from ...core.base import BaseLoop from ...core.messages import StandardMessageManager, ImageRetentionConfig from ...core.types import AgentResponse diff --git a/libs/agent/agent/providers/anthropic/messages/manager.py b/libs/agent/agent/providers/anthropic/messages/manager.py deleted file mode 100644 index f29af1b7..00000000 --- a/libs/agent/agent/providers/anthropic/messages/manager.py +++ /dev/null @@ -1,112 +0,0 @@ -from dataclasses import dataclass -from typing import cast -from anthropic.types.beta import ( - BetaMessageParam, - BetaCacheControlEphemeralParam, - BetaToolResultBlockParam, -) - - -@dataclass -class ImageRetentionConfig: - """Configuration for image retention in messages.""" - - num_images_to_keep: int | None = None - min_removal_threshold: int = 1 - enable_caching: bool = True - - def should_retain_images(self) -> bool: - """Check if image retention is enabled.""" - return self.num_images_to_keep is not None and self.num_images_to_keep > 0 - - -class MessageManager: - """Manages message preparation, including image retention and caching.""" - - def __init__(self, image_retention_config: ImageRetentionConfig): - """Initialize the message manager. - - Args: - image_retention_config: Configuration for image retention - """ - if image_retention_config.min_removal_threshold < 1: - raise ValueError("min_removal_threshold must be at least 1") - self.image_retention_config = image_retention_config - - def prepare_messages(self, messages: list[BetaMessageParam]) -> list[BetaMessageParam]: - """Prepare messages by applying image retention and caching as configured.""" - if self.image_retention_config.should_retain_images(): - self._filter_images(messages) - if self.image_retention_config.enable_caching: - self._inject_caching(messages) - return messages - - def _filter_images(self, messages: list[BetaMessageParam]) -> None: - """Filter messages to retain only the specified number of most recent images.""" - tool_result_blocks = cast( - list[BetaToolResultBlockParam], - [ - item - for message in messages - for item in (message["content"] if isinstance(message["content"], list) else []) - if isinstance(item, dict) and item.get("type") == "tool_result" - ], - ) - - total_images = sum( - 1 - for tool_result in tool_result_blocks - for content in tool_result.get("content", []) - if isinstance(content, dict) and content.get("type") == "image" - ) - - images_to_remove = total_images - (self.image_retention_config.num_images_to_keep or 0) - # Round down to nearest min_removal_threshold for better cache behavior - images_to_remove -= images_to_remove % self.image_retention_config.min_removal_threshold - - # Remove oldest images first - for tool_result in tool_result_blocks: - if isinstance(tool_result.get("content"), list): - new_content = [] - for content in tool_result.get("content", []): - if isinstance(content, dict) and content.get("type") == "image": - if images_to_remove > 0: - images_to_remove -= 1 - continue - new_content.append(content) - tool_result["content"] = new_content - - def _inject_caching(self, messages: list[BetaMessageParam]) -> None: - """Inject caching control for the most recent turns, limited to 3 blocks max to avoid API errors.""" - # Anthropic API allows a maximum of 4 blocks with cache_control - # We use 3 here to be safe, as the system block may also have cache_control - blocks_with_cache_control = 0 - max_cache_control_blocks = 3 - - for message in reversed(messages): - if message["role"] == "user" and isinstance(content := message["content"], list): - # Only add cache control to the latest message in each turn - if blocks_with_cache_control < max_cache_control_blocks: - blocks_with_cache_control += 1 - # Add cache control to the last content block only - if content and len(content) > 0: - content[-1]["cache_control"] = BetaCacheControlEphemeralParam( - type="ephemeral" - ) - else: - # Remove any existing cache control - if content and len(content) > 0: - content[-1].pop("cache_control", None) - - # Ensure we're not exceeding the limit by checking the total - if blocks_with_cache_control > max_cache_control_blocks: - # If we somehow exceeded the limit, remove excess cache controls - excess = blocks_with_cache_control - max_cache_control_blocks - for message in messages: - if excess <= 0: - break - - if message["role"] == "user" and isinstance(content := message["content"], list): - if content and len(content) > 0 and "cache_control" in content[-1]: - content[-1].pop("cache_control", None) - excess -= 1 diff --git a/libs/agent/agent/providers/anthropic/response_handler.py b/libs/agent/agent/providers/anthropic/response_handler.py index fd213dce..2b8d17e1 100644 --- a/libs/agent/agent/providers/anthropic/response_handler.py +++ b/libs/agent/agent/providers/anthropic/response_handler.py @@ -1,14 +1,11 @@ """Response and tool handling for Anthropic provider.""" import logging -from typing import Any, Dict, List, Optional, Tuple, cast +from typing import Any, Dict, List, Tuple, cast from anthropic.types.beta import ( BetaMessage, - BetaMessageParam, BetaTextBlock, - BetaTextBlockParam, - BetaToolUseBlockParam, BetaContentBlockParam, ) diff --git a/libs/agent/agent/providers/anthropic/utils.py b/libs/agent/agent/providers/anthropic/utils.py index 6f592838..c0afcd0f 100644 --- a/libs/agent/agent/providers/anthropic/utils.py +++ b/libs/agent/agent/providers/anthropic/utils.py @@ -1,14 +1,12 @@ """Utility functions for Anthropic message handling.""" -import time import logging import re from typing import Any, Dict, List, Optional, Tuple, cast -from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaTextBlock +from anthropic.types.beta import BetaMessage from ..omni.parser import ParseResult from ...core.types import AgentResponse from datetime import datetime -import json # Configure module logger logger = logging.getLogger(__name__) diff --git a/libs/agent/agent/providers/omni/loop.py b/libs/agent/agent/providers/omni/loop.py index 4a65ccbf..3223583e 100644 --- a/libs/agent/agent/providers/omni/loop.py +++ b/libs/agent/agent/providers/omni/loop.py @@ -10,7 +10,7 @@ from httpx import ConnectError, ReadTimeout from typing import cast from .parser import OmniParser, ParseResult -from ...core.loop import BaseLoop +from ...core.base import BaseLoop from ...core.visualization import VisualizationHelper from ...core.messages import StandardMessageManager, ImageRetentionConfig from .utils import to_openai_agent_response_format diff --git a/libs/agent/agent/providers/omni/types.py b/libs/agent/agent/providers/omni/types.py index 734a4af6..1f3aae93 100644 --- a/libs/agent/agent/providers/omni/types.py +++ b/libs/agent/agent/providers/omni/types.py @@ -9,8 +9,10 @@ class LLMProvider(StrEnum): """Supported LLM providers.""" ANTHROPIC = "anthropic" + OMNI = "omni" OPENAI = "openai" + @dataclass class LLM: """Configuration for LLM model and provider.""" diff --git a/libs/agent/agent/providers/openai/__init__.py b/libs/agent/agent/providers/openai/__init__.py new file mode 100644 index 00000000..c23cf80a --- /dev/null +++ b/libs/agent/agent/providers/openai/__init__.py @@ -0,0 +1,6 @@ +"""OpenAI Agent Response API provider for computer control.""" + +from .types import LLMProvider +from .loop import OpenAILoop + +__all__ = ["OpenAILoop", "LLMProvider"] diff --git a/libs/agent/agent/providers/openai/api/__init__.py b/libs/agent/agent/providers/openai/api/__init__.py new file mode 100644 index 00000000..201ea70e --- /dev/null +++ b/libs/agent/agent/providers/openai/api/__init__.py @@ -0,0 +1,5 @@ +"""OpenAI API client module.""" + +from .client import OpenAIClient + +__all__ = ["OpenAIClient"] diff --git a/libs/agent/agent/providers/openai/api/client.py b/libs/agent/agent/providers/openai/api/client.py new file mode 100644 index 00000000..3a795266 --- /dev/null +++ b/libs/agent/agent/providers/openai/api/client.py @@ -0,0 +1,137 @@ +"""OpenAI API client for Agent Response API.""" + +import logging +import json +import os +import httpx +from typing import Dict, List, Optional, Any, Union + +logger = logging.getLogger(__name__) + + +class OpenAIClient: + """Client for OpenAI's Agent Response API.""" + + def __init__( + self, + api_key: str, + model: str = "computer-use-preview", + base_url: str = "https://api.openai.com/v1", + max_retries: int = 3, + timeout: int = 120, + **kwargs, + ): + """Initialize OpenAI API client. + + Args: + api_key: OpenAI API key + model: Model to use for completions (should always be computer-use-preview) + base_url: Base URL for API requests + max_retries: Maximum number of retries for API calls + timeout: Timeout for API calls in seconds + **kwargs: Additional arguments to pass to the httpx client + """ + self.api_key = api_key + + # Always use computer-use-preview model + if model != "computer-use-preview": + logger.warning( + f"Overriding provided model '{model}' with required model 'computer-use-preview'" + ) + model = "computer-use-preview" + + self.model = model + self.base_url = base_url + self.max_retries = max_retries + self.timeout = timeout + + # Create httpx client with auth and timeout + self.client = httpx.AsyncClient( + timeout=timeout, + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + "OpenAI-Beta": "computer-use-2023-09-30", # Required beta header for computer use + }, + **kwargs, + ) + + # Additional initialization for organization if available + openai_org = os.environ.get("OPENAI_ORG") + if openai_org: + self.client.headers["OpenAI-Organization"] = openai_org + + logger.info(f"Initialized OpenAI client with model {model}") + + async def create_response( + self, + input: List[Dict[str, Any]], + tools: Optional[List[Dict[str, Any]]] = None, + truncation: str = "auto", + temperature: float = 0.7, + top_p: float = 1.0, + **kwargs, + ) -> Dict[str, Any]: + """Create a response using the OpenAI Agent Response API. + + Args: + input: List of messages in the conversation (must be in Agent Response API format) + tools: List of tools available to the agent + truncation: How to handle truncation (auto, truncate) + temperature: Sampling temperature + top_p: Nucleus sampling parameter + **kwargs: Additional parameters to include in the request + + Returns: + Response from the API + """ + url = f"{self.base_url}/responses" + + # Prepare request payload + payload = { + "model": self.model, + "input": input, + "temperature": temperature, + "top_p": top_p, + "truncation": truncation, + **kwargs, + } + + # Add tools if provided + if tools: + payload["tools"] = tools + + try: + logger.debug(f"Sending request to {url}") + + # Make API call + response = await self.client.post(url, json=payload) + + # Check for errors + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + error_detail = e.response.text + try: + # Try to parse the error as JSON for better debugging + error_json = json.loads(error_detail) + logger.error(f"HTTP error from OpenAI API: {json.dumps(error_json, indent=2)}") + except: + logger.error(f"HTTP error from OpenAI API: {error_detail}") + raise + + result = response.json() + logger.debug("Received successful response") + return result + + except httpx.HTTPStatusError as e: + error_detail = e.response.text if hasattr(e, "response") else str(e) + logger.error(f"HTTP error from OpenAI API: {error_detail}") + raise RuntimeError(f"OpenAI API error: {error_detail}") + except Exception as e: + logger.error(f"Error calling OpenAI API: {str(e)}") + raise RuntimeError(f"Error calling OpenAI API: {str(e)}") + + async def close(self): + """Close the httpx client.""" + await self.client.aclose() diff --git a/libs/agent/agent/providers/openai/api_handler.py b/libs/agent/agent/providers/openai/api_handler.py new file mode 100644 index 00000000..ac435146 --- /dev/null +++ b/libs/agent/agent/providers/openai/api_handler.py @@ -0,0 +1,453 @@ +"""API handler for the OpenAI provider.""" + +import logging +import requests +import os +from typing import Any, Dict, List, Optional, TYPE_CHECKING +from datetime import datetime + +if TYPE_CHECKING: + from .loop import OpenAILoop + +logger = logging.getLogger(__name__) + + +class OpenAIAPIHandler: + """Handler for OpenAI API interactions.""" + + def __init__(self, loop: "OpenAILoop"): + """Initialize the API handler. + + Args: + loop: OpenAI loop instance + """ + self.loop = loop + self.api_key = os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OPENAI_API_KEY environment variable not set") + + self.api_base = "https://api.openai.com/v1" + self.headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + + # Add organization if specified + org_id = os.getenv("OPENAI_ORG") + if org_id: + self.headers["OpenAI-Organization"] = org_id + + logger.info("Initialized OpenAI API handler") + + async def send_initial_request( + self, + messages: List[Dict[str, Any]], + display_width: str, + display_height: str, + previous_response_id: Optional[str] = None, + ) -> Dict[str, Any]: + """Send an initial request to the OpenAI API with a screenshot. + + Args: + messages: List of message objects in standard format + display_width: Width of the display in pixels + display_height: Height of the display in pixels + previous_response_id: Optional ID of the previous response to link requests + + Returns: + API response + """ + # Convert display dimensions to integers + try: + width = int(display_width) + height = int(display_height) + except (ValueError, TypeError) as e: + logger.error(f"Failed to convert display dimensions to integers: {str(e)}") + raise ValueError( + f"Display dimensions must be integers: width={display_width}, height={display_height}" + ) + + # Extract the latest text message and screenshot from messages + latest_text = None + latest_screenshot = None + + for msg in reversed(messages): + if not isinstance(msg, dict): + continue + + content = msg.get("content", []) + + if isinstance(content, str) and not latest_text: + latest_text = content + continue + + if not isinstance(content, list): + continue + + for item in content: + if not isinstance(item, dict): + continue + + # Look for text if we don't have it yet + if not latest_text and item.get("type") == "text" and "text" in item: + latest_text = item.get("text", "") + + # Look for an image if we don't have it yet + if not latest_screenshot and item.get("type") == "image": + source = item.get("source", {}) + if source.get("type") == "base64" and "data" in source: + latest_screenshot = source["data"] + + # Prepare the input array + input_array = [] + + # Add the text message if found + if latest_text: + input_array.append({"role": "user", "content": latest_text}) + + # Add the screenshot if found and no previous_response_id is provided + if latest_screenshot and not previous_response_id: + input_array.append( + { + "type": "message", + "role": "user", + "content": [ + { + "type": "input_image", + "image_url": f"data:image/png;base64,{latest_screenshot}", + } + ], + } + ) + + # Prepare the request payload - using minimal format from docs + payload = { + "model": "computer-use-preview", + "tools": [ + { + "type": "computer_use_preview", + "display_width": width, + "display_height": height, + "environment": "mac", # We're on macOS + } + ], + "input": input_array, + "truncation": "auto", + } + + # Add previous_response_id if provided + if previous_response_id: + payload["previous_response_id"] = previous_response_id + + # Log the request using the BaseLoop's log_api_call method + self.loop._log_api_call("request", payload) + + # Log for debug purposes + logger.info("Sending initial request to OpenAI API") + logger.debug(f"Request payload: {self._sanitize_response(payload)}") + + # Send the request + response = requests.post( + f"{self.api_base}/responses", + headers=self.headers, + json=payload, + ) + + if response.status_code != 200: + error_message = f"OpenAI API error: {response.status_code} {response.text}" + logger.error(error_message) + # Log the error using the BaseLoop's log_api_call method + self.loop._log_api_call("error", payload, error=Exception(error_message)) + raise Exception(error_message) + + response_data = response.json() + + # Log the response using the BaseLoop's log_api_call method + self.loop._log_api_call("response", payload, response_data) + + # Log for debug purposes + logger.info("Received response from OpenAI API") + logger.debug(f"Response data: {self._sanitize_response(response_data)}") + + return response_data + + async def send_computer_call_request( + self, + messages: List[Dict[str, Any]], + display_width: str, + display_height: str, + previous_response_id: str, + ) -> Dict[str, Any]: + """Send a request to the OpenAI API with computer_call_output. + + Args: + messages: List of message objects in standard format + display_width: Width of the display in pixels + display_height: Height of the display in pixels + system_prompt: System prompt to include + previous_response_id: ID of the previous response to link requests + + Returns: + API response + """ + # Convert display dimensions to integers + try: + width = int(display_width) + height = int(display_height) + except (ValueError, TypeError) as e: + logger.error(f"Failed to convert display dimensions to integers: {str(e)}") + raise ValueError( + f"Display dimensions must be integers: width={display_width}, height={display_height}" + ) + + # Find the most recent computer_call_output with call_id + call_id = None + screenshot_base64 = None + + # Look for call_id and screenshot in messages + for msg in reversed(messages): + if not isinstance(msg, dict): + continue + + # Check if the message itself has a call_id + if "call_id" in msg and not call_id: + call_id = msg["call_id"] + + content = msg.get("content", []) + if not isinstance(content, list): + continue + + for item in content: + if not isinstance(item, dict): + continue + + # Look for call_id + if not call_id and "call_id" in item: + call_id = item["call_id"] + + # Look for screenshot in computer_call_output + if not screenshot_base64 and item.get("type") == "computer_call_output": + output = item.get("output", {}) + if isinstance(output, dict) and "image_url" in output: + image_url = output.get("image_url", "") + if image_url.startswith("data:image/png;base64,"): + screenshot_base64 = image_url[len("data:image/png;base64,") :] + + # Look for screenshot in image type + if not screenshot_base64 and item.get("type") == "image": + source = item.get("source", {}) + if source.get("type") == "base64" and "data" in source: + screenshot_base64 = source["data"] + + if not call_id or not screenshot_base64: + logger.error("Missing call_id or screenshot for computer_call_output") + logger.error(f"Last message: {messages[-1] if messages else None}") + raise ValueError("Cannot create computer call request: missing call_id or screenshot") + + # Prepare the request payload using minimal format from docs + payload = { + "model": "computer-use-preview", + "previous_response_id": previous_response_id, + "tools": [ + { + "type": "computer_use_preview", + "display_width": width, + "display_height": height, + "environment": "mac", # We're on macOS + } + ], + "input": [ + { + "type": "computer_call_output", + "call_id": call_id, + "output": { + "type": "input_image", + "image_url": f"data:image/png;base64,{screenshot_base64}", + }, + } + ], + "truncation": "auto", + } + + # Log the request using the BaseLoop's log_api_call method + self.loop._log_api_call("request", payload) + + # Log for debug purposes + logger.info("Sending computer call request to OpenAI API") + logger.debug(f"Request payload: {self._sanitize_response(payload)}") + + # Send the request + response = requests.post( + f"{self.api_base}/responses", + headers=self.headers, + json=payload, + ) + + if response.status_code != 200: + error_message = f"OpenAI API error: {response.status_code} {response.text}" + logger.error(error_message) + # Log the error using the BaseLoop's log_api_call method + self.loop._log_api_call("error", payload, error=Exception(error_message)) + raise Exception(error_message) + + response_data = response.json() + + # Log the response using the BaseLoop's log_api_call method + self.loop._log_api_call("response", payload, response_data) + + # Log for debug purposes + logger.info("Received response from OpenAI API") + logger.debug(f"Response data: {self._sanitize_response(response_data)}") + + return response_data + + def _format_messages_for_agent_response( + self, messages: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + """Format messages for the OpenAI Agent Response API. + + The Agent Response API requires specific content types: + - For user messages: use "input_text", "input_image", etc. + - For assistant messages: use "output_text" only + + Additionally, when using the computer tool, only one image can be sent. + + Args: + messages: List of standard messages + + Returns: + Messages formatted for the Agent Response API + """ + formatted_messages = [] + has_image = False # Track if we've already included an image + + # We need to process messages in reverse to ensure we keep the most recent image + # but preserve the original order in the final output + reversed_messages = list(reversed(messages)) + temp_formatted = [] + + for msg in reversed_messages: + if not msg: + continue + + role = msg.get("role", "user") + content = msg.get("content", "") + + logger.debug(f"Processing message - Role: {role}, Content type: {type(content)}") + if isinstance(content, list): + logger.debug( + f"List content items: {[item.get('type') for item in content if isinstance(item, dict)]}" + ) + + if isinstance(content, str): + # For string content, create a message with the appropriate text type + if role == "user": + temp_formatted.append( + {"role": role, "content": [{"type": "input_text", "text": content}]} + ) + elif role == "assistant": + # For assistant, we need explicit output_text + temp_formatted.append( + {"role": role, "content": [{"type": "output_text", "text": content}]} + ) + elif role == "system": + # System messages need to be formatted as input_text as well + temp_formatted.append( + {"role": role, "content": [{"type": "input_text", "text": content}]} + ) + elif isinstance(content, list): + # For list content, convert each item to the correct type based on role + formatted_content = [] + has_image_in_this_message = False + + for item in content: + if not isinstance(item, dict): + continue + + item_type = item.get("type") + + if role == "user": + # Handle user message formatting + if item_type == "text" or item_type == "input_text": + # Text from user is input_text + formatted_content.append( + {"type": "input_text", "text": item.get("text", "")} + ) + elif (item_type == "image" or item_type == "image_url") and not has_image: + # Only include the first/most recent image we encounter + if item_type == "image": + # Image from user is input_image + source = item.get("source", {}) + if source.get("type") == "base64" and "data" in source: + formatted_content.append( + { + "type": "input_image", + "image_url": f"data:image/png;base64,{source['data']}", + } + ) + has_image = True + has_image_in_this_message = True + elif item_type == "image_url": + # Convert "image_url" to "input_image" + formatted_content.append( + { + "type": "input_image", + "image_url": item.get("image_url", {}).get("url", ""), + } + ) + has_image = True + has_image_in_this_message = True + elif role == "assistant": + # Handle assistant message formatting - only output_text is supported + if item_type == "text" or item_type == "output_text": + formatted_content.append( + {"type": "output_text", "text": item.get("text", "")} + ) + + if formatted_content: + # If this message had an image, mark it for inclusion + temp_formatted.append( + { + "role": role, + "content": formatted_content, + "_had_image": has_image_in_this_message, # Temporary marker + } + ) + + # Reverse back to original order and cleanup + for msg in reversed(temp_formatted): + # Remove our temporary marker + if "_had_image" in msg: + del msg["_had_image"] + formatted_messages.append(msg) + + # Log summary for debugging + num_images = sum( + 1 + for msg in formatted_messages + for item in (msg.get("content", []) if isinstance(msg.get("content"), list) else []) + if isinstance(item, dict) and item.get("type") == "input_image" + ) + logger.info(f"Formatted {len(messages)} messages for OpenAI API with {num_images} images") + + return formatted_messages + + def _sanitize_response(self, response: Dict[str, Any]) -> Dict[str, Any]: + """Sanitize response for logging by removing large image data. + + Args: + response: Response to sanitize + + Returns: + Sanitized response + """ + from .utils import sanitize_message + + # Deep copy to avoid modifying the original + sanitized = response.copy() + + # Sanitize output items if present + if "output" in sanitized and isinstance(sanitized["output"], list): + sanitized["output"] = [sanitize_message(item) for item in sanitized["output"]] + + return sanitized diff --git a/libs/agent/agent/providers/openai/loop.py b/libs/agent/agent/providers/openai/loop.py new file mode 100644 index 00000000..66114970 --- /dev/null +++ b/libs/agent/agent/providers/openai/loop.py @@ -0,0 +1,454 @@ +"""OpenAI Agent Response API provider implementation.""" + +import logging +import asyncio +import base64 +from typing import Any, Dict, List, Optional, AsyncGenerator, Callable, Awaitable, TYPE_CHECKING + +from computer import Computer +from ...core.base import BaseLoop +from ...core.types import AgentResponse +from ...core.messages import StandardMessageManager, ImageRetentionConfig + +from .api.client import OpenAIClient +from .api_handler import OpenAIAPIHandler +from .response_handler import OpenAIResponseHandler +from .tools.manager import ToolManager +from .types import LLMProvider, ResponseItemType +from .prompts import SYSTEM_PROMPT + +logger = logging.getLogger(__name__) + + +class OpenAILoop(BaseLoop): + """OpenAI-specific implementation of the agent loop. + + This class extends BaseLoop to provide specialized support for OpenAI's Agent Response API + with computer control capabilities. + """ + + ########################################### + # INITIALIZATION AND CONFIGURATION + ########################################### + + def __init__( + self, + api_key: str, + computer: Computer, + model: str = "computer-use-preview", + only_n_most_recent_images: Optional[int] = 2, + base_dir: Optional[str] = "trajectories", + max_retries: int = 3, + retry_delay: float = 1.0, + save_trajectory: bool = True, + acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None, + **kwargs, + ): + """Initialize the OpenAI loop. + + Args: + api_key: OpenAI API key + model: Model name (ignored, always uses computer-use-preview) + computer: Computer instance + only_n_most_recent_images: Maximum number of recent screenshots to include in API requests + base_dir: Base directory for saving experiment data + max_retries: Maximum number of retries for API calls + retry_delay: Delay between retries in seconds + save_trajectory: Whether to save trajectory data + acknowledge_safety_check_callback: Optional callback for safety check acknowledgment + **kwargs: Additional provider-specific arguments + """ + # Always use computer-use-preview model + if model != "computer-use-preview": + logger.info( + f"Overriding provided model '{model}' with required model 'computer-use-preview'" + ) + + # Initialize base class with core config + super().__init__( + computer=computer, + model="computer-use-preview", # Always use computer-use-preview + api_key=api_key, + max_retries=max_retries, + retry_delay=retry_delay, + base_dir=base_dir, + save_trajectory=save_trajectory, + only_n_most_recent_images=only_n_most_recent_images, + **kwargs, + ) + + # Initialize message manager + self.message_manager = StandardMessageManager( + config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images) + ) + + # OpenAI-specific attributes + self.provider = LLMProvider.OPENAI + self.client = None + self.retry_count = 0 + self.acknowledge_safety_check_callback = acknowledge_safety_check_callback + self.queue = asyncio.Queue() # Initialize queue + self.last_response_id = None # Store the last response ID across runs + + # Initialize handlers + self.api_handler = OpenAIAPIHandler(self) + self.response_handler = OpenAIResponseHandler(self) + + # Initialize tool manager with callback + self.tool_manager = ToolManager( + computer=computer, acknowledge_safety_check_callback=acknowledge_safety_check_callback + ) + + ########################################### + # CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD + ########################################### + + async def initialize_client(self) -> None: + """Initialize the OpenAI API client and tools. + + Implements abstract method from BaseLoop to set up the OpenAI-specific + client, tool manager, and message manager. + """ + try: + logger.info(f"Initializing OpenAI client with model {self.model}...") + + # Initialize client + self.client = OpenAIClient(api_key=self.api_key, model=self.model) + + # Initialize tool manager + await self.tool_manager.initialize() + + logger.info(f"Initialized OpenAI client with model {self.model}") + except Exception as e: + logger.error(f"Error initializing OpenAI client: {str(e)}") + self.client = None + raise RuntimeError(f"Failed to initialize OpenAI client: {str(e)}") + + ########################################### + # MAIN LOOP - IMPLEMENTING ABSTRACT METHOD + ########################################### + + async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]: + """Run the agent loop with provided messages. + + Args: + messages: List of message objects in standard format + + Yields: + Agent response format + """ + try: + logger.info("Starting OpenAI loop run") + + # Create queue for response streaming + queue = asyncio.Queue() + + # Ensure client is initialized + if self.client is None: + logger.info("Initializing client...") + await self.initialize_client() + if self.client is None: + raise RuntimeError("Failed to initialize client") + logger.info("Client initialized successfully") + + # Start loop in background task + loop_task = asyncio.create_task(self._run_loop(queue, messages)) + + # Process and yield messages as they arrive + while True: + try: + item = await queue.get() + if item is None: # Stop signal + break + yield item + queue.task_done() + except Exception as e: + logger.error(f"Error processing queue item: {str(e)}") + continue + + # Wait for loop to complete + await loop_task + + # Send completion message + yield { + "role": "assistant", + "content": "Task completed successfully.", + "metadata": {"title": "✅ Complete"}, + } + + except Exception as e: + logger.error(f"Error executing task: {str(e)}") + yield { + "role": "assistant", + "content": f"Error: {str(e)}", + "metadata": {"title": "❌ Error"}, + } + + ########################################### + # AGENT LOOP IMPLEMENTATION + ########################################### + + async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None: + """Run the agent loop with provided messages. + + Args: + queue: Queue for response streaming + messages: List of messages in standard format + """ + try: + # Use the instance-level last_response_id instead of creating a local variable + # This way it persists between runs + + # Capture initial screenshot + try: + # Take screenshot + screenshot = await self.computer.interface.screenshot() + logger.info("Screenshot captured successfully") + + # Convert to base64 if needed + if isinstance(screenshot, bytes): + screenshot_base64 = base64.b64encode(screenshot).decode("utf-8") + else: + screenshot_base64 = screenshot + + # Save screenshot if requested + if self.save_trajectory: + # Ensure screenshot_base64 is a string + if not isinstance(screenshot_base64, str): + logger.warning( + "Converting non-string screenshot_base64 to string for _save_screenshot" + ) + if isinstance(screenshot_base64, (bytearray, memoryview)): + screenshot_base64 = base64.b64encode(screenshot_base64).decode("utf-8") + self._save_screenshot(screenshot_base64, action_type="state") + logger.info("Screenshot saved to trajectory") + + # First add any existing user messages that were passed to run() + user_query = None + for msg in messages: + if msg.get("role") == "user": + user_content = msg.get("content", "") + if isinstance(user_content, str) and user_content: + user_query = user_content + # Add the user's original query to the message manager + self.message_manager.add_user_message( + [{"type": "text", "text": user_content}] + ) + break + + # Add screenshot to message manager + message_content = [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": screenshot_base64, + }, + } + ] + + # Add appropriate text with the screenshot + message_content.append( + { + "type": "text", + "text": user_query, + } + ) + + # Add the screenshot and text to the message manager + self.message_manager.add_user_message(message_content) + + # Process user request and convert our standard message format to one OpenAI expects + messages = self.message_manager.messages + logger.info(f"Starting agent loop with {len(messages)} messages") + + # Create initial turn directory + if self.save_trajectory: + self._create_turn_dir() + + # Call API + screen_size = await self.computer.interface.get_screen_size() + response = await self.api_handler.send_initial_request( + messages=messages, + display_width=str(screen_size["width"]), + display_height=str(screen_size["height"]), + previous_response_id=self.last_response_id, + ) + + # Store response ID for next request + # OpenAI API response structure: the ID is in the response dictionary + if isinstance(response, dict) and "id" in response: + self.last_response_id = response["id"] # Update instance variable + logger.info(f"Received response with ID: {self.last_response_id}") + else: + logger.warning( + f"Could not find response ID in OpenAI response: {type(response)}" + ) + # Don't reset last_response_id to None - keep the previous value if available + + # Process API response + await queue.put(response) + + # Loop to continue processing responses until task is complete + task_complete = False + while not task_complete: + # Check if there are any computer calls + output_items = response.get("output", []) or [] + computer_calls = [ + item for item in output_items if item.get("type") == "computer_call" + ] + + if not computer_calls: + logger.info("No computer calls in response, task may be complete.") + task_complete = True + continue + + # Process the first computer call + computer_call = computer_calls[0] + action = computer_call.get("action", {}) + call_id = computer_call.get("call_id") + + # Check for safety checks + pending_safety_checks = computer_call.get("pending_safety_checks", []) + acknowledged_safety_checks = [] + + if pending_safety_checks: + # Log safety checks + for check in pending_safety_checks: + logger.warning( + f"Safety check: {check.get('code')} - {check.get('message')}" + ) + + # If we have a callback, use it to acknowledge safety checks + if self.acknowledge_safety_check_callback: + acknowledged = await self.acknowledge_safety_check_callback( + pending_safety_checks + ) + if not acknowledged: + logger.warning("Safety check acknowledgment failed") + await queue.put( + { + "role": "assistant", + "content": "Safety checks were not acknowledged. Cannot proceed with action.", + "metadata": {"title": "⚠️ Safety Warning"}, + } + ) + continue + acknowledged_safety_checks = pending_safety_checks + + # Execute the action + try: + # Create a new turn directory for this action if saving trajectories + if self.save_trajectory: + self._create_turn_dir() + + # Execute the tool + result = await self.tool_manager.execute_tool("computer", action) + + # Take screenshot after action + screenshot = await self.computer.interface.screenshot() + if isinstance(screenshot, bytes): + screenshot_base64 = base64.b64encode(screenshot).decode("utf-8") + else: + screenshot_base64 = screenshot + + # Create computer_call_output + computer_call_output = { + "type": "computer_call_output", + "call_id": call_id, + "output": { + "type": "input_image", + "image_url": f"data:image/png;base64,{screenshot_base64}", + }, + } + + # Add acknowledged safety checks if any + if acknowledged_safety_checks: + computer_call_output["acknowledged_safety_checks"] = ( + acknowledged_safety_checks + ) + + # Save to message manager for history + self.message_manager.add_system_message( + f"[Computer action executed: {action.get('type')}]" + ) + self.message_manager.add_user_message([computer_call_output]) + + # For follow-up requests with previous_response_id, we only need to send + # the computer_call_output, not the full message history + # The API handler will extract this from the message history + if isinstance(self.last_response_id, str): + response = await self.api_handler.send_computer_call_request( + messages=self.message_manager.messages, + display_width=str(screen_size["width"]), + display_height=str(screen_size["height"]), + previous_response_id=self.last_response_id, # Use instance variable + ) + + # Store response ID for next request + if isinstance(response, dict) and "id" in response: + self.last_response_id = response["id"] # Update instance variable + logger.info(f"Received response with ID: {self.last_response_id}") + else: + logger.warning( + f"Could not find response ID in OpenAI response: {type(response)}" + ) + # Keep using the previous response ID if we can't find a new one + + # Process the response + # await self.response_handler.process_response(response, queue) + await queue.put(response) + except Exception as e: + logger.error(f"Error executing computer action: {str(e)}") + await queue.put( + { + "role": "assistant", + "content": f"Error executing action: {str(e)}", + "metadata": {"title": "❌ Error"}, + } + ) + task_complete = True + + except Exception as e: + logger.error(f"Error capturing initial screenshot: {str(e)}") + await queue.put( + { + "role": "assistant", + "content": f"Error capturing screenshot: {str(e)}", + "metadata": {"title": "❌ Error"}, + } + ) + await queue.put(None) # Signal that we're done + return + + # Signal that we're done + await queue.put(None) + + except Exception as e: + logger.error(f"Error in _run_loop: {str(e)}") + await queue.put( + { + "role": "assistant", + "content": f"Error: {str(e)}", + "metadata": {"title": "❌ Error"}, + } + ) + await queue.put(None) # Signal that we're done + + def get_last_response_id(self) -> Optional[str]: + """Get the last response ID. + + Returns: + The last response ID or None if no response has been received + """ + return self.last_response_id + + def set_last_response_id(self, response_id: str) -> None: + """Set the last response ID. + + Args: + response_id: OpenAI response ID to set + """ + self.last_response_id = response_id + logger.info(f"Manually set response ID to: {self.last_response_id}") diff --git a/libs/agent/agent/providers/openai/prompts.py b/libs/agent/agent/providers/openai/prompts.py new file mode 100644 index 00000000..d57eeb56 --- /dev/null +++ b/libs/agent/agent/providers/openai/prompts.py @@ -0,0 +1,20 @@ +"""Prompts for OpenAI Agent Response API.""" + +# System prompt to be used when no specific system prompt is provided +SYSTEM_PROMPT = """ +You are a helpful assistant that can control a computer to help users accomplish tasks. +You have access to a computer where you can: +- Click, scroll, and type to interact with the interface +- Use keyboard shortcuts and special keys +- Read text and images from the screen +- Navigate and interact with applications + +A few important rules to follow: +1. Only perform actions that the user has requested or that directly support their task +2. If uncertain about what the user wants, ask for clarification +3. Explain your steps clearly when working on complex tasks +4. Be careful when interacting with sensitive data or performing potentially destructive actions +5. Always respect user privacy and avoid accessing personal information unless necessary for the task + +When in doubt about how to accomplish something, try to break it down into simpler steps using available computer actions. +""" diff --git a/libs/agent/agent/providers/openai/response_handler.py b/libs/agent/agent/providers/openai/response_handler.py new file mode 100644 index 00000000..7b9338de --- /dev/null +++ b/libs/agent/agent/providers/openai/response_handler.py @@ -0,0 +1,205 @@ +"""Response handler for the OpenAI provider.""" + +import logging +import asyncio +import traceback +from typing import Any, Dict, List, Optional, TYPE_CHECKING, AsyncGenerator +import base64 + +from ...core.types import AgentResponse +from .types import ResponseItemType + +if TYPE_CHECKING: + from .loop import OpenAILoop + +logger = logging.getLogger(__name__) + + +class OpenAIResponseHandler: + """Handler for OpenAI API responses.""" + + def __init__(self, loop: "OpenAILoop"): + """Initialize the response handler. + + Args: + loop: OpenAI loop instance + """ + self.loop = loop + logger.info("Initialized OpenAI response handler") + + async def process_response(self, response: Dict[str, Any], queue: asyncio.Queue) -> None: + """Process the response from the OpenAI API. + + Args: + response: Response from the API + queue: Queue for response streaming + """ + try: + # Get output items + output_items = response.get("output", []) or [] + + # Process each output item + for item in output_items: + if not isinstance(item, dict): + continue + + item_type = item.get("type") + + # For computer_call items, we only need to add to the queue + # The loop is now handling executing the action and creating the computer_call_output + if item_type == ResponseItemType.COMPUTER_CALL: + # Send computer_call to queue so it can be processed + await queue.put(item) + + elif item_type == ResponseItemType.MESSAGE: + # Send message to queue + await queue.put(item) + + elif item_type == ResponseItemType.REASONING: + # Process reasoning summary + summary = None + if "summary" in item and isinstance(item["summary"], list): + for summary_item in item["summary"]: + if ( + isinstance(summary_item, dict) + and summary_item.get("type") == "summary_text" + ): + summary = summary_item.get("text") + break + + if summary: + # Log the reasoning summary + logger.info(f"Reasoning summary: {summary}") + + # Send reasoning summary to queue with a special format + await queue.put( + { + "role": "assistant", + "content": f"[Reasoning: {summary}]", + "metadata": {"title": "💭 Reasoning", "is_summary": True}, + } + ) + + # Also pass the original reasoning item to the queue for complete context + await queue.put(item) + + except Exception as e: + logger.error(f"Error processing response: {str(e)}") + await queue.put( + { + "role": "assistant", + "content": f"Error processing response: {str(e)}", + "metadata": {"title": "❌ Error"}, + } + ) + + def _process_message_item(self, item: Dict[str, Any]) -> AgentResponse: + """Process a message item from the response. + + Args: + item: Message item from the response + + Returns: + Processed message in AgentResponse format + """ + # Extract content items - add null check + content_items = item.get("content", []) or [] + + # Extract text from content items - use output_text type from OpenAI + text = "" + for content_item in content_items: + # Skip if content_item is None or not a dict + if content_item is None or not isinstance(content_item, dict): + continue + + # In OpenAI Agent Response API, text content is in "output_text" type items + if content_item.get("type") == "output_text": + text += content_item.get("text", "") + + # Create agent response + return { + "role": "assistant", + "content": text + or "I don't have a response for that right now.", # Provide fallback when text is empty + "metadata": {"title": "💬 Response"}, + } + + async def _process_computer_call(self, item: Dict[str, Any], queue: asyncio.Queue) -> None: + """Process a computer call item from the response. + + Args: + item: Computer call item + queue: Queue to add responses to + """ + try: + # Log the computer call + action = item.get("action", {}) or {} + if not isinstance(action, dict): + logger.warning(f"Expected dict for action, got {type(action)}") + action = {} + + action_type = action.get("type", "unknown") + logger.info(f"Processing computer call: {action_type}") + + # Execute the tool call + result = await self.loop.tool_manager.execute_tool("computer", action) + + # Add any message to the conversation history and queue + if result and result.base64_image: + # Update message history with the call output + self.loop.message_manager.add_user_message( + [{"type": "text", "text": f"[Computer action completed: {action_type}]"}] + ) + + # Add image to messages (using correct content types for Agent Response API) + self.loop.message_manager.add_user_message( + [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": result.base64_image, + }, + } + ] + ) + + # If browser environment, include URL if available + # if ( + # hasattr(self.loop.computer, "environment") + # and self.loop.computer.environment == "browser" + # ): + # try: + # if hasattr(self.loop.computer.interface, "get_current_url"): + # current_url = await self.loop.computer.interface.get_current_url() + # self.loop.message_manager.add_user_message( + # [ + # { + # "type": "text", + # "text": f"Current URL: {current_url}", + # } + # ] + # ) + # except Exception as e: + # logger.warning(f"Failed to get current URL: {str(e)}") + + # Log successful completion + logger.info(f"Computer call {action_type} executed successfully") + + except Exception as e: + logger.error(f"Error executing computer call: {str(e)}") + logger.debug(traceback.format_exc()) + + # Add error to conversation + self.loop.message_manager.add_user_message( + [{"type": "text", "text": f"Error executing computer action: {str(e)}"}] + ) + + # Send error to queue + error_response = { + "role": "assistant", + "content": f"Error executing computer action: {str(e)}", + "metadata": {"title": "❌ Error"}, + } + await queue.put(error_response) diff --git a/libs/agent/agent/providers/openai/tools/__init__.py b/libs/agent/agent/providers/openai/tools/__init__.py new file mode 100644 index 00000000..a72c4079 --- /dev/null +++ b/libs/agent/agent/providers/openai/tools/__init__.py @@ -0,0 +1,15 @@ +"""OpenAI tools module for computer control.""" + +from .manager import ToolManager +from .computer import ComputerTool +from .base import BaseOpenAITool, ToolResult, ToolError, ToolFailure, CLIResult + +__all__ = [ + "ToolManager", + "ComputerTool", + "BaseOpenAITool", + "ToolResult", + "ToolError", + "ToolFailure", + "CLIResult", +] diff --git a/libs/agent/agent/providers/openai/tools/base.py b/libs/agent/agent/providers/openai/tools/base.py new file mode 100644 index 00000000..51a9137f --- /dev/null +++ b/libs/agent/agent/providers/openai/tools/base.py @@ -0,0 +1,79 @@ +"""OpenAI-specific tool base classes.""" + +from abc import ABCMeta, abstractmethod +from dataclasses import dataclass, fields, replace +from typing import Any, Dict, List, Optional + +from ....core.tools.base import BaseTool + + +class BaseOpenAITool(BaseTool, metaclass=ABCMeta): + """Abstract base class for OpenAI-defined tools.""" + + def __init__(self): + """Initialize the base OpenAI tool.""" + # No specific initialization needed yet, but included for future extensibility + pass + + @abstractmethod + async def __call__(self, **kwargs) -> Any: + """Executes the tool with the given arguments.""" + ... + + @abstractmethod + def to_params(self) -> Dict[str, Any]: + """Convert tool to OpenAI-specific API parameters. + + Returns: + Dictionary with tool parameters for OpenAI API + """ + raise NotImplementedError + + +@dataclass(kw_only=True, frozen=True) +class ToolResult: + """Represents the result of a tool execution.""" + + output: str | None = None + error: str | None = None + base64_image: str | None = None + system: str | None = None + content: list[dict] | None = None + + def __bool__(self): + return any(getattr(self, field.name) for field in fields(self)) + + def __add__(self, other: "ToolResult"): + def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True): + if field and other_field: + if concatenate: + return field + other_field + raise ValueError("Cannot combine tool results") + return field or other_field + + return ToolResult( + output=combine_fields(self.output, other.output), + error=combine_fields(self.error, other.error), + base64_image=combine_fields(self.base64_image, other.base64_image, False), + system=combine_fields(self.system, other.system), + content=self.content or other.content, # Use first non-None content + ) + + def replace(self, **kwargs): + """Returns a new ToolResult with the given fields replaced.""" + return replace(self, **kwargs) + + +class CLIResult(ToolResult): + """A ToolResult that can be rendered as a CLI output.""" + + +class ToolFailure(ToolResult): + """A ToolResult that represents a failure.""" + + +class ToolError(Exception): + """Raised when a tool encounters an error.""" + + def __init__(self, message): + self.message = message diff --git a/libs/agent/agent/providers/openai/tools/computer.py b/libs/agent/agent/providers/openai/tools/computer.py new file mode 100644 index 00000000..9f4e606d --- /dev/null +++ b/libs/agent/agent/providers/openai/tools/computer.py @@ -0,0 +1,319 @@ +"""Computer tool for OpenAI.""" + +import asyncio +import base64 +import logging +from typing import Literal, Any, Dict, Optional, List, Union + +from computer.computer import Computer + +from .base import BaseOpenAITool, ToolError, ToolResult +from ....core.tools.computer import BaseComputerTool + +TYPING_DELAY_MS = 12 +TYPING_GROUP_SIZE = 50 + +# Key mapping for special keys +KEY_MAPPING = { + "enter": "return", + "backspace": "delete", + "delete": "forwarddelete", + "escape": "esc", + "pageup": "page_up", + "pagedown": "page_down", + "arrowup": "up", + "arrowdown": "down", + "arrowleft": "left", + "arrowright": "right", + "home": "home", + "end": "end", + "tab": "tab", + "space": "space", + "shift": "shift", + "control": "control", + "alt": "alt", + "meta": "command", +} + +Action = Literal[ + "key", + "type", + "mouse_move", + "left_click", + "right_click", + "double_click", + "screenshot", + "scroll", +] + + +class ComputerTool(BaseComputerTool, BaseOpenAITool): + """ + A tool that allows the agent to interact with the screen, keyboard, and mouse of the current computer. + """ + + name: Literal["computer"] = "computer" + api_type: Literal["computer_use_preview"] = "computer_use_preview" + width: Optional[int] = None + height: Optional[int] = None + display_num: Optional[int] = None + computer: Computer # The CUA Computer instance + logger = logging.getLogger(__name__) + + _screenshot_delay = 1.0 # macOS is generally faster than X11 + _scaling_enabled = True + + def __init__(self, computer: Computer): + """Initialize the computer tool. + + Args: + computer: Computer instance + """ + self.computer = computer + self.width = None + self.height = None + self.logger = logging.getLogger(__name__) + + # Initialize the base computer tool first + BaseComputerTool.__init__(self, computer) + # Then initialize the OpenAI tool + BaseOpenAITool.__init__(self) + + # Additional initialization + self.width = None # Will be initialized from computer interface + self.height = None # Will be initialized from computer interface + self.display_num = None + + def to_params(self) -> Dict[str, Any]: + """Convert tool to API parameters. + + Returns: + Dictionary with tool parameters + """ + if self.width is None or self.height is None: + raise RuntimeError( + "Screen dimensions not initialized. Call initialize_dimensions() first." + ) + return { + "type": self.api_type, + "display_width": self.width, + "display_height": self.height, + "display_number": self.display_num, + } + + async def initialize_dimensions(self): + """Initialize screen dimensions from the computer interface.""" + try: + display_size = await self.computer.interface.get_screen_size() + self.width = display_size["width"] + self.height = display_size["height"] + assert isinstance(self.width, int) and isinstance(self.height, int) + self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}") + except Exception as e: + # Fall back to defaults if we can't get accurate dimensions + self.width = 1024 + self.height = 768 + self.logger.warning( + f"Failed to get screen dimensions, using defaults: {self.width}x{self.height}. Error: {e}" + ) + + async def __call__( + self, + *, + type: str, # OpenAI uses 'type' instead of 'action' + text: Optional[str] = None, + **kwargs, + ): + try: + # Ensure dimensions are initialized + if self.width is None or self.height is None: + await self.initialize_dimensions() + if self.width is None or self.height is None: + raise ToolError("Failed to initialize screen dimensions") + + if type == "type": + if text is None: + raise ToolError("text is required for type action") + return await self.handle_typing(text) + elif type == "click": + # Map button to correct action name + button = kwargs.get("button") + if button is None: + raise ToolError("button is required for click action") + return await self.handle_click(button, kwargs["x"], kwargs["y"]) + elif type == "keypress": + # Check for keys in kwargs if text is None + if text is None: + if "keys" in kwargs and isinstance(kwargs["keys"], list): + # Pass the keys list directly instead of joining and then splitting + return await self.handle_key(kwargs["keys"]) + else: + raise ToolError("Either 'text' or 'keys' is required for keypress action") + return await self.handle_key(text) + elif type == "mouse_move": + if "coordinates" not in kwargs: + raise ToolError("coordinates is required for mouse_move action") + return await self.handle_mouse_move( + kwargs["coordinates"][0], kwargs["coordinates"][1] + ) + elif type == "scroll": + # Get x, y coordinates directly from kwargs + x = kwargs.get("x") + y = kwargs.get("y") + if x is None or y is None: + raise ToolError("x and y coordinates are required for scroll action") + scroll_x = kwargs.get("scroll_x", 0) + scroll_y = kwargs.get("scroll_y", 0) + return await self.handle_scroll(x, y, scroll_x, scroll_y) + elif type == "screenshot": + return await self.screenshot() + elif type == "wait": + duration = kwargs.get("duration", 1.0) + await asyncio.sleep(duration) + return await self.screenshot() + else: + raise ToolError(f"Unsupported action: {type}") + + except Exception as e: + self.logger.error(f"Error in ComputerTool.__call__: {str(e)}") + raise ToolError(f"Failed to execute {type}: {str(e)}") + + async def handle_click(self, button: str, x: int, y: int) -> ToolResult: + """Handle different click actions.""" + try: + # Perform requested click action + if button == "left": + await self.computer.interface.left_click(x, y) + elif button == "right": + await self.computer.interface.right_click(x, y) + elif button == "double": + await self.computer.interface.double_click(x, y) + + # Wait for UI to update + await asyncio.sleep(0.5) + + # Take screenshot after action + screenshot = await self.computer.interface.screenshot() + base64_screenshot = base64.b64encode(screenshot).decode("utf-8") + + return ToolResult( + output=f"Performed {button} click at ({x}, {y})", + base64_image=base64_screenshot, + ) + except Exception as e: + self.logger.error(f"Error in handle_click: {str(e)}") + raise ToolError(f"Failed to perform {button} click at ({x}, {y}): {str(e)}") + + async def handle_typing(self, text: str) -> ToolResult: + """Handle typing text with a small delay between characters.""" + try: + # Type the text with a small delay + await self.computer.interface.type_text(text) + + await asyncio.sleep(0.3) + + # Take screenshot after typing + screenshot = await self.computer.interface.screenshot() + base64_screenshot = base64.b64encode(screenshot).decode("utf-8") + + return ToolResult(output=f"Typed: {text}", base64_image=base64_screenshot) + except Exception as e: + self.logger.error(f"Error in handle_typing: {str(e)}") + raise ToolError(f"Failed to type '{text}': {str(e)}") + + async def handle_key(self, key: Union[str, List[str]]) -> ToolResult: + """Handle key press, supporting both single keys and combinations. + + Args: + key: Either a string (e.g. "ctrl+c") or a list of keys (e.g. ["ctrl", "c"]) + """ + try: + # Check if key is already a list + if isinstance(key, list): + keys = [k.strip().lower() for k in key] + else: + # Split key string into list if it's a combination (e.g. "ctrl+c") + keys = [k.strip().lower() for k in key.split("+")] + + # Map each key + mapped_keys = [KEY_MAPPING.get(k, k) for k in keys] + + if len(mapped_keys) > 1: + # For key combinations (like Ctrl+C) + for k in mapped_keys: + await self.computer.interface.press_key(k) + await asyncio.sleep(0.1) + for k in reversed(mapped_keys): + await self.computer.interface.press_key(k) + else: + # Single key press + await self.computer.interface.press_key(mapped_keys[0]) + + # Wait briefly + await asyncio.sleep(0.3) + + # Take screenshot after action + screenshot = await self.computer.interface.screenshot() + base64_screenshot = base64.b64encode(screenshot).decode("utf-8") + + return ToolResult(output=f"Pressed key: {key}", base64_image=base64_screenshot) + except Exception as e: + self.logger.error(f"Error in handle_key: {str(e)}") + raise ToolError(f"Failed to press key '{key}': {str(e)}") + + async def handle_mouse_move(self, x: int, y: int) -> ToolResult: + """Handle mouse movement.""" + try: + # Move cursor to position + await self.computer.interface.move_cursor(x, y) + + # Wait briefly + await asyncio.sleep(0.2) + + # Take screenshot after action + screenshot = await self.computer.interface.screenshot() + base64_screenshot = base64.b64encode(screenshot).decode("utf-8") + + return ToolResult(output=f"Moved cursor to ({x}, {y})", base64_image=base64_screenshot) + except Exception as e: + self.logger.error(f"Error in handle_mouse_move: {str(e)}") + raise ToolError(f"Failed to move cursor to ({x}, {y}): {str(e)}") + + async def handle_scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> ToolResult: + """Handle scrolling.""" + try: + # Move cursor to position first + await self.computer.interface.move_cursor(x, y) + + # Scroll based on direction + if scroll_y > 0: + await self.computer.interface.scroll_down(abs(scroll_y)) + elif scroll_y < 0: + await self.computer.interface.scroll_up(abs(scroll_y)) + + # Wait for UI to update + await asyncio.sleep(0.5) + + # Take screenshot after action + screenshot = await self.computer.interface.screenshot() + base64_screenshot = base64.b64encode(screenshot).decode("utf-8") + + return ToolResult( + output=f"Scrolled at ({x}, {y}) with delta ({scroll_x}, {scroll_y})", + base64_image=base64_screenshot, + ) + except Exception as e: + self.logger.error(f"Error in handle_scroll: {str(e)}") + raise ToolError(f"Failed to scroll at ({x}, {y}): {str(e)}") + + async def screenshot(self) -> ToolResult: + """Take a screenshot.""" + try: + # Take screenshot + screenshot = await self.computer.interface.screenshot() + base64_screenshot = base64.b64encode(screenshot).decode("utf-8") + + return ToolResult(output="Screenshot taken", base64_image=base64_screenshot) + except Exception as e: + self.logger.error(f"Error in screenshot: {str(e)}") + raise ToolError(f"Failed to take screenshot: {str(e)}") diff --git a/libs/agent/agent/providers/openai/tools/manager.py b/libs/agent/agent/providers/openai/tools/manager.py new file mode 100644 index 00000000..a7387aaf --- /dev/null +++ b/libs/agent/agent/providers/openai/tools/manager.py @@ -0,0 +1,106 @@ +"""Tool manager for the OpenAI provider.""" + +import logging +from typing import Dict, Any, Optional, List, Callable, Awaitable, Union + +from computer import Computer +from ..types import ComputerAction, ResponseItemType +from .computer import ComputerTool +from ....core.tools.base import ToolResult, ToolFailure +from ....core.tools.collection import ToolCollection + +logger = logging.getLogger(__name__) + + +class ToolManager: + """Manager for computer tools in the OpenAI agent.""" + + def __init__( + self, + computer: Computer, + acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None, + ): + """Initialize the tool manager. + + Args: + computer: Computer instance + acknowledge_safety_check_callback: Optional callback for safety check acknowledgment + """ + self.computer = computer + self.acknowledge_safety_check_callback = acknowledge_safety_check_callback + self._initialized = False + self.computer_tool = ComputerTool(computer) + self.tools = None + logger.info("Initialized OpenAI ToolManager") + + async def initialize(self) -> None: + """Initialize the tool manager.""" + if not self._initialized: + logger.info("Initializing OpenAI ToolManager") + + # Initialize the computer tool + await self.computer_tool.initialize_dimensions() + + # Initialize tool collection + self.tools = ToolCollection(self.computer_tool) + + self._initialized = True + logger.info("OpenAI ToolManager initialized") + + async def get_tools_definition(self) -> List[Dict[str, Any]]: + """Get the tools definition for the OpenAI agent. + + Returns: + Tools definition for the OpenAI agent + """ + if not self.tools: + raise RuntimeError("Tools not initialized. Call initialize() first.") + + # For the OpenAI Agent Response API, we use a special "computer-preview" tool + # which provides the correct interface for computer control + display_width, display_height = await self._get_computer_dimensions() + + # Get environment, using "mac" as default since we're on macOS + environment = getattr(self.computer, "environment", "mac") + + # Ensure environment is one of the allowed values + if environment not in ["windows", "mac", "linux", "browser"]: + logger.warning(f"Invalid environment value: {environment}, using 'mac' instead") + environment = "mac" + + return [ + { + "type": "computer-preview", + "display_width": display_width, + "display_height": display_height, + "environment": environment, + } + ] + + async def _get_computer_dimensions(self) -> tuple[int, int]: + """Get the dimensions of the computer display. + + Returns: + Tuple of (width, height) + """ + # If computer tool is initialized, use its dimensions + if self.computer_tool.width is not None and self.computer_tool.height is not None: + return (self.computer_tool.width, self.computer_tool.height) + + # Try to get from computer.interface if available + screen_size = await self.computer.interface.get_screen_size() + return (int(screen_size["width"]), int(screen_size["height"])) + + async def execute_tool(self, name: str, tool_input: Dict[str, Any]) -> ToolResult: + """Execute a tool with the given input. + + Args: + name: Name of the tool to execute + tool_input: Input parameters for the tool + + Returns: + Result of the tool execution + """ + if not self.tools: + raise RuntimeError("Tools not initialized. Call initialize() first.") + return await self.tools.run(name=name, tool_input=tool_input) diff --git a/libs/agent/agent/providers/openai/types.py b/libs/agent/agent/providers/openai/types.py new file mode 100644 index 00000000..7c7839b4 --- /dev/null +++ b/libs/agent/agent/providers/openai/types.py @@ -0,0 +1,36 @@ +"""Type definitions for the OpenAI provider.""" + +from enum import StrEnum, auto +from typing import Dict, List, Optional, Union, Any +from dataclasses import dataclass + + +class LLMProvider(StrEnum): + """OpenAI LLM provider types.""" + + OPENAI = "openai" + + +class ResponseItemType(StrEnum): + """Types of items in OpenAI Agent Response output.""" + + MESSAGE = "message" + COMPUTER_CALL = "computer_call" + COMPUTER_CALL_OUTPUT = "computer_call_output" + REASONING = "reasoning" + + +@dataclass +class ComputerAction: + """Represents a computer action to be performed.""" + + type: str + x: Optional[int] = None + y: Optional[int] = None + text: Optional[str] = None + button: Optional[str] = None + keys: Optional[List[str]] = None + ms: Optional[int] = None + scroll_x: Optional[int] = None + scroll_y: Optional[int] = None + path: Optional[List[Dict[str, int]]] = None diff --git a/libs/agent/agent/providers/openai/utils.py b/libs/agent/agent/providers/openai/utils.py new file mode 100644 index 00000000..58cb06e8 --- /dev/null +++ b/libs/agent/agent/providers/openai/utils.py @@ -0,0 +1,98 @@ +"""Utility functions for the OpenAI provider.""" + +import logging +import json +import base64 +from typing import Any, Dict, List, Optional + +from ...core.types import AgentResponse + +logger = logging.getLogger(__name__) + + +def format_images_for_openai(images_base64: List[str]) -> List[Dict[str, Any]]: + """Format images for OpenAI Agent Response API. + + Args: + images_base64: List of base64 encoded images + + Returns: + List of formatted image items for Agent Response API + """ + return [ + {"type": "input_image", "image_url": f"data:image/png;base64,{image}"} + for image in images_base64 + ] + + +def extract_message_content(message: Dict[str, Any]) -> str: + """Extract text content from a message. + + Args: + message: Message to extract content from + + Returns: + Text content from the message + """ + if isinstance(message.get("content"), str): + return message["content"] + + if isinstance(message.get("content"), list): + text = "" + role = message.get("role", "user") + + for item in message["content"]: + if isinstance(item, dict): + # For user messages + if role == "user" and item.get("type") == "input_text": + text += item.get("text", "") + # For standard format + elif item.get("type") == "text": + text += item.get("text", "") + # For assistant messages in Agent Response API format + elif item.get("type") == "output_text": + text += item.get("text", "") + return text + + return "" + + +def sanitize_message(msg: Dict[str, Any]) -> Dict[str, Any]: + """Sanitize a message for logging by removing large image data. + + Args: + msg: Message to sanitize + + Returns: + Sanitized message + """ + if not isinstance(msg, dict): + return msg + + sanitized = msg.copy() + + # Handle message content + if isinstance(sanitized.get("content"), list): + sanitized_content = [] + for item in sanitized["content"]: + if isinstance(item, dict): + # Handle various image types + if item.get("type") == "image_url" and "image_url" in item: + sanitized_content.append({"type": "image_url", "image_url": "[omitted]"}) + elif item.get("type") == "input_image" and "image_url" in item: + sanitized_content.append({"type": "input_image", "image_url": "[omitted]"}) + elif item.get("type") == "image" and "source" in item: + sanitized_content.append({"type": "image", "source": "[omitted]"}) + else: + sanitized_content.append(item) + else: + sanitized_content.append(item) + sanitized["content"] = sanitized_content + + # Handle computer_call_output + if sanitized.get("type") == "computer_call_output" and "output" in sanitized: + output = sanitized["output"] + if isinstance(output, dict) and "image_url" in output: + sanitized["output"] = {**output, "image_url": "[omitted]"} + + return sanitized diff --git a/libs/agent/pyproject.toml b/libs/agent/pyproject.toml index dc2dd04d..dd6ea55a 100644 --- a/libs/agent/pyproject.toml +++ b/libs/agent/pyproject.toml @@ -30,6 +30,10 @@ anthropic = [ "anthropic>=0.49.0", "boto3>=1.35.81,<2.0.0", ] +openai = [ + "openai>=1.14.0,<2.0.0", + "httpx>=0.27.0,<0.29.0", +] som = [ "torch>=2.2.1", "torchvision>=0.17.1", diff --git a/libs/agent/tests/test_agent.py b/libs/agent/tests/test_agent.py deleted file mode 100644 index 3030bd1f..00000000 --- a/libs/agent/tests/test_agent.py +++ /dev/null @@ -1,91 +0,0 @@ -# """Basic tests for the agent package.""" - -# import pytest -# from agent import OmniComputerAgent, LLMProvider -# from agent.base.agent import BaseComputerAgent -# from computer import Computer - -# def test_agent_import(): -# """Test that we can import the OmniComputerAgent class.""" -# assert OmniComputerAgent is not None -# assert LLMProvider is not None - -# def test_agent_init(): -# """Test that we can create an OmniComputerAgent instance.""" -# agent = OmniComputerAgent( -# provider=LLMProvider.OPENAI, -# use_host_computer_server=True -# ) -# assert agent is not None - -# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed") -# def test_computer_agent_anthropic(): -# """Test creating an Anthropic agent.""" -# agent = ComputerAgent(provider=Provider.ANTHROPIC) -# assert isinstance(agent._agent, BaseComputerAgent) - -# def test_computer_agent_invalid_provider(): -# """Test creating an agent with an invalid provider.""" -# with pytest.raises(ValueError, match="Unsupported provider"): -# ComputerAgent(provider="invalid_provider") - -# def test_computer_agent_uninstalled_provider(): -# """Test creating an agent with an uninstalled provider.""" -# with pytest.raises(NotImplementedError, match="OpenAI provider not yet implemented"): -# # OpenAI provider is not implemented yet -# ComputerAgent(provider=Provider.OPENAI) - -# @pytest.mark.asyncio -# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed") -# async def test_agent_cleanup(): -# """Test agent cleanup.""" -# agent = ComputerAgent(provider=Provider.ANTHROPIC) -# await agent.cleanup() # Should not raise any errors - -# @pytest.mark.asyncio -# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed") -# async def test_agent_direct_initialization(): -# """Test direct initialization of the agent.""" -# # Create with default computer -# agent = ComputerAgent(provider=Provider.ANTHROPIC) -# try: -# # Should not raise any errors -# await agent.run("test task") -# finally: -# await agent.cleanup() - -# # Create with custom computer -# custom_computer = Computer( -# display="1920x1080", -# memory="8GB", -# cpu="4", -# os="macos", -# use_host_computer_server=False, -# ) -# agent = ComputerAgent(provider=Provider.ANTHROPIC, computer=custom_computer) -# try: -# # Should not raise any errors -# await agent.run("test task") -# finally: -# await agent.cleanup() - -# @pytest.mark.asyncio -# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed") -# async def test_agent_context_manager(): -# """Test context manager initialization of the agent.""" -# # Test with default computer -# async with ComputerAgent(provider=Provider.ANTHROPIC) as agent: -# # Should not raise any errors -# await agent.run("test task") - -# # Test with custom computer -# custom_computer = Computer( -# display="1920x1080", -# memory="8GB", -# cpu="4", -# os="macos", -# use_host_computer_server=False, -# ) -# async with ComputerAgent(provider=Provider.ANTHROPIC, computer=custom_computer) as agent: -# # Should not raise any errors -# await agent.run("test task") diff --git a/notebooks/computer_nb.ipynb b/notebooks/computer_nb.ipynb index c776fbfe..b49e53ef 100644 --- a/notebooks/computer_nb.ipynb +++ b/notebooks/computer_nb.ipynb @@ -193,7 +193,51 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computer API Server not ready yet. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n" + ] + } + ], "source": [ "async with Computer(\n", " # name=\"my_vm\", # optional, in case you want to use any other VM created using lume\n", @@ -217,7 +261,51 @@ "cell_type": "code", "execution_count": null, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Computer API Server not ready yet. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n", + "Computer API Server connection lost. Will retry automatically.\n" + ] + } + ], "source": [ "computer = Computer(\n", " display=\"1024x768\",\n",