mirror of
https://github.com/trycua/computer.git
synced 2026-04-30 12:11:35 -05:00
Add OpenAI CUA
This commit is contained in:
@@ -49,7 +49,7 @@ except Exception as e:
|
||||
logger.warning(f"Error initializing telemetry: {e}")
|
||||
|
||||
from .providers.omni.types import LLMProvider, LLM
|
||||
from .core.loop import AgentLoop
|
||||
from .core.computer_agent import ComputerAgent
|
||||
from .core.factory import AgentLoop
|
||||
from .core.agent import ComputerAgent
|
||||
|
||||
__all__ = ["AgentLoop", "LLMProvider", "LLM", "ComputerAgent"]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Core agent components."""
|
||||
|
||||
from .loop import BaseLoop
|
||||
from .factory import BaseLoop
|
||||
from .messages import (
|
||||
BaseMessageManager,
|
||||
ImageRetentionConfig,
|
||||
|
||||
@@ -3,32 +3,18 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Dict, Optional, cast, List
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from computer import Computer
|
||||
from ..providers.anthropic.loop import AnthropicLoop
|
||||
from ..providers.omni.loop import OmniLoop
|
||||
from ..providers.omni.parser import OmniParser
|
||||
from ..providers.omni.types import LLMProvider, LLM
|
||||
from ..providers.omni.types import LLM
|
||||
from .. import AgentLoop
|
||||
from .messages import StandardMessageManager, ImageRetentionConfig
|
||||
from .types import AgentResponse
|
||||
from .factory import LoopFactory
|
||||
from .provider_config import DEFAULT_MODELS, ENV_VARS
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default models for different providers
|
||||
DEFAULT_MODELS = {
|
||||
LLMProvider.OPENAI: "gpt-4o",
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
}
|
||||
|
||||
# Map providers to their environment variable names
|
||||
ENV_VARS = {
|
||||
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
||||
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
||||
}
|
||||
|
||||
|
||||
class ComputerAgent:
|
||||
"""A computer agent that can perform automated tasks using natural language instructions."""
|
||||
@@ -98,35 +84,27 @@ class ComputerAgent:
|
||||
f"No model specified for provider {self.provider} and no default found"
|
||||
)
|
||||
|
||||
# Ensure computer is properly cast for typing purposes
|
||||
computer_instance = self.computer
|
||||
|
||||
# Get API key from environment if not provided
|
||||
actual_api_key = api_key or os.environ.get(ENV_VARS[self.provider], "")
|
||||
if not actual_api_key:
|
||||
raise ValueError(f"No API key provided for {self.provider}")
|
||||
|
||||
# Initialize the appropriate loop based on the loop parameter
|
||||
if loop == AgentLoop.ANTHROPIC:
|
||||
self._loop = AnthropicLoop(
|
||||
api_key=actual_api_key,
|
||||
model=actual_model_name,
|
||||
computer=computer_instance,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
)
|
||||
else:
|
||||
self._loop = OmniLoop(
|
||||
# Create the appropriate loop using the factory
|
||||
try:
|
||||
# Let the factory create the appropriate loop with needed components
|
||||
self._loop = LoopFactory.create_loop(
|
||||
loop_type=loop,
|
||||
provider=self.provider,
|
||||
computer=self.computer,
|
||||
model_name=actual_model_name,
|
||||
api_key=actual_api_key,
|
||||
model=actual_model_name,
|
||||
computer=computer_instance,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
trajectory_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
parser=OmniParser(),
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to create loop: {str(e)}")
|
||||
raise
|
||||
|
||||
# Initialize the message manager from the loop
|
||||
self.message_manager = self._loop.message_manager
|
||||
@@ -152,21 +130,6 @@ class ComputerAgent:
|
||||
else:
|
||||
logger.info("Computer already initialized, skipping initialization")
|
||||
|
||||
# Take a test screenshot to verify the computer is working
|
||||
logger.info("Testing computer with a screenshot...")
|
||||
try:
|
||||
test_screenshot = await self.computer.interface.screenshot()
|
||||
# Determine the screenshot size based on its type
|
||||
if isinstance(test_screenshot, (bytes, bytearray, memoryview)):
|
||||
size = len(test_screenshot)
|
||||
elif hasattr(test_screenshot, "base64_image"):
|
||||
size = len(test_screenshot.base64_image)
|
||||
else:
|
||||
size = "unknown"
|
||||
logger.info(f"Screenshot test successful, size: {size}")
|
||||
except Exception as e:
|
||||
logger.error(f"Screenshot test failed: {str(e)}")
|
||||
# Even though screenshot failed, we continue since some tests might not need it
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing computer in __aenter__: {str(e)}")
|
||||
raise
|
||||
@@ -232,7 +195,6 @@ class ComputerAgent:
|
||||
|
||||
# Execute the task and yield results
|
||||
async for result in self._loop.run(self.message_manager.messages):
|
||||
# Yield the result to the caller
|
||||
yield result
|
||||
|
||||
except Exception as e:
|
||||
@@ -1,35 +1,21 @@
|
||||
"""Base agent loop implementation."""
|
||||
"""Base loop definitions."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, auto
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from computer import Computer
|
||||
from .experiment import ExperimentManager
|
||||
from .messages import StandardMessageManager, ImageRetentionConfig
|
||||
from .types import AgentResponse
|
||||
from .experiment import ExperimentManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentLoop(Enum):
|
||||
"""Enumeration of available loop types."""
|
||||
|
||||
ANTHROPIC = auto() # Anthropic implementation
|
||||
OMNI = auto() # OmniLoop implementation
|
||||
# Add more loop types as needed
|
||||
|
||||
|
||||
class BaseLoop(ABC):
|
||||
"""Base class for agent loops that handle message processing and tool execution."""
|
||||
|
||||
###########################################
|
||||
# INITIALIZATION AND CONFIGURATION
|
||||
###########################################
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
computer: Computer,
|
||||
@@ -68,6 +54,11 @@ class BaseLoop(ABC):
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self._kwargs = kwargs
|
||||
|
||||
# Initialize message manager
|
||||
self.message_manager = StandardMessageManager(
|
||||
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
||||
)
|
||||
|
||||
# Initialize experiment manager
|
||||
if self.save_trajectory and self.base_dir:
|
||||
self.experiment_manager = ExperimentManager(
|
||||
@@ -110,8 +101,7 @@ class BaseLoop(ABC):
|
||||
)
|
||||
raise RuntimeError(f"Failed to initialize: {str(e)}")
|
||||
|
||||
###########################################
|
||||
|
||||
###########################################
|
||||
# ABSTRACT METHODS TO BE IMPLEMENTED BY SUBCLASSES
|
||||
###########################################
|
||||
|
||||
@@ -125,17 +115,14 @@ class BaseLoop(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
||||
def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
This method handles the main agent loop including message processing,
|
||||
API calls, response handling, and action execution.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
|
||||
Yields:
|
||||
Agent response format
|
||||
Returns:
|
||||
An async generator that yields agent responses
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
"""Base agent loop implementation."""
|
||||
|
||||
import logging
|
||||
import importlib.util
|
||||
from typing import Dict, Optional, Type, TYPE_CHECKING, Any, cast, Callable, Awaitable
|
||||
|
||||
from computer import Computer
|
||||
from .types import AgentLoop
|
||||
from .base import BaseLoop
|
||||
|
||||
# For type checking only
|
||||
if TYPE_CHECKING:
|
||||
from ..providers.omni.types import LLMProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopFactory:
|
||||
"""Factory class for creating agent loops."""
|
||||
|
||||
# Registry to store loop implementations
|
||||
_loop_registry: Dict[AgentLoop, Type[BaseLoop]] = {}
|
||||
|
||||
@classmethod
|
||||
def create_loop(
|
||||
cls,
|
||||
loop_type: AgentLoop,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
computer: Computer,
|
||||
provider: Any = None,
|
||||
save_trajectory: bool = True,
|
||||
trajectory_dir: str = "trajectories",
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
|
||||
) -> BaseLoop:
|
||||
"""Create and return an appropriate loop instance based on type."""
|
||||
if loop_type == AgentLoop.ANTHROPIC:
|
||||
# Lazy import AnthropicLoop only when needed
|
||||
try:
|
||||
from ..providers.anthropic.loop import AnthropicLoop
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'anthropic' provider is not installed. "
|
||||
"Install it with 'pip install cua-agent[anthropic]'"
|
||||
)
|
||||
|
||||
return AnthropicLoop(
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
computer=computer,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
)
|
||||
elif loop_type == AgentLoop.OPENAI:
|
||||
# Lazy import OpenAILoop only when needed
|
||||
try:
|
||||
from ..providers.openai.loop import OpenAILoop
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'openai' provider is not installed. "
|
||||
"Install it with 'pip install cua-agent[openai]'"
|
||||
)
|
||||
|
||||
return OpenAILoop(
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
computer=computer,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
acknowledge_safety_check_callback=acknowledge_safety_check_callback,
|
||||
)
|
||||
elif loop_type == AgentLoop.OMNI:
|
||||
# Lazy import OmniLoop and related classes only when needed
|
||||
try:
|
||||
from ..providers.omni.loop import OmniLoop
|
||||
from ..providers.omni.parser import OmniParser
|
||||
from ..providers.omni.types import LLMProvider
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'omni' provider is not installed. "
|
||||
"Install it with 'pip install cua-agent[all]'"
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError("Provider is required for OMNI loop type")
|
||||
|
||||
# We know provider is the correct type at this point, so cast it
|
||||
provider_instance = cast(LLMProvider, provider)
|
||||
|
||||
return OmniLoop(
|
||||
provider=provider_instance,
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
computer=computer,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
parser=OmniParser(),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported loop type: {loop_type}")
|
||||
@@ -0,0 +1,15 @@
|
||||
"""Provider-specific configurations and constants."""
|
||||
|
||||
from ..providers.omni.types import LLMProvider
|
||||
|
||||
# Default models for different providers
|
||||
DEFAULT_MODELS = {
|
||||
LLMProvider.OPENAI: "gpt-4o",
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
}
|
||||
|
||||
# Map providers to their environment variable names
|
||||
ENV_VARS = {
|
||||
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
||||
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
||||
}
|
||||
@@ -1,6 +1,16 @@
|
||||
"""Core type definitions."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class AgentLoop(Enum):
|
||||
"""Enumeration of available loop types."""
|
||||
|
||||
ANTHROPIC = auto() # Anthropic implementation
|
||||
OMNI = auto() # OmniLoop implementation
|
||||
OPENAI = auto() # OpenAI implementation
|
||||
# Add more loop types as needed
|
||||
|
||||
|
||||
class AgentResponse(TypedDict, total=False):
|
||||
|
||||
@@ -16,7 +16,7 @@ from datetime import datetime
|
||||
from computer import Computer
|
||||
|
||||
# Base imports
|
||||
from ...core.loop import BaseLoop
|
||||
from ...core.base import BaseLoop
|
||||
from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
||||
from ...core.types import AgentResponse
|
||||
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from anthropic.types.beta import (
|
||||
BetaMessageParam,
|
||||
BetaCacheControlEphemeralParam,
|
||||
BetaToolResultBlockParam,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageRetentionConfig:
|
||||
"""Configuration for image retention in messages."""
|
||||
|
||||
num_images_to_keep: int | None = None
|
||||
min_removal_threshold: int = 1
|
||||
enable_caching: bool = True
|
||||
|
||||
def should_retain_images(self) -> bool:
|
||||
"""Check if image retention is enabled."""
|
||||
return self.num_images_to_keep is not None and self.num_images_to_keep > 0
|
||||
|
||||
|
||||
class MessageManager:
|
||||
"""Manages message preparation, including image retention and caching."""
|
||||
|
||||
def __init__(self, image_retention_config: ImageRetentionConfig):
|
||||
"""Initialize the message manager.
|
||||
|
||||
Args:
|
||||
image_retention_config: Configuration for image retention
|
||||
"""
|
||||
if image_retention_config.min_removal_threshold < 1:
|
||||
raise ValueError("min_removal_threshold must be at least 1")
|
||||
self.image_retention_config = image_retention_config
|
||||
|
||||
def prepare_messages(self, messages: list[BetaMessageParam]) -> list[BetaMessageParam]:
|
||||
"""Prepare messages by applying image retention and caching as configured."""
|
||||
if self.image_retention_config.should_retain_images():
|
||||
self._filter_images(messages)
|
||||
if self.image_retention_config.enable_caching:
|
||||
self._inject_caching(messages)
|
||||
return messages
|
||||
|
||||
def _filter_images(self, messages: list[BetaMessageParam]) -> None:
|
||||
"""Filter messages to retain only the specified number of most recent images."""
|
||||
tool_result_blocks = cast(
|
||||
list[BetaToolResultBlockParam],
|
||||
[
|
||||
item
|
||||
for message in messages
|
||||
for item in (message["content"] if isinstance(message["content"], list) else [])
|
||||
if isinstance(item, dict) and item.get("type") == "tool_result"
|
||||
],
|
||||
)
|
||||
|
||||
total_images = sum(
|
||||
1
|
||||
for tool_result in tool_result_blocks
|
||||
for content in tool_result.get("content", [])
|
||||
if isinstance(content, dict) and content.get("type") == "image"
|
||||
)
|
||||
|
||||
images_to_remove = total_images - (self.image_retention_config.num_images_to_keep or 0)
|
||||
# Round down to nearest min_removal_threshold for better cache behavior
|
||||
images_to_remove -= images_to_remove % self.image_retention_config.min_removal_threshold
|
||||
|
||||
# Remove oldest images first
|
||||
for tool_result in tool_result_blocks:
|
||||
if isinstance(tool_result.get("content"), list):
|
||||
new_content = []
|
||||
for content in tool_result.get("content", []):
|
||||
if isinstance(content, dict) and content.get("type") == "image":
|
||||
if images_to_remove > 0:
|
||||
images_to_remove -= 1
|
||||
continue
|
||||
new_content.append(content)
|
||||
tool_result["content"] = new_content
|
||||
|
||||
def _inject_caching(self, messages: list[BetaMessageParam]) -> None:
|
||||
"""Inject caching control for the most recent turns, limited to 3 blocks max to avoid API errors."""
|
||||
# Anthropic API allows a maximum of 4 blocks with cache_control
|
||||
# We use 3 here to be safe, as the system block may also have cache_control
|
||||
blocks_with_cache_control = 0
|
||||
max_cache_control_blocks = 3
|
||||
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "user" and isinstance(content := message["content"], list):
|
||||
# Only add cache control to the latest message in each turn
|
||||
if blocks_with_cache_control < max_cache_control_blocks:
|
||||
blocks_with_cache_control += 1
|
||||
# Add cache control to the last content block only
|
||||
if content and len(content) > 0:
|
||||
content[-1]["cache_control"] = BetaCacheControlEphemeralParam(
|
||||
type="ephemeral"
|
||||
)
|
||||
else:
|
||||
# Remove any existing cache control
|
||||
if content and len(content) > 0:
|
||||
content[-1].pop("cache_control", None)
|
||||
|
||||
# Ensure we're not exceeding the limit by checking the total
|
||||
if blocks_with_cache_control > max_cache_control_blocks:
|
||||
# If we somehow exceeded the limit, remove excess cache controls
|
||||
excess = blocks_with_cache_control - max_cache_control_blocks
|
||||
for message in messages:
|
||||
if excess <= 0:
|
||||
break
|
||||
|
||||
if message["role"] == "user" and isinstance(content := message["content"], list):
|
||||
if content and len(content) > 0 and "cache_control" in content[-1]:
|
||||
content[-1].pop("cache_control", None)
|
||||
excess -= 1
|
||||
@@ -1,14 +1,11 @@
|
||||
"""Response and tool handling for Anthropic provider."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import Any, Dict, List, Tuple, cast
|
||||
|
||||
from anthropic.types.beta import (
|
||||
BetaMessage,
|
||||
BetaMessageParam,
|
||||
BetaTextBlock,
|
||||
BetaTextBlockParam,
|
||||
BetaToolUseBlockParam,
|
||||
BetaContentBlockParam,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
"""Utility functions for Anthropic message handling."""
|
||||
|
||||
import time
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaTextBlock
|
||||
from anthropic.types.beta import BetaMessage
|
||||
from ..omni.parser import ParseResult
|
||||
from ...core.types import AgentResponse
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
# Configure module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,7 +10,7 @@ from httpx import ConnectError, ReadTimeout
|
||||
from typing import cast
|
||||
|
||||
from .parser import OmniParser, ParseResult
|
||||
from ...core.loop import BaseLoop
|
||||
from ...core.base import BaseLoop
|
||||
from ...core.visualization import VisualizationHelper
|
||||
from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
||||
from .utils import to_openai_agent_response_format
|
||||
|
||||
@@ -9,8 +9,10 @@ class LLMProvider(StrEnum):
|
||||
"""Supported LLM providers."""
|
||||
|
||||
ANTHROPIC = "anthropic"
|
||||
OMNI = "omni"
|
||||
OPENAI = "openai"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLM:
|
||||
"""Configuration for LLM model and provider."""
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
"""OpenAI Agent Response API provider for computer control."""
|
||||
|
||||
from .types import LLMProvider
|
||||
from .loop import OpenAILoop
|
||||
|
||||
__all__ = ["OpenAILoop", "LLMProvider"]
|
||||
@@ -0,0 +1,5 @@
|
||||
"""OpenAI API client module."""
|
||||
|
||||
from .client import OpenAIClient
|
||||
|
||||
__all__ = ["OpenAIClient"]
|
||||
@@ -0,0 +1,137 @@
|
||||
"""OpenAI API client for Agent Response API."""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import httpx
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIClient:
|
||||
"""Client for OpenAI's Agent Response API."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "computer-use-preview",
|
||||
base_url: str = "https://api.openai.com/v1",
|
||||
max_retries: int = 3,
|
||||
timeout: int = 120,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize OpenAI API client.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Model to use for completions (should always be computer-use-preview)
|
||||
base_url: Base URL for API requests
|
||||
max_retries: Maximum number of retries for API calls
|
||||
timeout: Timeout for API calls in seconds
|
||||
**kwargs: Additional arguments to pass to the httpx client
|
||||
"""
|
||||
self.api_key = api_key
|
||||
|
||||
# Always use computer-use-preview model
|
||||
if model != "computer-use-preview":
|
||||
logger.warning(
|
||||
f"Overriding provided model '{model}' with required model 'computer-use-preview'"
|
||||
)
|
||||
model = "computer-use-preview"
|
||||
|
||||
self.model = model
|
||||
self.base_url = base_url
|
||||
self.max_retries = max_retries
|
||||
self.timeout = timeout
|
||||
|
||||
# Create httpx client with auth and timeout
|
||||
self.client = httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
"OpenAI-Beta": "computer-use-2023-09-30", # Required beta header for computer use
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Additional initialization for organization if available
|
||||
openai_org = os.environ.get("OPENAI_ORG")
|
||||
if openai_org:
|
||||
self.client.headers["OpenAI-Organization"] = openai_org
|
||||
|
||||
logger.info(f"Initialized OpenAI client with model {model}")
|
||||
|
||||
async def create_response(
|
||||
self,
|
||||
input: List[Dict[str, Any]],
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
truncation: str = "auto",
|
||||
temperature: float = 0.7,
|
||||
top_p: float = 1.0,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a response using the OpenAI Agent Response API.
|
||||
|
||||
Args:
|
||||
input: List of messages in the conversation (must be in Agent Response API format)
|
||||
tools: List of tools available to the agent
|
||||
truncation: How to handle truncation (auto, truncate)
|
||||
temperature: Sampling temperature
|
||||
top_p: Nucleus sampling parameter
|
||||
**kwargs: Additional parameters to include in the request
|
||||
|
||||
Returns:
|
||||
Response from the API
|
||||
"""
|
||||
url = f"{self.base_url}/responses"
|
||||
|
||||
# Prepare request payload
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": input,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"truncation": truncation,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Add tools if provided
|
||||
if tools:
|
||||
payload["tools"] = tools
|
||||
|
||||
try:
|
||||
logger.debug(f"Sending request to {url}")
|
||||
|
||||
# Make API call
|
||||
response = await self.client.post(url, json=payload)
|
||||
|
||||
# Check for errors
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_detail = e.response.text
|
||||
try:
|
||||
# Try to parse the error as JSON for better debugging
|
||||
error_json = json.loads(error_detail)
|
||||
logger.error(f"HTTP error from OpenAI API: {json.dumps(error_json, indent=2)}")
|
||||
except:
|
||||
logger.error(f"HTTP error from OpenAI API: {error_detail}")
|
||||
raise
|
||||
|
||||
result = response.json()
|
||||
logger.debug("Received successful response")
|
||||
return result
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_detail = e.response.text if hasattr(e, "response") else str(e)
|
||||
logger.error(f"HTTP error from OpenAI API: {error_detail}")
|
||||
raise RuntimeError(f"OpenAI API error: {error_detail}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling OpenAI API: {str(e)}")
|
||||
raise RuntimeError(f"Error calling OpenAI API: {str(e)}")
|
||||
|
||||
async def close(self):
|
||||
"""Close the httpx client."""
|
||||
await self.client.aclose()
|
||||
@@ -0,0 +1,453 @@
|
||||
"""API handler for the OpenAI provider."""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .loop import OpenAILoop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIAPIHandler:
|
||||
"""Handler for OpenAI API interactions."""
|
||||
|
||||
def __init__(self, loop: "OpenAILoop"):
|
||||
"""Initialize the API handler.
|
||||
|
||||
Args:
|
||||
loop: OpenAI loop instance
|
||||
"""
|
||||
self.loop = loop
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
self.api_base = "https://api.openai.com/v1"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Add organization if specified
|
||||
org_id = os.getenv("OPENAI_ORG")
|
||||
if org_id:
|
||||
self.headers["OpenAI-Organization"] = org_id
|
||||
|
||||
logger.info("Initialized OpenAI API handler")
|
||||
|
||||
async def send_initial_request(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
display_width: str,
|
||||
display_height: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send an initial request to the OpenAI API with a screenshot.
|
||||
|
||||
Args:
|
||||
messages: List of message objects in standard format
|
||||
display_width: Width of the display in pixels
|
||||
display_height: Height of the display in pixels
|
||||
previous_response_id: Optional ID of the previous response to link requests
|
||||
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
# Convert display dimensions to integers
|
||||
try:
|
||||
width = int(display_width)
|
||||
height = int(display_height)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Failed to convert display dimensions to integers: {str(e)}")
|
||||
raise ValueError(
|
||||
f"Display dimensions must be integers: width={display_width}, height={display_height}"
|
||||
)
|
||||
|
||||
# Extract the latest text message and screenshot from messages
|
||||
latest_text = None
|
||||
latest_screenshot = None
|
||||
|
||||
for msg in reversed(messages):
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
content = msg.get("content", [])
|
||||
|
||||
if isinstance(content, str) and not latest_text:
|
||||
latest_text = content
|
||||
continue
|
||||
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# Look for text if we don't have it yet
|
||||
if not latest_text and item.get("type") == "text" and "text" in item:
|
||||
latest_text = item.get("text", "")
|
||||
|
||||
# Look for an image if we don't have it yet
|
||||
if not latest_screenshot and item.get("type") == "image":
|
||||
source = item.get("source", {})
|
||||
if source.get("type") == "base64" and "data" in source:
|
||||
latest_screenshot = source["data"]
|
||||
|
||||
# Prepare the input array
|
||||
input_array = []
|
||||
|
||||
# Add the text message if found
|
||||
if latest_text:
|
||||
input_array.append({"role": "user", "content": latest_text})
|
||||
|
||||
# Add the screenshot if found and no previous_response_id is provided
|
||||
if latest_screenshot and not previous_response_id:
|
||||
input_array.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{latest_screenshot}",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare the request payload - using minimal format from docs
|
||||
payload = {
|
||||
"model": "computer-use-preview",
|
||||
"tools": [
|
||||
{
|
||||
"type": "computer_use_preview",
|
||||
"display_width": width,
|
||||
"display_height": height,
|
||||
"environment": "mac", # We're on macOS
|
||||
}
|
||||
],
|
||||
"input": input_array,
|
||||
"truncation": "auto",
|
||||
}
|
||||
|
||||
# Add previous_response_id if provided
|
||||
if previous_response_id:
|
||||
payload["previous_response_id"] = previous_response_id
|
||||
|
||||
# Log the request using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("request", payload)
|
||||
|
||||
# Log for debug purposes
|
||||
logger.info("Sending initial request to OpenAI API")
|
||||
logger.debug(f"Request payload: {self._sanitize_response(payload)}")
|
||||
|
||||
# Send the request
|
||||
response = requests.post(
|
||||
f"{self.api_base}/responses",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_message = f"OpenAI API error: {response.status_code} {response.text}"
|
||||
logger.error(error_message)
|
||||
# Log the error using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("error", payload, error=Exception(error_message))
|
||||
raise Exception(error_message)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Log the response using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("response", payload, response_data)
|
||||
|
||||
# Log for debug purposes
|
||||
logger.info("Received response from OpenAI API")
|
||||
logger.debug(f"Response data: {self._sanitize_response(response_data)}")
|
||||
|
||||
return response_data
|
||||
|
||||
async def send_computer_call_request(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
display_width: str,
|
||||
display_height: str,
|
||||
previous_response_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a request to the OpenAI API with computer_call_output.
|
||||
|
||||
Args:
|
||||
messages: List of message objects in standard format
|
||||
display_width: Width of the display in pixels
|
||||
display_height: Height of the display in pixels
|
||||
system_prompt: System prompt to include
|
||||
previous_response_id: ID of the previous response to link requests
|
||||
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
# Convert display dimensions to integers
|
||||
try:
|
||||
width = int(display_width)
|
||||
height = int(display_height)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Failed to convert display dimensions to integers: {str(e)}")
|
||||
raise ValueError(
|
||||
f"Display dimensions must be integers: width={display_width}, height={display_height}"
|
||||
)
|
||||
|
||||
# Find the most recent computer_call_output with call_id
|
||||
call_id = None
|
||||
screenshot_base64 = None
|
||||
|
||||
# Look for call_id and screenshot in messages
|
||||
for msg in reversed(messages):
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
# Check if the message itself has a call_id
|
||||
if "call_id" in msg and not call_id:
|
||||
call_id = msg["call_id"]
|
||||
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# Look for call_id
|
||||
if not call_id and "call_id" in item:
|
||||
call_id = item["call_id"]
|
||||
|
||||
# Look for screenshot in computer_call_output
|
||||
if not screenshot_base64 and item.get("type") == "computer_call_output":
|
||||
output = item.get("output", {})
|
||||
if isinstance(output, dict) and "image_url" in output:
|
||||
image_url = output.get("image_url", "")
|
||||
if image_url.startswith("data:image/png;base64,"):
|
||||
screenshot_base64 = image_url[len("data:image/png;base64,") :]
|
||||
|
||||
# Look for screenshot in image type
|
||||
if not screenshot_base64 and item.get("type") == "image":
|
||||
source = item.get("source", {})
|
||||
if source.get("type") == "base64" and "data" in source:
|
||||
screenshot_base64 = source["data"]
|
||||
|
||||
if not call_id or not screenshot_base64:
|
||||
logger.error("Missing call_id or screenshot for computer_call_output")
|
||||
logger.error(f"Last message: {messages[-1] if messages else None}")
|
||||
raise ValueError("Cannot create computer call request: missing call_id or screenshot")
|
||||
|
||||
# Prepare the request payload using minimal format from docs
|
||||
payload = {
|
||||
"model": "computer-use-preview",
|
||||
"previous_response_id": previous_response_id,
|
||||
"tools": [
|
||||
{
|
||||
"type": "computer_use_preview",
|
||||
"display_width": width,
|
||||
"display_height": height,
|
||||
"environment": "mac", # We're on macOS
|
||||
}
|
||||
],
|
||||
"input": [
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
||||
},
|
||||
}
|
||||
],
|
||||
"truncation": "auto",
|
||||
}
|
||||
|
||||
# Log the request using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("request", payload)
|
||||
|
||||
# Log for debug purposes
|
||||
logger.info("Sending computer call request to OpenAI API")
|
||||
logger.debug(f"Request payload: {self._sanitize_response(payload)}")
|
||||
|
||||
# Send the request
|
||||
response = requests.post(
|
||||
f"{self.api_base}/responses",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_message = f"OpenAI API error: {response.status_code} {response.text}"
|
||||
logger.error(error_message)
|
||||
# Log the error using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("error", payload, error=Exception(error_message))
|
||||
raise Exception(error_message)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Log the response using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("response", payload, response_data)
|
||||
|
||||
# Log for debug purposes
|
||||
logger.info("Received response from OpenAI API")
|
||||
logger.debug(f"Response data: {self._sanitize_response(response_data)}")
|
||||
|
||||
return response_data
|
||||
|
||||
def _format_messages_for_agent_response(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Format messages for the OpenAI Agent Response API.
|
||||
|
||||
The Agent Response API requires specific content types:
|
||||
- For user messages: use "input_text", "input_image", etc.
|
||||
- For assistant messages: use "output_text" only
|
||||
|
||||
Additionally, when using the computer tool, only one image can be sent.
|
||||
|
||||
Args:
|
||||
messages: List of standard messages
|
||||
|
||||
Returns:
|
||||
Messages formatted for the Agent Response API
|
||||
"""
|
||||
formatted_messages = []
|
||||
has_image = False # Track if we've already included an image
|
||||
|
||||
# We need to process messages in reverse to ensure we keep the most recent image
|
||||
# but preserve the original order in the final output
|
||||
reversed_messages = list(reversed(messages))
|
||||
temp_formatted = []
|
||||
|
||||
for msg in reversed_messages:
|
||||
if not msg:
|
||||
continue
|
||||
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
logger.debug(f"Processing message - Role: {role}, Content type: {type(content)}")
|
||||
if isinstance(content, list):
|
||||
logger.debug(
|
||||
f"List content items: {[item.get('type') for item in content if isinstance(item, dict)]}"
|
||||
)
|
||||
|
||||
if isinstance(content, str):
|
||||
# For string content, create a message with the appropriate text type
|
||||
if role == "user":
|
||||
temp_formatted.append(
|
||||
{"role": role, "content": [{"type": "input_text", "text": content}]}
|
||||
)
|
||||
elif role == "assistant":
|
||||
# For assistant, we need explicit output_text
|
||||
temp_formatted.append(
|
||||
{"role": role, "content": [{"type": "output_text", "text": content}]}
|
||||
)
|
||||
elif role == "system":
|
||||
# System messages need to be formatted as input_text as well
|
||||
temp_formatted.append(
|
||||
{"role": role, "content": [{"type": "input_text", "text": content}]}
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
# For list content, convert each item to the correct type based on role
|
||||
formatted_content = []
|
||||
has_image_in_this_message = False
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_type = item.get("type")
|
||||
|
||||
if role == "user":
|
||||
# Handle user message formatting
|
||||
if item_type == "text" or item_type == "input_text":
|
||||
# Text from user is input_text
|
||||
formatted_content.append(
|
||||
{"type": "input_text", "text": item.get("text", "")}
|
||||
)
|
||||
elif (item_type == "image" or item_type == "image_url") and not has_image:
|
||||
# Only include the first/most recent image we encounter
|
||||
if item_type == "image":
|
||||
# Image from user is input_image
|
||||
source = item.get("source", {})
|
||||
if source.get("type") == "base64" and "data" in source:
|
||||
formatted_content.append(
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{source['data']}",
|
||||
}
|
||||
)
|
||||
has_image = True
|
||||
has_image_in_this_message = True
|
||||
elif item_type == "image_url":
|
||||
# Convert "image_url" to "input_image"
|
||||
formatted_content.append(
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": item.get("image_url", {}).get("url", ""),
|
||||
}
|
||||
)
|
||||
has_image = True
|
||||
has_image_in_this_message = True
|
||||
elif role == "assistant":
|
||||
# Handle assistant message formatting - only output_text is supported
|
||||
if item_type == "text" or item_type == "output_text":
|
||||
formatted_content.append(
|
||||
{"type": "output_text", "text": item.get("text", "")}
|
||||
)
|
||||
|
||||
if formatted_content:
|
||||
# If this message had an image, mark it for inclusion
|
||||
temp_formatted.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": formatted_content,
|
||||
"_had_image": has_image_in_this_message, # Temporary marker
|
||||
}
|
||||
)
|
||||
|
||||
# Reverse back to original order and cleanup
|
||||
for msg in reversed(temp_formatted):
|
||||
# Remove our temporary marker
|
||||
if "_had_image" in msg:
|
||||
del msg["_had_image"]
|
||||
formatted_messages.append(msg)
|
||||
|
||||
# Log summary for debugging
|
||||
num_images = sum(
|
||||
1
|
||||
for msg in formatted_messages
|
||||
for item in (msg.get("content", []) if isinstance(msg.get("content"), list) else [])
|
||||
if isinstance(item, dict) and item.get("type") == "input_image"
|
||||
)
|
||||
logger.info(f"Formatted {len(messages)} messages for OpenAI API with {num_images} images")
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _sanitize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sanitize response for logging by removing large image data.
|
||||
|
||||
Args:
|
||||
response: Response to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized response
|
||||
"""
|
||||
from .utils import sanitize_message
|
||||
|
||||
# Deep copy to avoid modifying the original
|
||||
sanitized = response.copy()
|
||||
|
||||
# Sanitize output items if present
|
||||
if "output" in sanitized and isinstance(sanitized["output"], list):
|
||||
sanitized["output"] = [sanitize_message(item) for item in sanitized["output"]]
|
||||
|
||||
return sanitized
|
||||
@@ -0,0 +1,454 @@
|
||||
"""OpenAI Agent Response API provider implementation."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional, AsyncGenerator, Callable, Awaitable, TYPE_CHECKING
|
||||
|
||||
from computer import Computer
|
||||
from ...core.base import BaseLoop
|
||||
from ...core.types import AgentResponse
|
||||
from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
||||
|
||||
from .api.client import OpenAIClient
|
||||
from .api_handler import OpenAIAPIHandler
|
||||
from .response_handler import OpenAIResponseHandler
|
||||
from .tools.manager import ToolManager
|
||||
from .types import LLMProvider, ResponseItemType
|
||||
from .prompts import SYSTEM_PROMPT
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAILoop(BaseLoop):
|
||||
"""OpenAI-specific implementation of the agent loop.
|
||||
|
||||
This class extends BaseLoop to provide specialized support for OpenAI's Agent Response API
|
||||
with computer control capabilities.
|
||||
"""
|
||||
|
||||
###########################################
|
||||
# INITIALIZATION AND CONFIGURATION
|
||||
###########################################
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
computer: Computer,
|
||||
model: str = "computer-use-preview",
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
save_trajectory: bool = True,
|
||||
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the OpenAI loop.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Model name (ignored, always uses computer-use-preview)
|
||||
computer: Computer instance
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
base_dir: Base directory for saving experiment data
|
||||
max_retries: Maximum number of retries for API calls
|
||||
retry_delay: Delay between retries in seconds
|
||||
save_trajectory: Whether to save trajectory data
|
||||
acknowledge_safety_check_callback: Optional callback for safety check acknowledgment
|
||||
**kwargs: Additional provider-specific arguments
|
||||
"""
|
||||
# Always use computer-use-preview model
|
||||
if model != "computer-use-preview":
|
||||
logger.info(
|
||||
f"Overriding provided model '{model}' with required model 'computer-use-preview'"
|
||||
)
|
||||
|
||||
# Initialize base class with core config
|
||||
super().__init__(
|
||||
computer=computer,
|
||||
model="computer-use-preview", # Always use computer-use-preview
|
||||
api_key=api_key,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
base_dir=base_dir,
|
||||
save_trajectory=save_trajectory,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize message manager
|
||||
self.message_manager = StandardMessageManager(
|
||||
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
||||
)
|
||||
|
||||
# OpenAI-specific attributes
|
||||
self.provider = LLMProvider.OPENAI
|
||||
self.client = None
|
||||
self.retry_count = 0
|
||||
self.acknowledge_safety_check_callback = acknowledge_safety_check_callback
|
||||
self.queue = asyncio.Queue() # Initialize queue
|
||||
self.last_response_id = None # Store the last response ID across runs
|
||||
|
||||
# Initialize handlers
|
||||
self.api_handler = OpenAIAPIHandler(self)
|
||||
self.response_handler = OpenAIResponseHandler(self)
|
||||
|
||||
# Initialize tool manager with callback
|
||||
self.tool_manager = ToolManager(
|
||||
computer=computer, acknowledge_safety_check_callback=acknowledge_safety_check_callback
|
||||
)
|
||||
|
||||
###########################################
|
||||
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def initialize_client(self) -> None:
|
||||
"""Initialize the OpenAI API client and tools.
|
||||
|
||||
Implements abstract method from BaseLoop to set up the OpenAI-specific
|
||||
client, tool manager, and message manager.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Initializing OpenAI client with model {self.model}...")
|
||||
|
||||
# Initialize client
|
||||
self.client = OpenAIClient(api_key=self.api_key, model=self.model)
|
||||
|
||||
# Initialize tool manager
|
||||
await self.tool_manager.initialize()
|
||||
|
||||
logger.info(f"Initialized OpenAI client with model {self.model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing OpenAI client: {str(e)}")
|
||||
self.client = None
|
||||
raise RuntimeError(f"Failed to initialize OpenAI client: {str(e)}")
|
||||
|
||||
###########################################
|
||||
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
messages: List of message objects in standard format
|
||||
|
||||
Yields:
|
||||
Agent response format
|
||||
"""
|
||||
try:
|
||||
logger.info("Starting OpenAI loop run")
|
||||
|
||||
# Create queue for response streaming
|
||||
queue = asyncio.Queue()
|
||||
|
||||
# Ensure client is initialized
|
||||
if self.client is None:
|
||||
logger.info("Initializing client...")
|
||||
await self.initialize_client()
|
||||
if self.client is None:
|
||||
raise RuntimeError("Failed to initialize client")
|
||||
logger.info("Client initialized successfully")
|
||||
|
||||
# Start loop in background task
|
||||
loop_task = asyncio.create_task(self._run_loop(queue, messages))
|
||||
|
||||
# Process and yield messages as they arrive
|
||||
while True:
|
||||
try:
|
||||
item = await queue.get()
|
||||
if item is None: # Stop signal
|
||||
break
|
||||
yield item
|
||||
queue.task_done()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing queue item: {str(e)}")
|
||||
continue
|
||||
|
||||
# Wait for loop to complete
|
||||
await loop_task
|
||||
|
||||
# Send completion message
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": "Task completed successfully.",
|
||||
"metadata": {"title": "✅ Complete"},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing task: {str(e)}")
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
|
||||
###########################################
|
||||
# AGENT LOOP IMPLEMENTATION
|
||||
###########################################
|
||||
|
||||
async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
queue: Queue for response streaming
|
||||
messages: List of messages in standard format
|
||||
"""
|
||||
try:
|
||||
# Use the instance-level last_response_id instead of creating a local variable
|
||||
# This way it persists between runs
|
||||
|
||||
# Capture initial screenshot
|
||||
try:
|
||||
# Take screenshot
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
logger.info("Screenshot captured successfully")
|
||||
|
||||
# Convert to base64 if needed
|
||||
if isinstance(screenshot, bytes):
|
||||
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
|
||||
else:
|
||||
screenshot_base64 = screenshot
|
||||
|
||||
# Save screenshot if requested
|
||||
if self.save_trajectory:
|
||||
# Ensure screenshot_base64 is a string
|
||||
if not isinstance(screenshot_base64, str):
|
||||
logger.warning(
|
||||
"Converting non-string screenshot_base64 to string for _save_screenshot"
|
||||
)
|
||||
if isinstance(screenshot_base64, (bytearray, memoryview)):
|
||||
screenshot_base64 = base64.b64encode(screenshot_base64).decode("utf-8")
|
||||
self._save_screenshot(screenshot_base64, action_type="state")
|
||||
logger.info("Screenshot saved to trajectory")
|
||||
|
||||
# First add any existing user messages that were passed to run()
|
||||
user_query = None
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
user_content = msg.get("content", "")
|
||||
if isinstance(user_content, str) and user_content:
|
||||
user_query = user_content
|
||||
# Add the user's original query to the message manager
|
||||
self.message_manager.add_user_message(
|
||||
[{"type": "text", "text": user_content}]
|
||||
)
|
||||
break
|
||||
|
||||
# Add screenshot to message manager
|
||||
message_content = [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": screenshot_base64,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Add appropriate text with the screenshot
|
||||
message_content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": user_query,
|
||||
}
|
||||
)
|
||||
|
||||
# Add the screenshot and text to the message manager
|
||||
self.message_manager.add_user_message(message_content)
|
||||
|
||||
# Process user request and convert our standard message format to one OpenAI expects
|
||||
messages = self.message_manager.messages
|
||||
logger.info(f"Starting agent loop with {len(messages)} messages")
|
||||
|
||||
# Create initial turn directory
|
||||
if self.save_trajectory:
|
||||
self._create_turn_dir()
|
||||
|
||||
# Call API
|
||||
screen_size = await self.computer.interface.get_screen_size()
|
||||
response = await self.api_handler.send_initial_request(
|
||||
messages=messages,
|
||||
display_width=str(screen_size["width"]),
|
||||
display_height=str(screen_size["height"]),
|
||||
previous_response_id=self.last_response_id,
|
||||
)
|
||||
|
||||
# Store response ID for next request
|
||||
# OpenAI API response structure: the ID is in the response dictionary
|
||||
if isinstance(response, dict) and "id" in response:
|
||||
self.last_response_id = response["id"] # Update instance variable
|
||||
logger.info(f"Received response with ID: {self.last_response_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not find response ID in OpenAI response: {type(response)}"
|
||||
)
|
||||
# Don't reset last_response_id to None - keep the previous value if available
|
||||
|
||||
# Process API response
|
||||
await queue.put(response)
|
||||
|
||||
# Loop to continue processing responses until task is complete
|
||||
task_complete = False
|
||||
while not task_complete:
|
||||
# Check if there are any computer calls
|
||||
output_items = response.get("output", []) or []
|
||||
computer_calls = [
|
||||
item for item in output_items if item.get("type") == "computer_call"
|
||||
]
|
||||
|
||||
if not computer_calls:
|
||||
logger.info("No computer calls in response, task may be complete.")
|
||||
task_complete = True
|
||||
continue
|
||||
|
||||
# Process the first computer call
|
||||
computer_call = computer_calls[0]
|
||||
action = computer_call.get("action", {})
|
||||
call_id = computer_call.get("call_id")
|
||||
|
||||
# Check for safety checks
|
||||
pending_safety_checks = computer_call.get("pending_safety_checks", [])
|
||||
acknowledged_safety_checks = []
|
||||
|
||||
if pending_safety_checks:
|
||||
# Log safety checks
|
||||
for check in pending_safety_checks:
|
||||
logger.warning(
|
||||
f"Safety check: {check.get('code')} - {check.get('message')}"
|
||||
)
|
||||
|
||||
# If we have a callback, use it to acknowledge safety checks
|
||||
if self.acknowledge_safety_check_callback:
|
||||
acknowledged = await self.acknowledge_safety_check_callback(
|
||||
pending_safety_checks
|
||||
)
|
||||
if not acknowledged:
|
||||
logger.warning("Safety check acknowledgment failed")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Safety checks were not acknowledged. Cannot proceed with action.",
|
||||
"metadata": {"title": "⚠️ Safety Warning"},
|
||||
}
|
||||
)
|
||||
continue
|
||||
acknowledged_safety_checks = pending_safety_checks
|
||||
|
||||
# Execute the action
|
||||
try:
|
||||
# Create a new turn directory for this action if saving trajectories
|
||||
if self.save_trajectory:
|
||||
self._create_turn_dir()
|
||||
|
||||
# Execute the tool
|
||||
result = await self.tool_manager.execute_tool("computer", action)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
if isinstance(screenshot, bytes):
|
||||
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
|
||||
else:
|
||||
screenshot_base64 = screenshot
|
||||
|
||||
# Create computer_call_output
|
||||
computer_call_output = {
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
||||
},
|
||||
}
|
||||
|
||||
# Add acknowledged safety checks if any
|
||||
if acknowledged_safety_checks:
|
||||
computer_call_output["acknowledged_safety_checks"] = (
|
||||
acknowledged_safety_checks
|
||||
)
|
||||
|
||||
# Save to message manager for history
|
||||
self.message_manager.add_system_message(
|
||||
f"[Computer action executed: {action.get('type')}]"
|
||||
)
|
||||
self.message_manager.add_user_message([computer_call_output])
|
||||
|
||||
# For follow-up requests with previous_response_id, we only need to send
|
||||
# the computer_call_output, not the full message history
|
||||
# The API handler will extract this from the message history
|
||||
if isinstance(self.last_response_id, str):
|
||||
response = await self.api_handler.send_computer_call_request(
|
||||
messages=self.message_manager.messages,
|
||||
display_width=str(screen_size["width"]),
|
||||
display_height=str(screen_size["height"]),
|
||||
previous_response_id=self.last_response_id, # Use instance variable
|
||||
)
|
||||
|
||||
# Store response ID for next request
|
||||
if isinstance(response, dict) and "id" in response:
|
||||
self.last_response_id = response["id"] # Update instance variable
|
||||
logger.info(f"Received response with ID: {self.last_response_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not find response ID in OpenAI response: {type(response)}"
|
||||
)
|
||||
# Keep using the previous response ID if we can't find a new one
|
||||
|
||||
# Process the response
|
||||
# await self.response_handler.process_response(response, queue)
|
||||
await queue.put(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing computer action: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error executing action: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
task_complete = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing initial screenshot: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error capturing screenshot: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
await queue.put(None) # Signal that we're done
|
||||
return
|
||||
|
||||
# Signal that we're done
|
||||
await queue.put(None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _run_loop: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
await queue.put(None) # Signal that we're done
|
||||
|
||||
def get_last_response_id(self) -> Optional[str]:
|
||||
"""Get the last response ID.
|
||||
|
||||
Returns:
|
||||
The last response ID or None if no response has been received
|
||||
"""
|
||||
return self.last_response_id
|
||||
|
||||
def set_last_response_id(self, response_id: str) -> None:
|
||||
"""Set the last response ID.
|
||||
|
||||
Args:
|
||||
response_id: OpenAI response ID to set
|
||||
"""
|
||||
self.last_response_id = response_id
|
||||
logger.info(f"Manually set response ID to: {self.last_response_id}")
|
||||
@@ -0,0 +1,20 @@
|
||||
"""Prompts for OpenAI Agent Response API."""
|
||||
|
||||
# System prompt to be used when no specific system prompt is provided
|
||||
SYSTEM_PROMPT = """
|
||||
You are a helpful assistant that can control a computer to help users accomplish tasks.
|
||||
You have access to a computer where you can:
|
||||
- Click, scroll, and type to interact with the interface
|
||||
- Use keyboard shortcuts and special keys
|
||||
- Read text and images from the screen
|
||||
- Navigate and interact with applications
|
||||
|
||||
A few important rules to follow:
|
||||
1. Only perform actions that the user has requested or that directly support their task
|
||||
2. If uncertain about what the user wants, ask for clarification
|
||||
3. Explain your steps clearly when working on complex tasks
|
||||
4. Be careful when interacting with sensitive data or performing potentially destructive actions
|
||||
5. Always respect user privacy and avoid accessing personal information unless necessary for the task
|
||||
|
||||
When in doubt about how to accomplish something, try to break it down into simpler steps using available computer actions.
|
||||
"""
|
||||
@@ -0,0 +1,205 @@
|
||||
"""Response handler for the OpenAI provider."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, AsyncGenerator
|
||||
import base64
|
||||
|
||||
from ...core.types import AgentResponse
|
||||
from .types import ResponseItemType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .loop import OpenAILoop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIResponseHandler:
|
||||
"""Handler for OpenAI API responses."""
|
||||
|
||||
def __init__(self, loop: "OpenAILoop"):
|
||||
"""Initialize the response handler.
|
||||
|
||||
Args:
|
||||
loop: OpenAI loop instance
|
||||
"""
|
||||
self.loop = loop
|
||||
logger.info("Initialized OpenAI response handler")
|
||||
|
||||
async def process_response(self, response: Dict[str, Any], queue: asyncio.Queue) -> None:
|
||||
"""Process the response from the OpenAI API.
|
||||
|
||||
Args:
|
||||
response: Response from the API
|
||||
queue: Queue for response streaming
|
||||
"""
|
||||
try:
|
||||
# Get output items
|
||||
output_items = response.get("output", []) or []
|
||||
|
||||
# Process each output item
|
||||
for item in output_items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_type = item.get("type")
|
||||
|
||||
# For computer_call items, we only need to add to the queue
|
||||
# The loop is now handling executing the action and creating the computer_call_output
|
||||
if item_type == ResponseItemType.COMPUTER_CALL:
|
||||
# Send computer_call to queue so it can be processed
|
||||
await queue.put(item)
|
||||
|
||||
elif item_type == ResponseItemType.MESSAGE:
|
||||
# Send message to queue
|
||||
await queue.put(item)
|
||||
|
||||
elif item_type == ResponseItemType.REASONING:
|
||||
# Process reasoning summary
|
||||
summary = None
|
||||
if "summary" in item and isinstance(item["summary"], list):
|
||||
for summary_item in item["summary"]:
|
||||
if (
|
||||
isinstance(summary_item, dict)
|
||||
and summary_item.get("type") == "summary_text"
|
||||
):
|
||||
summary = summary_item.get("text")
|
||||
break
|
||||
|
||||
if summary:
|
||||
# Log the reasoning summary
|
||||
logger.info(f"Reasoning summary: {summary}")
|
||||
|
||||
# Send reasoning summary to queue with a special format
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"[Reasoning: {summary}]",
|
||||
"metadata": {"title": "💭 Reasoning", "is_summary": True},
|
||||
}
|
||||
)
|
||||
|
||||
# Also pass the original reasoning item to the queue for complete context
|
||||
await queue.put(item)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing response: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error processing response: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
|
||||
def _process_message_item(self, item: Dict[str, Any]) -> AgentResponse:
|
||||
"""Process a message item from the response.
|
||||
|
||||
Args:
|
||||
item: Message item from the response
|
||||
|
||||
Returns:
|
||||
Processed message in AgentResponse format
|
||||
"""
|
||||
# Extract content items - add null check
|
||||
content_items = item.get("content", []) or []
|
||||
|
||||
# Extract text from content items - use output_text type from OpenAI
|
||||
text = ""
|
||||
for content_item in content_items:
|
||||
# Skip if content_item is None or not a dict
|
||||
if content_item is None or not isinstance(content_item, dict):
|
||||
continue
|
||||
|
||||
# In OpenAI Agent Response API, text content is in "output_text" type items
|
||||
if content_item.get("type") == "output_text":
|
||||
text += content_item.get("text", "")
|
||||
|
||||
# Create agent response
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": text
|
||||
or "I don't have a response for that right now.", # Provide fallback when text is empty
|
||||
"metadata": {"title": "💬 Response"},
|
||||
}
|
||||
|
||||
async def _process_computer_call(self, item: Dict[str, Any], queue: asyncio.Queue) -> None:
|
||||
"""Process a computer call item from the response.
|
||||
|
||||
Args:
|
||||
item: Computer call item
|
||||
queue: Queue to add responses to
|
||||
"""
|
||||
try:
|
||||
# Log the computer call
|
||||
action = item.get("action", {}) or {}
|
||||
if not isinstance(action, dict):
|
||||
logger.warning(f"Expected dict for action, got {type(action)}")
|
||||
action = {}
|
||||
|
||||
action_type = action.get("type", "unknown")
|
||||
logger.info(f"Processing computer call: {action_type}")
|
||||
|
||||
# Execute the tool call
|
||||
result = await self.loop.tool_manager.execute_tool("computer", action)
|
||||
|
||||
# Add any message to the conversation history and queue
|
||||
if result and result.base64_image:
|
||||
# Update message history with the call output
|
||||
self.loop.message_manager.add_user_message(
|
||||
[{"type": "text", "text": f"[Computer action completed: {action_type}]"}]
|
||||
)
|
||||
|
||||
# Add image to messages (using correct content types for Agent Response API)
|
||||
self.loop.message_manager.add_user_message(
|
||||
[
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": result.base64_image,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# If browser environment, include URL if available
|
||||
# if (
|
||||
# hasattr(self.loop.computer, "environment")
|
||||
# and self.loop.computer.environment == "browser"
|
||||
# ):
|
||||
# try:
|
||||
# if hasattr(self.loop.computer.interface, "get_current_url"):
|
||||
# current_url = await self.loop.computer.interface.get_current_url()
|
||||
# self.loop.message_manager.add_user_message(
|
||||
# [
|
||||
# {
|
||||
# "type": "text",
|
||||
# "text": f"Current URL: {current_url}",
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to get current URL: {str(e)}")
|
||||
|
||||
# Log successful completion
|
||||
logger.info(f"Computer call {action_type} executed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing computer call: {str(e)}")
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
# Add error to conversation
|
||||
self.loop.message_manager.add_user_message(
|
||||
[{"type": "text", "text": f"Error executing computer action: {str(e)}"}]
|
||||
)
|
||||
|
||||
# Send error to queue
|
||||
error_response = {
|
||||
"role": "assistant",
|
||||
"content": f"Error executing computer action: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
await queue.put(error_response)
|
||||
@@ -0,0 +1,15 @@
|
||||
"""OpenAI tools module for computer control."""
|
||||
|
||||
from .manager import ToolManager
|
||||
from .computer import ComputerTool
|
||||
from .base import BaseOpenAITool, ToolResult, ToolError, ToolFailure, CLIResult
|
||||
|
||||
__all__ = [
|
||||
"ToolManager",
|
||||
"ComputerTool",
|
||||
"BaseOpenAITool",
|
||||
"ToolResult",
|
||||
"ToolError",
|
||||
"ToolFailure",
|
||||
"CLIResult",
|
||||
]
|
||||
@@ -0,0 +1,79 @@
|
||||
"""OpenAI-specific tool base classes."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ....core.tools.base import BaseTool
|
||||
|
||||
|
||||
class BaseOpenAITool(BaseTool, metaclass=ABCMeta):
|
||||
"""Abstract base class for OpenAI-defined tools."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the base OpenAI tool."""
|
||||
# No specific initialization needed yet, but included for future extensibility
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> Any:
|
||||
"""Executes the tool with the given arguments."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to OpenAI-specific API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters for OpenAI API
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class ToolResult:
|
||||
"""Represents the result of a tool execution."""
|
||||
|
||||
output: str | None = None
|
||||
error: str | None = None
|
||||
base64_image: str | None = None
|
||||
system: str | None = None
|
||||
content: list[dict] | None = None
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field.name) for field in fields(self))
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
content=self.content or other.content, # Use first non-None content
|
||||
)
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new ToolResult with the given fields replaced."""
|
||||
return replace(self, **kwargs)
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""A ToolResult that can be rendered as a CLI output."""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
class ToolError(Exception):
|
||||
"""Raised when a tool encounters an error."""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
@@ -0,0 +1,319 @@
|
||||
"""Computer tool for OpenAI."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from typing import Literal, Any, Dict, Optional, List, Union
|
||||
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseOpenAITool, ToolError, ToolResult
|
||||
from ....core.tools.computer import BaseComputerTool
|
||||
|
||||
TYPING_DELAY_MS = 12
|
||||
TYPING_GROUP_SIZE = 50
|
||||
|
||||
# Key mapping for special keys
|
||||
KEY_MAPPING = {
|
||||
"enter": "return",
|
||||
"backspace": "delete",
|
||||
"delete": "forwarddelete",
|
||||
"escape": "esc",
|
||||
"pageup": "page_up",
|
||||
"pagedown": "page_down",
|
||||
"arrowup": "up",
|
||||
"arrowdown": "down",
|
||||
"arrowleft": "left",
|
||||
"arrowright": "right",
|
||||
"home": "home",
|
||||
"end": "end",
|
||||
"tab": "tab",
|
||||
"space": "space",
|
||||
"shift": "shift",
|
||||
"control": "control",
|
||||
"alt": "alt",
|
||||
"meta": "command",
|
||||
}
|
||||
|
||||
Action = Literal[
|
||||
"key",
|
||||
"type",
|
||||
"mouse_move",
|
||||
"left_click",
|
||||
"right_click",
|
||||
"double_click",
|
||||
"screenshot",
|
||||
"scroll",
|
||||
]
|
||||
|
||||
|
||||
class ComputerTool(BaseComputerTool, BaseOpenAITool):
|
||||
"""
|
||||
A tool that allows the agent to interact with the screen, keyboard, and mouse of the current computer.
|
||||
"""
|
||||
|
||||
name: Literal["computer"] = "computer"
|
||||
api_type: Literal["computer_use_preview"] = "computer_use_preview"
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
display_num: Optional[int] = None
|
||||
computer: Computer # The CUA Computer instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_screenshot_delay = 1.0 # macOS is generally faster than X11
|
||||
_scaling_enabled = True
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the computer tool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance
|
||||
"""
|
||||
self.computer = computer
|
||||
self.width = None
|
||||
self.height = None
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize the base computer tool first
|
||||
BaseComputerTool.__init__(self, computer)
|
||||
# Then initialize the OpenAI tool
|
||||
BaseOpenAITool.__init__(self)
|
||||
|
||||
# Additional initialization
|
||||
self.width = None # Will be initialized from computer interface
|
||||
self.height = None # Will be initialized from computer interface
|
||||
self.display_num = None
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
if self.width is None or self.height is None:
|
||||
raise RuntimeError(
|
||||
"Screen dimensions not initialized. Call initialize_dimensions() first."
|
||||
)
|
||||
return {
|
||||
"type": self.api_type,
|
||||
"display_width": self.width,
|
||||
"display_height": self.height,
|
||||
"display_number": self.display_num,
|
||||
}
|
||||
|
||||
async def initialize_dimensions(self):
|
||||
"""Initialize screen dimensions from the computer interface."""
|
||||
try:
|
||||
display_size = await self.computer.interface.get_screen_size()
|
||||
self.width = display_size["width"]
|
||||
self.height = display_size["height"]
|
||||
assert isinstance(self.width, int) and isinstance(self.height, int)
|
||||
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
|
||||
except Exception as e:
|
||||
# Fall back to defaults if we can't get accurate dimensions
|
||||
self.width = 1024
|
||||
self.height = 768
|
||||
self.logger.warning(
|
||||
f"Failed to get screen dimensions, using defaults: {self.width}x{self.height}. Error: {e}"
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
type: str, # OpenAI uses 'type' instead of 'action'
|
||||
text: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
# Ensure dimensions are initialized
|
||||
if self.width is None or self.height is None:
|
||||
await self.initialize_dimensions()
|
||||
if self.width is None or self.height is None:
|
||||
raise ToolError("Failed to initialize screen dimensions")
|
||||
|
||||
if type == "type":
|
||||
if text is None:
|
||||
raise ToolError("text is required for type action")
|
||||
return await self.handle_typing(text)
|
||||
elif type == "click":
|
||||
# Map button to correct action name
|
||||
button = kwargs.get("button")
|
||||
if button is None:
|
||||
raise ToolError("button is required for click action")
|
||||
return await self.handle_click(button, kwargs["x"], kwargs["y"])
|
||||
elif type == "keypress":
|
||||
# Check for keys in kwargs if text is None
|
||||
if text is None:
|
||||
if "keys" in kwargs and isinstance(kwargs["keys"], list):
|
||||
# Pass the keys list directly instead of joining and then splitting
|
||||
return await self.handle_key(kwargs["keys"])
|
||||
else:
|
||||
raise ToolError("Either 'text' or 'keys' is required for keypress action")
|
||||
return await self.handle_key(text)
|
||||
elif type == "mouse_move":
|
||||
if "coordinates" not in kwargs:
|
||||
raise ToolError("coordinates is required for mouse_move action")
|
||||
return await self.handle_mouse_move(
|
||||
kwargs["coordinates"][0], kwargs["coordinates"][1]
|
||||
)
|
||||
elif type == "scroll":
|
||||
# Get x, y coordinates directly from kwargs
|
||||
x = kwargs.get("x")
|
||||
y = kwargs.get("y")
|
||||
if x is None or y is None:
|
||||
raise ToolError("x and y coordinates are required for scroll action")
|
||||
scroll_x = kwargs.get("scroll_x", 0)
|
||||
scroll_y = kwargs.get("scroll_y", 0)
|
||||
return await self.handle_scroll(x, y, scroll_x, scroll_y)
|
||||
elif type == "screenshot":
|
||||
return await self.screenshot()
|
||||
elif type == "wait":
|
||||
duration = kwargs.get("duration", 1.0)
|
||||
await asyncio.sleep(duration)
|
||||
return await self.screenshot()
|
||||
else:
|
||||
raise ToolError(f"Unsupported action: {type}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in ComputerTool.__call__: {str(e)}")
|
||||
raise ToolError(f"Failed to execute {type}: {str(e)}")
|
||||
|
||||
async def handle_click(self, button: str, x: int, y: int) -> ToolResult:
|
||||
"""Handle different click actions."""
|
||||
try:
|
||||
# Perform requested click action
|
||||
if button == "left":
|
||||
await self.computer.interface.left_click(x, y)
|
||||
elif button == "right":
|
||||
await self.computer.interface.right_click(x, y)
|
||||
elif button == "double":
|
||||
await self.computer.interface.double_click(x, y)
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Performed {button} click at ({x}, {y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_click: {str(e)}")
|
||||
raise ToolError(f"Failed to perform {button} click at ({x}, {y}): {str(e)}")
|
||||
|
||||
async def handle_typing(self, text: str) -> ToolResult:
|
||||
"""Handle typing text with a small delay between characters."""
|
||||
try:
|
||||
# Type the text with a small delay
|
||||
await self.computer.interface.type_text(text)
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Take screenshot after typing
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(output=f"Typed: {text}", base64_image=base64_screenshot)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_typing: {str(e)}")
|
||||
raise ToolError(f"Failed to type '{text}': {str(e)}")
|
||||
|
||||
async def handle_key(self, key: Union[str, List[str]]) -> ToolResult:
|
||||
"""Handle key press, supporting both single keys and combinations.
|
||||
|
||||
Args:
|
||||
key: Either a string (e.g. "ctrl+c") or a list of keys (e.g. ["ctrl", "c"])
|
||||
"""
|
||||
try:
|
||||
# Check if key is already a list
|
||||
if isinstance(key, list):
|
||||
keys = [k.strip().lower() for k in key]
|
||||
else:
|
||||
# Split key string into list if it's a combination (e.g. "ctrl+c")
|
||||
keys = [k.strip().lower() for k in key.split("+")]
|
||||
|
||||
# Map each key
|
||||
mapped_keys = [KEY_MAPPING.get(k, k) for k in keys]
|
||||
|
||||
if len(mapped_keys) > 1:
|
||||
# For key combinations (like Ctrl+C)
|
||||
for k in mapped_keys:
|
||||
await self.computer.interface.press_key(k)
|
||||
await asyncio.sleep(0.1)
|
||||
for k in reversed(mapped_keys):
|
||||
await self.computer.interface.press_key(k)
|
||||
else:
|
||||
# Single key press
|
||||
await self.computer.interface.press_key(mapped_keys[0])
|
||||
|
||||
# Wait briefly
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(output=f"Pressed key: {key}", base64_image=base64_screenshot)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_key: {str(e)}")
|
||||
raise ToolError(f"Failed to press key '{key}': {str(e)}")
|
||||
|
||||
async def handle_mouse_move(self, x: int, y: int) -> ToolResult:
|
||||
"""Handle mouse movement."""
|
||||
try:
|
||||
# Move cursor to position
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
|
||||
# Wait briefly
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(output=f"Moved cursor to ({x}, {y})", base64_image=base64_screenshot)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_mouse_move: {str(e)}")
|
||||
raise ToolError(f"Failed to move cursor to ({x}, {y}): {str(e)}")
|
||||
|
||||
async def handle_scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> ToolResult:
|
||||
"""Handle scrolling."""
|
||||
try:
|
||||
# Move cursor to position first
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
|
||||
# Scroll based on direction
|
||||
if scroll_y > 0:
|
||||
await self.computer.interface.scroll_down(abs(scroll_y))
|
||||
elif scroll_y < 0:
|
||||
await self.computer.interface.scroll_up(abs(scroll_y))
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Scrolled at ({x}, {y}) with delta ({scroll_x}, {scroll_y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_scroll: {str(e)}")
|
||||
raise ToolError(f"Failed to scroll at ({x}, {y}): {str(e)}")
|
||||
|
||||
async def screenshot(self) -> ToolResult:
|
||||
"""Take a screenshot."""
|
||||
try:
|
||||
# Take screenshot
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(output="Screenshot taken", base64_image=base64_screenshot)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in screenshot: {str(e)}")
|
||||
raise ToolError(f"Failed to take screenshot: {str(e)}")
|
||||
@@ -0,0 +1,106 @@
|
||||
"""Tool manager for the OpenAI provider."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Callable, Awaitable, Union
|
||||
|
||||
from computer import Computer
|
||||
from ..types import ComputerAction, ResponseItemType
|
||||
from .computer import ComputerTool
|
||||
from ....core.tools.base import ToolResult, ToolFailure
|
||||
from ....core.tools.collection import ToolCollection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""Manager for computer tools in the OpenAI agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
computer: Computer,
|
||||
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
|
||||
):
|
||||
"""Initialize the tool manager.
|
||||
|
||||
Args:
|
||||
computer: Computer instance
|
||||
acknowledge_safety_check_callback: Optional callback for safety check acknowledgment
|
||||
"""
|
||||
self.computer = computer
|
||||
self.acknowledge_safety_check_callback = acknowledge_safety_check_callback
|
||||
self._initialized = False
|
||||
self.computer_tool = ComputerTool(computer)
|
||||
self.tools = None
|
||||
logger.info("Initialized OpenAI ToolManager")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the tool manager."""
|
||||
if not self._initialized:
|
||||
logger.info("Initializing OpenAI ToolManager")
|
||||
|
||||
# Initialize the computer tool
|
||||
await self.computer_tool.initialize_dimensions()
|
||||
|
||||
# Initialize tool collection
|
||||
self.tools = ToolCollection(self.computer_tool)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("OpenAI ToolManager initialized")
|
||||
|
||||
async def get_tools_definition(self) -> List[Dict[str, Any]]:
|
||||
"""Get the tools definition for the OpenAI agent.
|
||||
|
||||
Returns:
|
||||
Tools definition for the OpenAI agent
|
||||
"""
|
||||
if not self.tools:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
|
||||
# For the OpenAI Agent Response API, we use a special "computer-preview" tool
|
||||
# which provides the correct interface for computer control
|
||||
display_width, display_height = await self._get_computer_dimensions()
|
||||
|
||||
# Get environment, using "mac" as default since we're on macOS
|
||||
environment = getattr(self.computer, "environment", "mac")
|
||||
|
||||
# Ensure environment is one of the allowed values
|
||||
if environment not in ["windows", "mac", "linux", "browser"]:
|
||||
logger.warning(f"Invalid environment value: {environment}, using 'mac' instead")
|
||||
environment = "mac"
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "computer-preview",
|
||||
"display_width": display_width,
|
||||
"display_height": display_height,
|
||||
"environment": environment,
|
||||
}
|
||||
]
|
||||
|
||||
async def _get_computer_dimensions(self) -> tuple[int, int]:
|
||||
"""Get the dimensions of the computer display.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
# If computer tool is initialized, use its dimensions
|
||||
if self.computer_tool.width is not None and self.computer_tool.height is not None:
|
||||
return (self.computer_tool.width, self.computer_tool.height)
|
||||
|
||||
# Try to get from computer.interface if available
|
||||
screen_size = await self.computer.interface.get_screen_size()
|
||||
return (int(screen_size["width"]), int(screen_size["height"]))
|
||||
|
||||
async def execute_tool(self, name: str, tool_input: Dict[str, Any]) -> ToolResult:
|
||||
"""Execute a tool with the given input.
|
||||
|
||||
Args:
|
||||
name: Name of the tool to execute
|
||||
tool_input: Input parameters for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool execution
|
||||
"""
|
||||
if not self.tools:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
return await self.tools.run(name=name, tool_input=tool_input)
|
||||
@@ -0,0 +1,36 @@
|
||||
"""Type definitions for the OpenAI provider."""
|
||||
|
||||
from enum import StrEnum, auto
|
||||
from typing import Dict, List, Optional, Union, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class LLMProvider(StrEnum):
|
||||
"""OpenAI LLM provider types."""
|
||||
|
||||
OPENAI = "openai"
|
||||
|
||||
|
||||
class ResponseItemType(StrEnum):
|
||||
"""Types of items in OpenAI Agent Response output."""
|
||||
|
||||
MESSAGE = "message"
|
||||
COMPUTER_CALL = "computer_call"
|
||||
COMPUTER_CALL_OUTPUT = "computer_call_output"
|
||||
REASONING = "reasoning"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputerAction:
|
||||
"""Represents a computer action to be performed."""
|
||||
|
||||
type: str
|
||||
x: Optional[int] = None
|
||||
y: Optional[int] = None
|
||||
text: Optional[str] = None
|
||||
button: Optional[str] = None
|
||||
keys: Optional[List[str]] = None
|
||||
ms: Optional[int] = None
|
||||
scroll_x: Optional[int] = None
|
||||
scroll_y: Optional[int] = None
|
||||
path: Optional[List[Dict[str, int]]] = None
|
||||
@@ -0,0 +1,98 @@
|
||||
"""Utility functions for the OpenAI provider."""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ...core.types import AgentResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def format_images_for_openai(images_base64: List[str]) -> List[Dict[str, Any]]:
|
||||
"""Format images for OpenAI Agent Response API.
|
||||
|
||||
Args:
|
||||
images_base64: List of base64 encoded images
|
||||
|
||||
Returns:
|
||||
List of formatted image items for Agent Response API
|
||||
"""
|
||||
return [
|
||||
{"type": "input_image", "image_url": f"data:image/png;base64,{image}"}
|
||||
for image in images_base64
|
||||
]
|
||||
|
||||
|
||||
def extract_message_content(message: Dict[str, Any]) -> str:
|
||||
"""Extract text content from a message.
|
||||
|
||||
Args:
|
||||
message: Message to extract content from
|
||||
|
||||
Returns:
|
||||
Text content from the message
|
||||
"""
|
||||
if isinstance(message.get("content"), str):
|
||||
return message["content"]
|
||||
|
||||
if isinstance(message.get("content"), list):
|
||||
text = ""
|
||||
role = message.get("role", "user")
|
||||
|
||||
for item in message["content"]:
|
||||
if isinstance(item, dict):
|
||||
# For user messages
|
||||
if role == "user" and item.get("type") == "input_text":
|
||||
text += item.get("text", "")
|
||||
# For standard format
|
||||
elif item.get("type") == "text":
|
||||
text += item.get("text", "")
|
||||
# For assistant messages in Agent Response API format
|
||||
elif item.get("type") == "output_text":
|
||||
text += item.get("text", "")
|
||||
return text
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def sanitize_message(msg: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sanitize a message for logging by removing large image data.
|
||||
|
||||
Args:
|
||||
msg: Message to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized message
|
||||
"""
|
||||
if not isinstance(msg, dict):
|
||||
return msg
|
||||
|
||||
sanitized = msg.copy()
|
||||
|
||||
# Handle message content
|
||||
if isinstance(sanitized.get("content"), list):
|
||||
sanitized_content = []
|
||||
for item in sanitized["content"]:
|
||||
if isinstance(item, dict):
|
||||
# Handle various image types
|
||||
if item.get("type") == "image_url" and "image_url" in item:
|
||||
sanitized_content.append({"type": "image_url", "image_url": "[omitted]"})
|
||||
elif item.get("type") == "input_image" and "image_url" in item:
|
||||
sanitized_content.append({"type": "input_image", "image_url": "[omitted]"})
|
||||
elif item.get("type") == "image" and "source" in item:
|
||||
sanitized_content.append({"type": "image", "source": "[omitted]"})
|
||||
else:
|
||||
sanitized_content.append(item)
|
||||
else:
|
||||
sanitized_content.append(item)
|
||||
sanitized["content"] = sanitized_content
|
||||
|
||||
# Handle computer_call_output
|
||||
if sanitized.get("type") == "computer_call_output" and "output" in sanitized:
|
||||
output = sanitized["output"]
|
||||
if isinstance(output, dict) and "image_url" in output:
|
||||
sanitized["output"] = {**output, "image_url": "[omitted]"}
|
||||
|
||||
return sanitized
|
||||
@@ -30,6 +30,10 @@ anthropic = [
|
||||
"anthropic>=0.49.0",
|
||||
"boto3>=1.35.81,<2.0.0",
|
||||
]
|
||||
openai = [
|
||||
"openai>=1.14.0,<2.0.0",
|
||||
"httpx>=0.27.0,<0.29.0",
|
||||
]
|
||||
som = [
|
||||
"torch>=2.2.1",
|
||||
"torchvision>=0.17.1",
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
# """Basic tests for the agent package."""
|
||||
|
||||
# import pytest
|
||||
# from agent import OmniComputerAgent, LLMProvider
|
||||
# from agent.base.agent import BaseComputerAgent
|
||||
# from computer import Computer
|
||||
|
||||
# def test_agent_import():
|
||||
# """Test that we can import the OmniComputerAgent class."""
|
||||
# assert OmniComputerAgent is not None
|
||||
# assert LLMProvider is not None
|
||||
|
||||
# def test_agent_init():
|
||||
# """Test that we can create an OmniComputerAgent instance."""
|
||||
# agent = OmniComputerAgent(
|
||||
# provider=LLMProvider.OPENAI,
|
||||
# use_host_computer_server=True
|
||||
# )
|
||||
# assert agent is not None
|
||||
|
||||
# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed")
|
||||
# def test_computer_agent_anthropic():
|
||||
# """Test creating an Anthropic agent."""
|
||||
# agent = ComputerAgent(provider=Provider.ANTHROPIC)
|
||||
# assert isinstance(agent._agent, BaseComputerAgent)
|
||||
|
||||
# def test_computer_agent_invalid_provider():
|
||||
# """Test creating an agent with an invalid provider."""
|
||||
# with pytest.raises(ValueError, match="Unsupported provider"):
|
||||
# ComputerAgent(provider="invalid_provider")
|
||||
|
||||
# def test_computer_agent_uninstalled_provider():
|
||||
# """Test creating an agent with an uninstalled provider."""
|
||||
# with pytest.raises(NotImplementedError, match="OpenAI provider not yet implemented"):
|
||||
# # OpenAI provider is not implemented yet
|
||||
# ComputerAgent(provider=Provider.OPENAI)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed")
|
||||
# async def test_agent_cleanup():
|
||||
# """Test agent cleanup."""
|
||||
# agent = ComputerAgent(provider=Provider.ANTHROPIC)
|
||||
# await agent.cleanup() # Should not raise any errors
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed")
|
||||
# async def test_agent_direct_initialization():
|
||||
# """Test direct initialization of the agent."""
|
||||
# # Create with default computer
|
||||
# agent = ComputerAgent(provider=Provider.ANTHROPIC)
|
||||
# try:
|
||||
# # Should not raise any errors
|
||||
# await agent.run("test task")
|
||||
# finally:
|
||||
# await agent.cleanup()
|
||||
|
||||
# # Create with custom computer
|
||||
# custom_computer = Computer(
|
||||
# display="1920x1080",
|
||||
# memory="8GB",
|
||||
# cpu="4",
|
||||
# os="macos",
|
||||
# use_host_computer_server=False,
|
||||
# )
|
||||
# agent = ComputerAgent(provider=Provider.ANTHROPIC, computer=custom_computer)
|
||||
# try:
|
||||
# # Should not raise any errors
|
||||
# await agent.run("test task")
|
||||
# finally:
|
||||
# await agent.cleanup()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed")
|
||||
# async def test_agent_context_manager():
|
||||
# """Test context manager initialization of the agent."""
|
||||
# # Test with default computer
|
||||
# async with ComputerAgent(provider=Provider.ANTHROPIC) as agent:
|
||||
# # Should not raise any errors
|
||||
# await agent.run("test task")
|
||||
|
||||
# # Test with custom computer
|
||||
# custom_computer = Computer(
|
||||
# display="1920x1080",
|
||||
# memory="8GB",
|
||||
# cpu="4",
|
||||
# os="macos",
|
||||
# use_host_computer_server=False,
|
||||
# )
|
||||
# async with ComputerAgent(provider=Provider.ANTHROPIC, computer=custom_computer) as agent:
|
||||
# # Should not raise any errors
|
||||
# await agent.run("test task")
|
||||
Reference in New Issue
Block a user