mirror of
https://github.com/trycua/computer.git
synced 2026-02-24 23:39:53 -06:00
Merge pull request #68 from trycua/feature/agent/openai-cua
Add OpenAI CUA
This commit is contained in:
@@ -23,50 +23,73 @@ async def run_agent_example():
|
||||
print("\n=== Example: ComputerAgent with OpenAI and Omni provider ===")
|
||||
|
||||
try:
|
||||
# Create Computer instance with default parameters
|
||||
computer = Computer(verbosity=logging.DEBUG)
|
||||
# Create Computer instance with async context manager
|
||||
async with Computer(verbosity=logging.DEBUG) as macos_computer:
|
||||
# Create agent with loop and provider
|
||||
agent = ComputerAgent(
|
||||
computer=macos_computer,
|
||||
loop=AgentLoop.OPENAI,
|
||||
# loop=AgentLoop.ANTHROPIC,
|
||||
# loop=AgentLoop.OMNI,
|
||||
model=LLM(provider=LLMProvider.OPENAI), # No model name for Operator CUA
|
||||
# model=LLM(provider=LLMProvider.OPENAI, name="gpt-4.5-preview"),
|
||||
# model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"),
|
||||
save_trajectory=True,
|
||||
only_n_most_recent_images=3,
|
||||
verbosity=logging.DEBUG,
|
||||
)
|
||||
|
||||
# Create agent with loop and provider
|
||||
agent = ComputerAgent(
|
||||
computer=computer,
|
||||
# loop=AgentLoop.ANTHROPIC,
|
||||
loop=AgentLoop.OMNI,
|
||||
model=LLM(provider=LLMProvider.OPENAI, name="gpt-4.5-preview"),
|
||||
# model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"),
|
||||
save_trajectory=True,
|
||||
only_n_most_recent_images=3,
|
||||
verbosity=logging.DEBUG,
|
||||
)
|
||||
tasks = [
|
||||
"Look for a repository named trycua/cua on GitHub.",
|
||||
"Check the open issues, open the most recent one and read it.",
|
||||
"Clone the repository in users/lume/projects if it doesn't exist yet.",
|
||||
"Open the repository with an app named Cursor (on the dock, black background and white cube icon).",
|
||||
"From Cursor, open Composer if not already open.",
|
||||
"Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.",
|
||||
]
|
||||
|
||||
tasks = [
|
||||
"Look for a repository named trycua/cua on GitHub.",
|
||||
"Check the open issues, open the most recent one and read it.",
|
||||
"Clone the repository in users/lume/projects if it doesn't exist yet.",
|
||||
"Open the repository with an app named Cursor (on the dock, black background and white cube icon).",
|
||||
"From Cursor, open Composer if not already open.",
|
||||
"Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.",
|
||||
]
|
||||
for i, task in enumerate(tasks):
|
||||
print(f"\nExecuting task {i}/{len(tasks)}: {task}")
|
||||
async for result in agent.run(task):
|
||||
print("Response ID: ", result.get("id"))
|
||||
|
||||
for i, task in enumerate(tasks):
|
||||
print(f"\nExecuting task {i}/{len(tasks)}: {task}")
|
||||
async for result in agent.run(task):
|
||||
# print(result)
|
||||
pass
|
||||
# Print detailed usage information
|
||||
usage = result.get("usage")
|
||||
if usage:
|
||||
print("\nUsage Details:")
|
||||
print(f" Input Tokens: {usage.get('input_tokens')}")
|
||||
if "input_tokens_details" in usage:
|
||||
print(f" Input Tokens Details: {usage.get('input_tokens_details')}")
|
||||
print(f" Output Tokens: {usage.get('output_tokens')}")
|
||||
if "output_tokens_details" in usage:
|
||||
print(f" Output Tokens Details: {usage.get('output_tokens_details')}")
|
||||
print(f" Total Tokens: {usage.get('total_tokens')}")
|
||||
|
||||
print(f"\n✅ Task {i+1}/{len(tasks)} completed: {task}")
|
||||
print("Response Text: ", result.get("text"))
|
||||
|
||||
# Print tools information
|
||||
tools = result.get("tools")
|
||||
if tools:
|
||||
print("\nTools:")
|
||||
print(tools)
|
||||
|
||||
# Print reasoning and tool call outputs
|
||||
outputs = result.get("output", [])
|
||||
for output in outputs:
|
||||
output_type = output.get("type")
|
||||
if output_type == "reasoning":
|
||||
print("\nReasoning Output:")
|
||||
print(output)
|
||||
elif output_type == "computer_call":
|
||||
print("\nTool Call Output:")
|
||||
print(output)
|
||||
|
||||
print(f"\n✅ Task {i+1}/{len(tasks)} completed: {task}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in run_omni_agent_example: {e}")
|
||||
logger.error(f"Error in run_agent_example: {e}")
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
# Clean up resources
|
||||
if computer and computer._initialized:
|
||||
try:
|
||||
# await computer.stop()
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping computer: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -15,9 +15,7 @@
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
**Agent** is a Computer Use (CUA) framework for running multi-app agentic workflows targeting macOS and Linux sandbox, supporting local (Ollama) and cloud model providers (OpenAI, Anthropic, Groq, DeepSeek, Qwen). The framework integrates with Microsoft's OmniParser for enhanced UI understanding and interaction.
|
||||
|
||||
> While our north star is to create a 1-click experience, this preview of Agent might be still a bit rough around the edges. We appreciate your patience as we work to improve the experience.
|
||||
**cua-agent** is a general Computer-Use framework for running multi-app agentic workflows targeting macOS and Linux sandbox created with Cua, supporting local (Ollama) and cloud model providers (OpenAI, Anthropic, Groq, DeepSeek, Qwen).
|
||||
|
||||
### Get started with Agent
|
||||
|
||||
@@ -27,18 +25,92 @@
|
||||
|
||||
## Install
|
||||
|
||||
### cua-agent
|
||||
|
||||
```bash
|
||||
pip install "cua-agent[all]"
|
||||
|
||||
# or install specific loop providers
|
||||
pip install "cua-agent[anthropic]"
|
||||
pip install "cua-agent[omni]"
|
||||
pip install "cua-agent[openai]" # OpenAI Cua Loop
|
||||
pip install "cua-agent[anthropic]" # Anthropic Cua Loop
|
||||
pip install "cua-agent[omni]" # Cua Loop based on OmniParser
|
||||
```
|
||||
|
||||
## Run
|
||||
|
||||
```bash
|
||||
async with Computer() as macos_computer:
|
||||
# Create agent with loop and provider
|
||||
agent = ComputerAgent(
|
||||
computer=macos_computer,
|
||||
loop=AgentLoop.OPENAI,
|
||||
model=LLM(provider=LLMProvider.OPENAI)
|
||||
)
|
||||
|
||||
tasks = [
|
||||
"Look for a repository named trycua/cua on GitHub.",
|
||||
"Check the open issues, open the most recent one and read it.",
|
||||
"Clone the repository in users/lume/projects if it doesn't exist yet.",
|
||||
"Open the repository with an app named Cursor (on the dock, black background and white cube icon).",
|
||||
"From Cursor, open Composer if not already open.",
|
||||
"Focus on the Composer text area, then write and submit a task to help resolve the GitHub issue.",
|
||||
]
|
||||
|
||||
for i, task in enumerate(tasks):
|
||||
print(f"\nExecuting task {i}/{len(tasks)}: {task}")
|
||||
async for result in agent.run(task):
|
||||
print(result)
|
||||
|
||||
print(f"\n✅ Task {i+1}/{len(tasks)} completed: {task}")
|
||||
```
|
||||
|
||||
Refer to these notebooks for step-by-step guides on how to use the Computer-Use Agent (CUA):
|
||||
|
||||
- [Agent Notebook](../../notebooks/agent_nb.ipynb) - Complete examples and workflows
|
||||
- [Agent Notebook](../../notebooks/agent_nb.ipynb) - Complete examples and workflows
|
||||
|
||||
## Agent Loops
|
||||
|
||||
The `cua-agent` package provides three agent loops variations, based on different CUA models providers and techniques:
|
||||
|
||||
| Agent Loop | Supported Models | Description | Set-Of-Marks |
|
||||
|:-----------|:-----------------|:------------|:-------------|
|
||||
| `AgentLoop.OPENAI` | • `computer_use_preview` | Use OpenAI Operator CUA model | Not Required |
|
||||
| `AgentLoop.ANTHROPIC` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219` | Use Anthropic Computer-Use | Not Required |
|
||||
| `AgentLoop.OMNI` <br>(preview) | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219`<br>• `gpt-4.5-preview`<br>• `gpt-4o`<br>• `gpt-4`<br>• `gpt-3.5-turbo` | Use OmniParser for element pixel-detection (SoM) and any VLMs | OmniParser |
|
||||
|
||||
## AgentResponse
|
||||
The `AgentResponse` class represents the structured output returned after each agent turn. It contains the agent's response, reasoning, tool usage, and other metadata. The response format aligns with the new [OpenAI Agent SDK specification](https://platform.openai.com/docs/api-reference/responses) for better consistency across different agent loops.
|
||||
|
||||
```python
|
||||
async for result in agent.run(task):
|
||||
print("Response ID: ", result.get("id"))
|
||||
|
||||
# Print detailed usage information
|
||||
usage = result.get("usage")
|
||||
if usage:
|
||||
print("\nUsage Details:")
|
||||
print(f" Input Tokens: {usage.get('input_tokens')}")
|
||||
if "input_tokens_details" in usage:
|
||||
print(f" Input Tokens Details: {usage.get('input_tokens_details')}")
|
||||
print(f" Output Tokens: {usage.get('output_tokens')}")
|
||||
if "output_tokens_details" in usage:
|
||||
print(f" Output Tokens Details: {usage.get('output_tokens_details')}")
|
||||
print(f" Total Tokens: {usage.get('total_tokens')}")
|
||||
|
||||
print("Response Text: ", result.get("text"))
|
||||
|
||||
# Print tools information
|
||||
tools = result.get("tools")
|
||||
if tools:
|
||||
print("\nTools:")
|
||||
print(tools)
|
||||
|
||||
# Print reasoning and tool call outputs
|
||||
outputs = result.get("output", [])
|
||||
for output in outputs:
|
||||
output_type = output.get("type")
|
||||
if output_type == "reasoning":
|
||||
print("\nReasoning Output:")
|
||||
print(output)
|
||||
elif output_type == "computer_call":
|
||||
print("\nTool Call Output:")
|
||||
print(output)
|
||||
```
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
# Agent Package Structure
|
||||
|
||||
## Overview
|
||||
The agent package provides a modular and extensible framework for AI-powered computer agents.
|
||||
|
||||
## Directory Structure
|
||||
```
|
||||
agent/
|
||||
├── __init__.py # Package exports
|
||||
├── core/ # Core functionality
|
||||
│ ├── __init__.py
|
||||
│ ├── computer_agent.py # Main entry point
|
||||
│ └── factory.py # Provider factory
|
||||
├── base/ # Base implementations
|
||||
│ ├── __init__.py
|
||||
│ ├── agent.py # Base agent class
|
||||
│ ├── core/ # Core components
|
||||
│ │ ├── callbacks.py
|
||||
│ │ ├── loop.py
|
||||
│ │ └── messages.py
|
||||
│ └── tools/ # Tool implementations
|
||||
├── providers/ # Provider implementations
|
||||
│ ├── __init__.py
|
||||
│ ├── anthropic/ # Anthropic provider
|
||||
│ │ ├── agent.py
|
||||
│ │ ├── loop.py
|
||||
│ │ └── tool_manager.py
|
||||
│ └── omni/ # Omni provider
|
||||
│ ├── agent.py
|
||||
│ ├── loop.py
|
||||
│ └── tool_manager.py
|
||||
└── types/ # Type definitions
|
||||
├── __init__.py
|
||||
├── base.py # Core types
|
||||
├── messages.py # Message types
|
||||
├── tools.py # Tool types
|
||||
└── providers/ # Provider-specific types
|
||||
├── anthropic.py
|
||||
└── omni.py
|
||||
```
|
||||
|
||||
## Key Components
|
||||
|
||||
### Core
|
||||
- `computer_agent.py`: Main entry point for creating and using agents
|
||||
- `factory.py`: Factory for creating provider-specific implementations
|
||||
|
||||
### Base
|
||||
- `agent.py`: Base agent implementation with shared functionality
|
||||
- `core/`: Core components used across providers
|
||||
- `tools/`: Shared tool implementations
|
||||
|
||||
### Providers
|
||||
Each provider follows the same structure:
|
||||
- `agent.py`: Provider-specific agent implementation
|
||||
- `loop.py`: Provider-specific message loop
|
||||
- `tool_manager.py`: Tool management for provider
|
||||
|
||||
### Types
|
||||
- `base.py`: Core type definitions
|
||||
- `messages.py`: Message-related types
|
||||
- `tools.py`: Tool-related types
|
||||
- `providers/`: Provider-specific type definitions
|
||||
@@ -49,7 +49,7 @@ except Exception as e:
|
||||
logger.warning(f"Error initializing telemetry: {e}")
|
||||
|
||||
from .providers.omni.types import LLMProvider, LLM
|
||||
from .core.loop import AgentLoop
|
||||
from .core.computer_agent import ComputerAgent
|
||||
from .core.factory import AgentLoop
|
||||
from .core.agent import ComputerAgent
|
||||
|
||||
__all__ = ["AgentLoop", "LLMProvider", "LLM", "ComputerAgent"]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Core agent components."""
|
||||
|
||||
from .loop import BaseLoop
|
||||
from .factory import BaseLoop
|
||||
from .messages import (
|
||||
BaseMessageManager,
|
||||
ImageRetentionConfig,
|
||||
|
||||
@@ -3,32 +3,18 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Dict, Optional, cast, List
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from computer import Computer
|
||||
from ..providers.anthropic.loop import AnthropicLoop
|
||||
from ..providers.omni.loop import OmniLoop
|
||||
from ..providers.omni.parser import OmniParser
|
||||
from ..providers.omni.types import LLMProvider, LLM
|
||||
from ..providers.omni.types import LLM
|
||||
from .. import AgentLoop
|
||||
from .messages import StandardMessageManager, ImageRetentionConfig
|
||||
from .types import AgentResponse
|
||||
from .factory import LoopFactory
|
||||
from .provider_config import DEFAULT_MODELS, ENV_VARS
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default models for different providers
|
||||
DEFAULT_MODELS = {
|
||||
LLMProvider.OPENAI: "gpt-4o",
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
}
|
||||
|
||||
# Map providers to their environment variable names
|
||||
ENV_VARS = {
|
||||
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
||||
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
||||
}
|
||||
|
||||
|
||||
class ComputerAgent:
|
||||
"""A computer agent that can perform automated tasks using natural language instructions."""
|
||||
@@ -98,35 +84,27 @@ class ComputerAgent:
|
||||
f"No model specified for provider {self.provider} and no default found"
|
||||
)
|
||||
|
||||
# Ensure computer is properly cast for typing purposes
|
||||
computer_instance = self.computer
|
||||
|
||||
# Get API key from environment if not provided
|
||||
actual_api_key = api_key or os.environ.get(ENV_VARS[self.provider], "")
|
||||
if not actual_api_key:
|
||||
raise ValueError(f"No API key provided for {self.provider}")
|
||||
|
||||
# Initialize the appropriate loop based on the loop parameter
|
||||
if loop == AgentLoop.ANTHROPIC:
|
||||
self._loop = AnthropicLoop(
|
||||
api_key=actual_api_key,
|
||||
model=actual_model_name,
|
||||
computer=computer_instance,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
)
|
||||
else:
|
||||
self._loop = OmniLoop(
|
||||
# Create the appropriate loop using the factory
|
||||
try:
|
||||
# Let the factory create the appropriate loop with needed components
|
||||
self._loop = LoopFactory.create_loop(
|
||||
loop_type=loop,
|
||||
provider=self.provider,
|
||||
computer=self.computer,
|
||||
model_name=actual_model_name,
|
||||
api_key=actual_api_key,
|
||||
model=actual_model_name,
|
||||
computer=computer_instance,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
trajectory_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
parser=OmniParser(),
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to create loop: {str(e)}")
|
||||
raise
|
||||
|
||||
# Initialize the message manager from the loop
|
||||
self.message_manager = self._loop.message_manager
|
||||
@@ -152,21 +130,6 @@ class ComputerAgent:
|
||||
else:
|
||||
logger.info("Computer already initialized, skipping initialization")
|
||||
|
||||
# Take a test screenshot to verify the computer is working
|
||||
logger.info("Testing computer with a screenshot...")
|
||||
try:
|
||||
test_screenshot = await self.computer.interface.screenshot()
|
||||
# Determine the screenshot size based on its type
|
||||
if isinstance(test_screenshot, (bytes, bytearray, memoryview)):
|
||||
size = len(test_screenshot)
|
||||
elif hasattr(test_screenshot, "base64_image"):
|
||||
size = len(test_screenshot.base64_image)
|
||||
else:
|
||||
size = "unknown"
|
||||
logger.info(f"Screenshot test successful, size: {size}")
|
||||
except Exception as e:
|
||||
logger.error(f"Screenshot test failed: {str(e)}")
|
||||
# Even though screenshot failed, we continue since some tests might not need it
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing computer in __aenter__: {str(e)}")
|
||||
raise
|
||||
@@ -232,7 +195,6 @@ class ComputerAgent:
|
||||
|
||||
# Execute the task and yield results
|
||||
async for result in self._loop.run(self.message_manager.messages):
|
||||
# Yield the result to the caller
|
||||
yield result
|
||||
|
||||
except Exception as e:
|
||||
@@ -1,35 +1,21 @@
|
||||
"""Base agent loop implementation."""
|
||||
"""Base loop definitions."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum, auto
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from computer import Computer
|
||||
from .experiment import ExperimentManager
|
||||
from .messages import StandardMessageManager, ImageRetentionConfig
|
||||
from .types import AgentResponse
|
||||
from .experiment import ExperimentManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentLoop(Enum):
|
||||
"""Enumeration of available loop types."""
|
||||
|
||||
ANTHROPIC = auto() # Anthropic implementation
|
||||
OMNI = auto() # OmniLoop implementation
|
||||
# Add more loop types as needed
|
||||
|
||||
|
||||
class BaseLoop(ABC):
|
||||
"""Base class for agent loops that handle message processing and tool execution."""
|
||||
|
||||
###########################################
|
||||
# INITIALIZATION AND CONFIGURATION
|
||||
###########################################
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
computer: Computer,
|
||||
@@ -68,6 +54,11 @@ class BaseLoop(ABC):
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
self._kwargs = kwargs
|
||||
|
||||
# Initialize message manager
|
||||
self.message_manager = StandardMessageManager(
|
||||
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
||||
)
|
||||
|
||||
# Initialize experiment manager
|
||||
if self.save_trajectory and self.base_dir:
|
||||
self.experiment_manager = ExperimentManager(
|
||||
@@ -110,8 +101,7 @@ class BaseLoop(ABC):
|
||||
)
|
||||
raise RuntimeError(f"Failed to initialize: {str(e)}")
|
||||
|
||||
###########################################
|
||||
|
||||
###########################################
|
||||
# ABSTRACT METHODS TO BE IMPLEMENTED BY SUBCLASSES
|
||||
###########################################
|
||||
|
||||
@@ -125,17 +115,14 @@ class BaseLoop(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
||||
def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
This method handles the main agent loop including message processing,
|
||||
API calls, response handling, and action execution.
|
||||
|
||||
Args:
|
||||
messages: List of message objects
|
||||
|
||||
Yields:
|
||||
Agent response format
|
||||
Returns:
|
||||
An async generator that yields agent responses
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
104
libs/agent/agent/core/factory.py
Normal file
104
libs/agent/agent/core/factory.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Base agent loop implementation."""
|
||||
|
||||
import logging
|
||||
import importlib.util
|
||||
from typing import Dict, Optional, Type, TYPE_CHECKING, Any, cast, Callable, Awaitable
|
||||
|
||||
from computer import Computer
|
||||
from .types import AgentLoop
|
||||
from .base import BaseLoop
|
||||
|
||||
# For type checking only
|
||||
if TYPE_CHECKING:
|
||||
from ..providers.omni.types import LLMProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoopFactory:
|
||||
"""Factory class for creating agent loops."""
|
||||
|
||||
# Registry to store loop implementations
|
||||
_loop_registry: Dict[AgentLoop, Type[BaseLoop]] = {}
|
||||
|
||||
@classmethod
|
||||
def create_loop(
|
||||
cls,
|
||||
loop_type: AgentLoop,
|
||||
api_key: str,
|
||||
model_name: str,
|
||||
computer: Computer,
|
||||
provider: Any = None,
|
||||
save_trajectory: bool = True,
|
||||
trajectory_dir: str = "trajectories",
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
|
||||
) -> BaseLoop:
|
||||
"""Create and return an appropriate loop instance based on type."""
|
||||
if loop_type == AgentLoop.ANTHROPIC:
|
||||
# Lazy import AnthropicLoop only when needed
|
||||
try:
|
||||
from ..providers.anthropic.loop import AnthropicLoop
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'anthropic' provider is not installed. "
|
||||
"Install it with 'pip install cua-agent[anthropic]'"
|
||||
)
|
||||
|
||||
return AnthropicLoop(
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
computer=computer,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
)
|
||||
elif loop_type == AgentLoop.OPENAI:
|
||||
# Lazy import OpenAILoop only when needed
|
||||
try:
|
||||
from ..providers.openai.loop import OpenAILoop
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'openai' provider is not installed. "
|
||||
"Install it with 'pip install cua-agent[openai]'"
|
||||
)
|
||||
|
||||
return OpenAILoop(
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
computer=computer,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
acknowledge_safety_check_callback=acknowledge_safety_check_callback,
|
||||
)
|
||||
elif loop_type == AgentLoop.OMNI:
|
||||
# Lazy import OmniLoop and related classes only when needed
|
||||
try:
|
||||
from ..providers.omni.loop import OmniLoop
|
||||
from ..providers.omni.parser import OmniParser
|
||||
from ..providers.omni.types import LLMProvider
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The 'omni' provider is not installed. "
|
||||
"Install it with 'pip install cua-agent[all]'"
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError("Provider is required for OMNI loop type")
|
||||
|
||||
# We know provider is the correct type at this point, so cast it
|
||||
provider_instance = cast(LLMProvider, provider)
|
||||
|
||||
return OmniLoop(
|
||||
provider=provider_instance,
|
||||
api_key=api_key,
|
||||
model=model_name,
|
||||
computer=computer,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
parser=OmniParser(),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported loop type: {loop_type}")
|
||||
15
libs/agent/agent/core/provider_config.py
Normal file
15
libs/agent/agent/core/provider_config.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""Provider-specific configurations and constants."""
|
||||
|
||||
from ..providers.omni.types import LLMProvider
|
||||
|
||||
# Default models for different providers
|
||||
DEFAULT_MODELS = {
|
||||
LLMProvider.OPENAI: "gpt-4o",
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
}
|
||||
|
||||
# Map providers to their environment variable names
|
||||
ENV_VARS = {
|
||||
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
||||
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
||||
}
|
||||
@@ -1,6 +1,16 @@
|
||||
"""Core type definitions."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||
from enum import Enum, auto
|
||||
|
||||
|
||||
class AgentLoop(Enum):
|
||||
"""Enumeration of available loop types."""
|
||||
|
||||
ANTHROPIC = auto() # Anthropic implementation
|
||||
OMNI = auto() # OmniLoop implementation
|
||||
OPENAI = auto() # OpenAI implementation
|
||||
# Add more loop types as needed
|
||||
|
||||
|
||||
class AgentResponse(TypedDict, total=False):
|
||||
|
||||
@@ -16,7 +16,7 @@ from datetime import datetime
|
||||
from computer import Computer
|
||||
|
||||
# Base imports
|
||||
from ...core.loop import BaseLoop
|
||||
from ...core.base import BaseLoop
|
||||
from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
||||
from ...core.types import AgentResponse
|
||||
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from anthropic.types.beta import (
|
||||
BetaMessageParam,
|
||||
BetaCacheControlEphemeralParam,
|
||||
BetaToolResultBlockParam,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageRetentionConfig:
|
||||
"""Configuration for image retention in messages."""
|
||||
|
||||
num_images_to_keep: int | None = None
|
||||
min_removal_threshold: int = 1
|
||||
enable_caching: bool = True
|
||||
|
||||
def should_retain_images(self) -> bool:
|
||||
"""Check if image retention is enabled."""
|
||||
return self.num_images_to_keep is not None and self.num_images_to_keep > 0
|
||||
|
||||
|
||||
class MessageManager:
|
||||
"""Manages message preparation, including image retention and caching."""
|
||||
|
||||
def __init__(self, image_retention_config: ImageRetentionConfig):
|
||||
"""Initialize the message manager.
|
||||
|
||||
Args:
|
||||
image_retention_config: Configuration for image retention
|
||||
"""
|
||||
if image_retention_config.min_removal_threshold < 1:
|
||||
raise ValueError("min_removal_threshold must be at least 1")
|
||||
self.image_retention_config = image_retention_config
|
||||
|
||||
def prepare_messages(self, messages: list[BetaMessageParam]) -> list[BetaMessageParam]:
|
||||
"""Prepare messages by applying image retention and caching as configured."""
|
||||
if self.image_retention_config.should_retain_images():
|
||||
self._filter_images(messages)
|
||||
if self.image_retention_config.enable_caching:
|
||||
self._inject_caching(messages)
|
||||
return messages
|
||||
|
||||
def _filter_images(self, messages: list[BetaMessageParam]) -> None:
|
||||
"""Filter messages to retain only the specified number of most recent images."""
|
||||
tool_result_blocks = cast(
|
||||
list[BetaToolResultBlockParam],
|
||||
[
|
||||
item
|
||||
for message in messages
|
||||
for item in (message["content"] if isinstance(message["content"], list) else [])
|
||||
if isinstance(item, dict) and item.get("type") == "tool_result"
|
||||
],
|
||||
)
|
||||
|
||||
total_images = sum(
|
||||
1
|
||||
for tool_result in tool_result_blocks
|
||||
for content in tool_result.get("content", [])
|
||||
if isinstance(content, dict) and content.get("type") == "image"
|
||||
)
|
||||
|
||||
images_to_remove = total_images - (self.image_retention_config.num_images_to_keep or 0)
|
||||
# Round down to nearest min_removal_threshold for better cache behavior
|
||||
images_to_remove -= images_to_remove % self.image_retention_config.min_removal_threshold
|
||||
|
||||
# Remove oldest images first
|
||||
for tool_result in tool_result_blocks:
|
||||
if isinstance(tool_result.get("content"), list):
|
||||
new_content = []
|
||||
for content in tool_result.get("content", []):
|
||||
if isinstance(content, dict) and content.get("type") == "image":
|
||||
if images_to_remove > 0:
|
||||
images_to_remove -= 1
|
||||
continue
|
||||
new_content.append(content)
|
||||
tool_result["content"] = new_content
|
||||
|
||||
def _inject_caching(self, messages: list[BetaMessageParam]) -> None:
|
||||
"""Inject caching control for the most recent turns, limited to 3 blocks max to avoid API errors."""
|
||||
# Anthropic API allows a maximum of 4 blocks with cache_control
|
||||
# We use 3 here to be safe, as the system block may also have cache_control
|
||||
blocks_with_cache_control = 0
|
||||
max_cache_control_blocks = 3
|
||||
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "user" and isinstance(content := message["content"], list):
|
||||
# Only add cache control to the latest message in each turn
|
||||
if blocks_with_cache_control < max_cache_control_blocks:
|
||||
blocks_with_cache_control += 1
|
||||
# Add cache control to the last content block only
|
||||
if content and len(content) > 0:
|
||||
content[-1]["cache_control"] = BetaCacheControlEphemeralParam(
|
||||
type="ephemeral"
|
||||
)
|
||||
else:
|
||||
# Remove any existing cache control
|
||||
if content and len(content) > 0:
|
||||
content[-1].pop("cache_control", None)
|
||||
|
||||
# Ensure we're not exceeding the limit by checking the total
|
||||
if blocks_with_cache_control > max_cache_control_blocks:
|
||||
# If we somehow exceeded the limit, remove excess cache controls
|
||||
excess = blocks_with_cache_control - max_cache_control_blocks
|
||||
for message in messages:
|
||||
if excess <= 0:
|
||||
break
|
||||
|
||||
if message["role"] == "user" and isinstance(content := message["content"], list):
|
||||
if content and len(content) > 0 and "cache_control" in content[-1]:
|
||||
content[-1].pop("cache_control", None)
|
||||
excess -= 1
|
||||
@@ -1,14 +1,11 @@
|
||||
"""Response and tool handling for Anthropic provider."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from typing import Any, Dict, List, Tuple, cast
|
||||
|
||||
from anthropic.types.beta import (
|
||||
BetaMessage,
|
||||
BetaMessageParam,
|
||||
BetaTextBlock,
|
||||
BetaTextBlockParam,
|
||||
BetaToolUseBlockParam,
|
||||
BetaContentBlockParam,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,14 +1,12 @@
|
||||
"""Utility functions for Anthropic message handling."""
|
||||
|
||||
import time
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Tuple, cast
|
||||
from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaTextBlock
|
||||
from anthropic.types.beta import BetaMessage
|
||||
from ..omni.parser import ParseResult
|
||||
from ...core.types import AgentResponse
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
# Configure module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -10,7 +10,7 @@ from httpx import ConnectError, ReadTimeout
|
||||
from typing import cast
|
||||
|
||||
from .parser import OmniParser, ParseResult
|
||||
from ...core.loop import BaseLoop
|
||||
from ...core.base import BaseLoop
|
||||
from ...core.visualization import VisualizationHelper
|
||||
from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
||||
from .utils import to_openai_agent_response_format
|
||||
|
||||
@@ -9,8 +9,10 @@ class LLMProvider(StrEnum):
|
||||
"""Supported LLM providers."""
|
||||
|
||||
ANTHROPIC = "anthropic"
|
||||
OMNI = "omni"
|
||||
OPENAI = "openai"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLM:
|
||||
"""Configuration for LLM model and provider."""
|
||||
|
||||
6
libs/agent/agent/providers/openai/__init__.py
Normal file
6
libs/agent/agent/providers/openai/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""OpenAI Agent Response API provider for computer control."""
|
||||
|
||||
from .types import LLMProvider
|
||||
from .loop import OpenAILoop
|
||||
|
||||
__all__ = ["OpenAILoop", "LLMProvider"]
|
||||
453
libs/agent/agent/providers/openai/api_handler.py
Normal file
453
libs/agent/agent/providers/openai/api_handler.py
Normal file
@@ -0,0 +1,453 @@
|
||||
"""API handler for the OpenAI provider."""
|
||||
|
||||
import logging
|
||||
import requests
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
from datetime import datetime
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .loop import OpenAILoop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIAPIHandler:
|
||||
"""Handler for OpenAI API interactions."""
|
||||
|
||||
def __init__(self, loop: "OpenAILoop"):
|
||||
"""Initialize the API handler.
|
||||
|
||||
Args:
|
||||
loop: OpenAI loop instance
|
||||
"""
|
||||
self.loop = loop
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("OPENAI_API_KEY environment variable not set")
|
||||
|
||||
self.api_base = "https://api.openai.com/v1"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Add organization if specified
|
||||
org_id = os.getenv("OPENAI_ORG")
|
||||
if org_id:
|
||||
self.headers["OpenAI-Organization"] = org_id
|
||||
|
||||
logger.info("Initialized OpenAI API handler")
|
||||
|
||||
async def send_initial_request(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
display_width: str,
|
||||
display_height: str,
|
||||
previous_response_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send an initial request to the OpenAI API with a screenshot.
|
||||
|
||||
Args:
|
||||
messages: List of message objects in standard format
|
||||
display_width: Width of the display in pixels
|
||||
display_height: Height of the display in pixels
|
||||
previous_response_id: Optional ID of the previous response to link requests
|
||||
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
# Convert display dimensions to integers
|
||||
try:
|
||||
width = int(display_width)
|
||||
height = int(display_height)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Failed to convert display dimensions to integers: {str(e)}")
|
||||
raise ValueError(
|
||||
f"Display dimensions must be integers: width={display_width}, height={display_height}"
|
||||
)
|
||||
|
||||
# Extract the latest text message and screenshot from messages
|
||||
latest_text = None
|
||||
latest_screenshot = None
|
||||
|
||||
for msg in reversed(messages):
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
content = msg.get("content", [])
|
||||
|
||||
if isinstance(content, str) and not latest_text:
|
||||
latest_text = content
|
||||
continue
|
||||
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# Look for text if we don't have it yet
|
||||
if not latest_text and item.get("type") == "text" and "text" in item:
|
||||
latest_text = item.get("text", "")
|
||||
|
||||
# Look for an image if we don't have it yet
|
||||
if not latest_screenshot and item.get("type") == "image":
|
||||
source = item.get("source", {})
|
||||
if source.get("type") == "base64" and "data" in source:
|
||||
latest_screenshot = source["data"]
|
||||
|
||||
# Prepare the input array
|
||||
input_array = []
|
||||
|
||||
# Add the text message if found
|
||||
if latest_text:
|
||||
input_array.append({"role": "user", "content": latest_text})
|
||||
|
||||
# Add the screenshot if found and no previous_response_id is provided
|
||||
if latest_screenshot and not previous_response_id:
|
||||
input_array.append(
|
||||
{
|
||||
"type": "message",
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{latest_screenshot}",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# Prepare the request payload - using minimal format from docs
|
||||
payload = {
|
||||
"model": "computer-use-preview",
|
||||
"tools": [
|
||||
{
|
||||
"type": "computer_use_preview",
|
||||
"display_width": width,
|
||||
"display_height": height,
|
||||
"environment": "mac", # We're on macOS
|
||||
}
|
||||
],
|
||||
"input": input_array,
|
||||
"truncation": "auto",
|
||||
}
|
||||
|
||||
# Add previous_response_id if provided
|
||||
if previous_response_id:
|
||||
payload["previous_response_id"] = previous_response_id
|
||||
|
||||
# Log the request using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("request", payload)
|
||||
|
||||
# Log for debug purposes
|
||||
logger.info("Sending initial request to OpenAI API")
|
||||
logger.debug(f"Request payload: {self._sanitize_response(payload)}")
|
||||
|
||||
# Send the request
|
||||
response = requests.post(
|
||||
f"{self.api_base}/responses",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_message = f"OpenAI API error: {response.status_code} {response.text}"
|
||||
logger.error(error_message)
|
||||
# Log the error using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("error", payload, error=Exception(error_message))
|
||||
raise Exception(error_message)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Log the response using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("response", payload, response_data)
|
||||
|
||||
# Log for debug purposes
|
||||
logger.info("Received response from OpenAI API")
|
||||
logger.debug(f"Response data: {self._sanitize_response(response_data)}")
|
||||
|
||||
return response_data
|
||||
|
||||
async def send_computer_call_request(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
display_width: str,
|
||||
display_height: str,
|
||||
previous_response_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Send a request to the OpenAI API with computer_call_output.
|
||||
|
||||
Args:
|
||||
messages: List of message objects in standard format
|
||||
display_width: Width of the display in pixels
|
||||
display_height: Height of the display in pixels
|
||||
system_prompt: System prompt to include
|
||||
previous_response_id: ID of the previous response to link requests
|
||||
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
# Convert display dimensions to integers
|
||||
try:
|
||||
width = int(display_width)
|
||||
height = int(display_height)
|
||||
except (ValueError, TypeError) as e:
|
||||
logger.error(f"Failed to convert display dimensions to integers: {str(e)}")
|
||||
raise ValueError(
|
||||
f"Display dimensions must be integers: width={display_width}, height={display_height}"
|
||||
)
|
||||
|
||||
# Find the most recent computer_call_output with call_id
|
||||
call_id = None
|
||||
screenshot_base64 = None
|
||||
|
||||
# Look for call_id and screenshot in messages
|
||||
for msg in reversed(messages):
|
||||
if not isinstance(msg, dict):
|
||||
continue
|
||||
|
||||
# Check if the message itself has a call_id
|
||||
if "call_id" in msg and not call_id:
|
||||
call_id = msg["call_id"]
|
||||
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
# Look for call_id
|
||||
if not call_id and "call_id" in item:
|
||||
call_id = item["call_id"]
|
||||
|
||||
# Look for screenshot in computer_call_output
|
||||
if not screenshot_base64 and item.get("type") == "computer_call_output":
|
||||
output = item.get("output", {})
|
||||
if isinstance(output, dict) and "image_url" in output:
|
||||
image_url = output.get("image_url", "")
|
||||
if image_url.startswith("data:image/png;base64,"):
|
||||
screenshot_base64 = image_url[len("data:image/png;base64,") :]
|
||||
|
||||
# Look for screenshot in image type
|
||||
if not screenshot_base64 and item.get("type") == "image":
|
||||
source = item.get("source", {})
|
||||
if source.get("type") == "base64" and "data" in source:
|
||||
screenshot_base64 = source["data"]
|
||||
|
||||
if not call_id or not screenshot_base64:
|
||||
logger.error("Missing call_id or screenshot for computer_call_output")
|
||||
logger.error(f"Last message: {messages[-1] if messages else None}")
|
||||
raise ValueError("Cannot create computer call request: missing call_id or screenshot")
|
||||
|
||||
# Prepare the request payload using minimal format from docs
|
||||
payload = {
|
||||
"model": "computer-use-preview",
|
||||
"previous_response_id": previous_response_id,
|
||||
"tools": [
|
||||
{
|
||||
"type": "computer_use_preview",
|
||||
"display_width": width,
|
||||
"display_height": height,
|
||||
"environment": "mac", # We're on macOS
|
||||
}
|
||||
],
|
||||
"input": [
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
||||
},
|
||||
}
|
||||
],
|
||||
"truncation": "auto",
|
||||
}
|
||||
|
||||
# Log the request using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("request", payload)
|
||||
|
||||
# Log for debug purposes
|
||||
logger.info("Sending computer call request to OpenAI API")
|
||||
logger.debug(f"Request payload: {self._sanitize_response(payload)}")
|
||||
|
||||
# Send the request
|
||||
response = requests.post(
|
||||
f"{self.api_base}/responses",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_message = f"OpenAI API error: {response.status_code} {response.text}"
|
||||
logger.error(error_message)
|
||||
# Log the error using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("error", payload, error=Exception(error_message))
|
||||
raise Exception(error_message)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Log the response using the BaseLoop's log_api_call method
|
||||
self.loop._log_api_call("response", payload, response_data)
|
||||
|
||||
# Log for debug purposes
|
||||
logger.info("Received response from OpenAI API")
|
||||
logger.debug(f"Response data: {self._sanitize_response(response_data)}")
|
||||
|
||||
return response_data
|
||||
|
||||
def _format_messages_for_agent_response(
|
||||
self, messages: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Format messages for the OpenAI Agent Response API.
|
||||
|
||||
The Agent Response API requires specific content types:
|
||||
- For user messages: use "input_text", "input_image", etc.
|
||||
- For assistant messages: use "output_text" only
|
||||
|
||||
Additionally, when using the computer tool, only one image can be sent.
|
||||
|
||||
Args:
|
||||
messages: List of standard messages
|
||||
|
||||
Returns:
|
||||
Messages formatted for the Agent Response API
|
||||
"""
|
||||
formatted_messages = []
|
||||
has_image = False # Track if we've already included an image
|
||||
|
||||
# We need to process messages in reverse to ensure we keep the most recent image
|
||||
# but preserve the original order in the final output
|
||||
reversed_messages = list(reversed(messages))
|
||||
temp_formatted = []
|
||||
|
||||
for msg in reversed_messages:
|
||||
if not msg:
|
||||
continue
|
||||
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
|
||||
logger.debug(f"Processing message - Role: {role}, Content type: {type(content)}")
|
||||
if isinstance(content, list):
|
||||
logger.debug(
|
||||
f"List content items: {[item.get('type') for item in content if isinstance(item, dict)]}"
|
||||
)
|
||||
|
||||
if isinstance(content, str):
|
||||
# For string content, create a message with the appropriate text type
|
||||
if role == "user":
|
||||
temp_formatted.append(
|
||||
{"role": role, "content": [{"type": "input_text", "text": content}]}
|
||||
)
|
||||
elif role == "assistant":
|
||||
# For assistant, we need explicit output_text
|
||||
temp_formatted.append(
|
||||
{"role": role, "content": [{"type": "output_text", "text": content}]}
|
||||
)
|
||||
elif role == "system":
|
||||
# System messages need to be formatted as input_text as well
|
||||
temp_formatted.append(
|
||||
{"role": role, "content": [{"type": "input_text", "text": content}]}
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
# For list content, convert each item to the correct type based on role
|
||||
formatted_content = []
|
||||
has_image_in_this_message = False
|
||||
|
||||
for item in content:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_type = item.get("type")
|
||||
|
||||
if role == "user":
|
||||
# Handle user message formatting
|
||||
if item_type == "text" or item_type == "input_text":
|
||||
# Text from user is input_text
|
||||
formatted_content.append(
|
||||
{"type": "input_text", "text": item.get("text", "")}
|
||||
)
|
||||
elif (item_type == "image" or item_type == "image_url") and not has_image:
|
||||
# Only include the first/most recent image we encounter
|
||||
if item_type == "image":
|
||||
# Image from user is input_image
|
||||
source = item.get("source", {})
|
||||
if source.get("type") == "base64" and "data" in source:
|
||||
formatted_content.append(
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{source['data']}",
|
||||
}
|
||||
)
|
||||
has_image = True
|
||||
has_image_in_this_message = True
|
||||
elif item_type == "image_url":
|
||||
# Convert "image_url" to "input_image"
|
||||
formatted_content.append(
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": item.get("image_url", {}).get("url", ""),
|
||||
}
|
||||
)
|
||||
has_image = True
|
||||
has_image_in_this_message = True
|
||||
elif role == "assistant":
|
||||
# Handle assistant message formatting - only output_text is supported
|
||||
if item_type == "text" or item_type == "output_text":
|
||||
formatted_content.append(
|
||||
{"type": "output_text", "text": item.get("text", "")}
|
||||
)
|
||||
|
||||
if formatted_content:
|
||||
# If this message had an image, mark it for inclusion
|
||||
temp_formatted.append(
|
||||
{
|
||||
"role": role,
|
||||
"content": formatted_content,
|
||||
"_had_image": has_image_in_this_message, # Temporary marker
|
||||
}
|
||||
)
|
||||
|
||||
# Reverse back to original order and cleanup
|
||||
for msg in reversed(temp_formatted):
|
||||
# Remove our temporary marker
|
||||
if "_had_image" in msg:
|
||||
del msg["_had_image"]
|
||||
formatted_messages.append(msg)
|
||||
|
||||
# Log summary for debugging
|
||||
num_images = sum(
|
||||
1
|
||||
for msg in formatted_messages
|
||||
for item in (msg.get("content", []) if isinstance(msg.get("content"), list) else [])
|
||||
if isinstance(item, dict) and item.get("type") == "input_image"
|
||||
)
|
||||
logger.info(f"Formatted {len(messages)} messages for OpenAI API with {num_images} images")
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _sanitize_response(self, response: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sanitize response for logging by removing large image data.
|
||||
|
||||
Args:
|
||||
response: Response to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized response
|
||||
"""
|
||||
from .utils import sanitize_message
|
||||
|
||||
# Deep copy to avoid modifying the original
|
||||
sanitized = response.copy()
|
||||
|
||||
# Sanitize output items if present
|
||||
if "output" in sanitized and isinstance(sanitized["output"], list):
|
||||
sanitized["output"] = [sanitize_message(item) for item in sanitized["output"]]
|
||||
|
||||
return sanitized
|
||||
440
libs/agent/agent/providers/openai/loop.py
Normal file
440
libs/agent/agent/providers/openai/loop.py
Normal file
@@ -0,0 +1,440 @@
|
||||
"""OpenAI Agent Response API provider implementation."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional, AsyncGenerator, Callable, Awaitable, TYPE_CHECKING
|
||||
|
||||
from computer import Computer
|
||||
from ...core.base import BaseLoop
|
||||
from ...core.types import AgentResponse
|
||||
from ...core.messages import StandardMessageManager, ImageRetentionConfig
|
||||
|
||||
from .api_handler import OpenAIAPIHandler
|
||||
from .response_handler import OpenAIResponseHandler
|
||||
from .tools.manager import ToolManager
|
||||
from .types import LLMProvider, ResponseItemType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAILoop(BaseLoop):
|
||||
"""OpenAI-specific implementation of the agent loop.
|
||||
|
||||
This class extends BaseLoop to provide specialized support for OpenAI's Agent Response API
|
||||
with computer control capabilities.
|
||||
"""
|
||||
|
||||
###########################################
|
||||
# INITIALIZATION AND CONFIGURATION
|
||||
###########################################
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
computer: Computer,
|
||||
model: str = "computer-use-preview",
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
save_trajectory: bool = True,
|
||||
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the OpenAI loop.
|
||||
|
||||
Args:
|
||||
api_key: OpenAI API key
|
||||
model: Model name (ignored, always uses computer-use-preview)
|
||||
computer: Computer instance
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
|
||||
base_dir: Base directory for saving experiment data
|
||||
max_retries: Maximum number of retries for API calls
|
||||
retry_delay: Delay between retries in seconds
|
||||
save_trajectory: Whether to save trajectory data
|
||||
acknowledge_safety_check_callback: Optional callback for safety check acknowledgment
|
||||
**kwargs: Additional provider-specific arguments
|
||||
"""
|
||||
# Always use computer-use-preview model
|
||||
if model != "computer-use-preview":
|
||||
logger.info(
|
||||
f"Overriding provided model '{model}' with required model 'computer-use-preview'"
|
||||
)
|
||||
|
||||
# Initialize base class with core config
|
||||
super().__init__(
|
||||
computer=computer,
|
||||
model="computer-use-preview", # Always use computer-use-preview
|
||||
api_key=api_key,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
base_dir=base_dir,
|
||||
save_trajectory=save_trajectory,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize message manager
|
||||
self.message_manager = StandardMessageManager(
|
||||
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
|
||||
)
|
||||
|
||||
# OpenAI-specific attributes
|
||||
self.provider = LLMProvider.OPENAI
|
||||
self.client = None
|
||||
self.retry_count = 0
|
||||
self.acknowledge_safety_check_callback = acknowledge_safety_check_callback
|
||||
self.queue = asyncio.Queue() # Initialize queue
|
||||
self.last_response_id = None # Store the last response ID across runs
|
||||
|
||||
# Initialize handlers
|
||||
self.api_handler = OpenAIAPIHandler(self)
|
||||
self.response_handler = OpenAIResponseHandler(self)
|
||||
|
||||
# Initialize tool manager with callback
|
||||
self.tool_manager = ToolManager(
|
||||
computer=computer, acknowledge_safety_check_callback=acknowledge_safety_check_callback
|
||||
)
|
||||
|
||||
###########################################
|
||||
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def initialize_client(self) -> None:
|
||||
"""Initialize the OpenAI API client and tools.
|
||||
|
||||
Implements abstract method from BaseLoop to set up the OpenAI-specific
|
||||
client, tool manager, and message manager.
|
||||
"""
|
||||
try:
|
||||
# Initialize tool manager
|
||||
await self.tool_manager.initialize()
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing OpenAI client: {str(e)}")
|
||||
self.client = None
|
||||
raise RuntimeError(f"Failed to initialize OpenAI client: {str(e)}")
|
||||
|
||||
###########################################
|
||||
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
|
||||
###########################################
|
||||
|
||||
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[AgentResponse, None]:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
messages: List of message objects in standard format
|
||||
|
||||
Yields:
|
||||
Agent response format
|
||||
"""
|
||||
try:
|
||||
logger.info("Starting OpenAI loop run")
|
||||
|
||||
# Create queue for response streaming
|
||||
queue = asyncio.Queue()
|
||||
|
||||
# Ensure tool manager is initialized
|
||||
await self.tool_manager.initialize()
|
||||
|
||||
# Start loop in background task
|
||||
loop_task = asyncio.create_task(self._run_loop(queue, messages))
|
||||
|
||||
# Process and yield messages as they arrive
|
||||
while True:
|
||||
try:
|
||||
item = await queue.get()
|
||||
if item is None: # Stop signal
|
||||
break
|
||||
yield item
|
||||
queue.task_done()
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing queue item: {str(e)}")
|
||||
continue
|
||||
|
||||
# Wait for loop to complete
|
||||
await loop_task
|
||||
|
||||
# Send completion message
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": "Task completed successfully.",
|
||||
"metadata": {"title": "✅ Complete"},
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing task: {str(e)}")
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
|
||||
###########################################
|
||||
# AGENT LOOP IMPLEMENTATION
|
||||
###########################################
|
||||
|
||||
async def _run_loop(self, queue: asyncio.Queue, messages: List[Dict[str, Any]]) -> None:
|
||||
"""Run the agent loop with provided messages.
|
||||
|
||||
Args:
|
||||
queue: Queue for response streaming
|
||||
messages: List of messages in standard format
|
||||
"""
|
||||
try:
|
||||
# Use the instance-level last_response_id instead of creating a local variable
|
||||
# This way it persists between runs
|
||||
|
||||
# Capture initial screenshot
|
||||
try:
|
||||
# Take screenshot
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
logger.info("Screenshot captured successfully")
|
||||
|
||||
# Convert to base64 if needed
|
||||
if isinstance(screenshot, bytes):
|
||||
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
|
||||
else:
|
||||
screenshot_base64 = screenshot
|
||||
|
||||
# Save screenshot if requested
|
||||
if self.save_trajectory:
|
||||
# Ensure screenshot_base64 is a string
|
||||
if not isinstance(screenshot_base64, str):
|
||||
logger.warning(
|
||||
"Converting non-string screenshot_base64 to string for _save_screenshot"
|
||||
)
|
||||
if isinstance(screenshot_base64, (bytearray, memoryview)):
|
||||
screenshot_base64 = base64.b64encode(screenshot_base64).decode("utf-8")
|
||||
self._save_screenshot(screenshot_base64, action_type="state")
|
||||
logger.info("Screenshot saved to trajectory")
|
||||
|
||||
# First add any existing user messages that were passed to run()
|
||||
user_query = None
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
user_content = msg.get("content", "")
|
||||
if isinstance(user_content, str) and user_content:
|
||||
user_query = user_content
|
||||
# Add the user's original query to the message manager
|
||||
self.message_manager.add_user_message(
|
||||
[{"type": "text", "text": user_content}]
|
||||
)
|
||||
break
|
||||
|
||||
# Add screenshot to message manager
|
||||
message_content = [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": screenshot_base64,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
# Add appropriate text with the screenshot
|
||||
message_content.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": user_query,
|
||||
}
|
||||
)
|
||||
|
||||
# Add the screenshot and text to the message manager
|
||||
self.message_manager.add_user_message(message_content)
|
||||
|
||||
# Process user request and convert our standard message format to one OpenAI expects
|
||||
messages = self.message_manager.messages
|
||||
logger.info(f"Starting agent loop with {len(messages)} messages")
|
||||
|
||||
# Create initial turn directory
|
||||
if self.save_trajectory:
|
||||
self._create_turn_dir()
|
||||
|
||||
# Call API
|
||||
screen_size = await self.computer.interface.get_screen_size()
|
||||
response = await self.api_handler.send_initial_request(
|
||||
messages=messages,
|
||||
display_width=str(screen_size["width"]),
|
||||
display_height=str(screen_size["height"]),
|
||||
previous_response_id=self.last_response_id,
|
||||
)
|
||||
|
||||
# Store response ID for next request
|
||||
# OpenAI API response structure: the ID is in the response dictionary
|
||||
if isinstance(response, dict) and "id" in response:
|
||||
self.last_response_id = response["id"] # Update instance variable
|
||||
logger.info(f"Received response with ID: {self.last_response_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not find response ID in OpenAI response: {type(response)}"
|
||||
)
|
||||
# Don't reset last_response_id to None - keep the previous value if available
|
||||
|
||||
# Process API response
|
||||
await queue.put(response)
|
||||
|
||||
# Loop to continue processing responses until task is complete
|
||||
task_complete = False
|
||||
while not task_complete:
|
||||
# Check if there are any computer calls
|
||||
output_items = response.get("output", []) or []
|
||||
computer_calls = [
|
||||
item for item in output_items if item.get("type") == "computer_call"
|
||||
]
|
||||
|
||||
if not computer_calls:
|
||||
logger.info("No computer calls in response, task may be complete.")
|
||||
task_complete = True
|
||||
continue
|
||||
|
||||
# Process the first computer call
|
||||
computer_call = computer_calls[0]
|
||||
action = computer_call.get("action", {})
|
||||
call_id = computer_call.get("call_id")
|
||||
|
||||
# Check for safety checks
|
||||
pending_safety_checks = computer_call.get("pending_safety_checks", [])
|
||||
acknowledged_safety_checks = []
|
||||
|
||||
if pending_safety_checks:
|
||||
# Log safety checks
|
||||
for check in pending_safety_checks:
|
||||
logger.warning(
|
||||
f"Safety check: {check.get('code')} - {check.get('message')}"
|
||||
)
|
||||
|
||||
# If we have a callback, use it to acknowledge safety checks
|
||||
if self.acknowledge_safety_check_callback:
|
||||
acknowledged = await self.acknowledge_safety_check_callback(
|
||||
pending_safety_checks
|
||||
)
|
||||
if not acknowledged:
|
||||
logger.warning("Safety check acknowledgment failed")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Safety checks were not acknowledged. Cannot proceed with action.",
|
||||
"metadata": {"title": "⚠️ Safety Warning"},
|
||||
}
|
||||
)
|
||||
continue
|
||||
acknowledged_safety_checks = pending_safety_checks
|
||||
|
||||
# Execute the action
|
||||
try:
|
||||
# Create a new turn directory for this action if saving trajectories
|
||||
if self.save_trajectory:
|
||||
self._create_turn_dir()
|
||||
|
||||
# Execute the tool
|
||||
result = await self.tool_manager.execute_tool("computer", action)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
if isinstance(screenshot, bytes):
|
||||
screenshot_base64 = base64.b64encode(screenshot).decode("utf-8")
|
||||
else:
|
||||
screenshot_base64 = screenshot
|
||||
|
||||
# Create computer_call_output
|
||||
computer_call_output = {
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}",
|
||||
},
|
||||
}
|
||||
|
||||
# Add acknowledged safety checks if any
|
||||
if acknowledged_safety_checks:
|
||||
computer_call_output["acknowledged_safety_checks"] = (
|
||||
acknowledged_safety_checks
|
||||
)
|
||||
|
||||
# Save to message manager for history
|
||||
self.message_manager.add_system_message(
|
||||
f"[Computer action executed: {action.get('type')}]"
|
||||
)
|
||||
self.message_manager.add_user_message([computer_call_output])
|
||||
|
||||
# For follow-up requests with previous_response_id, we only need to send
|
||||
# the computer_call_output, not the full message history
|
||||
# The API handler will extract this from the message history
|
||||
if isinstance(self.last_response_id, str):
|
||||
response = await self.api_handler.send_computer_call_request(
|
||||
messages=self.message_manager.messages,
|
||||
display_width=str(screen_size["width"]),
|
||||
display_height=str(screen_size["height"]),
|
||||
previous_response_id=self.last_response_id, # Use instance variable
|
||||
)
|
||||
|
||||
# Store response ID for next request
|
||||
if isinstance(response, dict) and "id" in response:
|
||||
self.last_response_id = response["id"] # Update instance variable
|
||||
logger.info(f"Received response with ID: {self.last_response_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Could not find response ID in OpenAI response: {type(response)}"
|
||||
)
|
||||
# Keep using the previous response ID if we can't find a new one
|
||||
|
||||
# Process the response
|
||||
# await self.response_handler.process_response(response, queue)
|
||||
await queue.put(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing computer action: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error executing action: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
task_complete = True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing initial screenshot: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error capturing screenshot: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
await queue.put(None) # Signal that we're done
|
||||
return
|
||||
|
||||
# Signal that we're done
|
||||
await queue.put(None)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _run_loop: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
await queue.put(None) # Signal that we're done
|
||||
|
||||
def get_last_response_id(self) -> Optional[str]:
|
||||
"""Get the last response ID.
|
||||
|
||||
Returns:
|
||||
The last response ID or None if no response has been received
|
||||
"""
|
||||
return self.last_response_id
|
||||
|
||||
def set_last_response_id(self, response_id: str) -> None:
|
||||
"""Set the last response ID.
|
||||
|
||||
Args:
|
||||
response_id: OpenAI response ID to set
|
||||
"""
|
||||
self.last_response_id = response_id
|
||||
logger.info(f"Manually set response ID to: {self.last_response_id}")
|
||||
205
libs/agent/agent/providers/openai/response_handler.py
Normal file
205
libs/agent/agent/providers/openai/response_handler.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Response handler for the OpenAI provider."""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING, AsyncGenerator
|
||||
import base64
|
||||
|
||||
from ...core.types import AgentResponse
|
||||
from .types import ResponseItemType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .loop import OpenAILoop
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenAIResponseHandler:
|
||||
"""Handler for OpenAI API responses."""
|
||||
|
||||
def __init__(self, loop: "OpenAILoop"):
|
||||
"""Initialize the response handler.
|
||||
|
||||
Args:
|
||||
loop: OpenAI loop instance
|
||||
"""
|
||||
self.loop = loop
|
||||
logger.info("Initialized OpenAI response handler")
|
||||
|
||||
async def process_response(self, response: Dict[str, Any], queue: asyncio.Queue) -> None:
|
||||
"""Process the response from the OpenAI API.
|
||||
|
||||
Args:
|
||||
response: Response from the API
|
||||
queue: Queue for response streaming
|
||||
"""
|
||||
try:
|
||||
# Get output items
|
||||
output_items = response.get("output", []) or []
|
||||
|
||||
# Process each output item
|
||||
for item in output_items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
item_type = item.get("type")
|
||||
|
||||
# For computer_call items, we only need to add to the queue
|
||||
# The loop is now handling executing the action and creating the computer_call_output
|
||||
if item_type == ResponseItemType.COMPUTER_CALL:
|
||||
# Send computer_call to queue so it can be processed
|
||||
await queue.put(item)
|
||||
|
||||
elif item_type == ResponseItemType.MESSAGE:
|
||||
# Send message to queue
|
||||
await queue.put(item)
|
||||
|
||||
elif item_type == ResponseItemType.REASONING:
|
||||
# Process reasoning summary
|
||||
summary = None
|
||||
if "summary" in item and isinstance(item["summary"], list):
|
||||
for summary_item in item["summary"]:
|
||||
if (
|
||||
isinstance(summary_item, dict)
|
||||
and summary_item.get("type") == "summary_text"
|
||||
):
|
||||
summary = summary_item.get("text")
|
||||
break
|
||||
|
||||
if summary:
|
||||
# Log the reasoning summary
|
||||
logger.info(f"Reasoning summary: {summary}")
|
||||
|
||||
# Send reasoning summary to queue with a special format
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"[Reasoning: {summary}]",
|
||||
"metadata": {"title": "💭 Reasoning", "is_summary": True},
|
||||
}
|
||||
)
|
||||
|
||||
# Also pass the original reasoning item to the queue for complete context
|
||||
await queue.put(item)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing response: {str(e)}")
|
||||
await queue.put(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"Error processing response: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
)
|
||||
|
||||
def _process_message_item(self, item: Dict[str, Any]) -> AgentResponse:
|
||||
"""Process a message item from the response.
|
||||
|
||||
Args:
|
||||
item: Message item from the response
|
||||
|
||||
Returns:
|
||||
Processed message in AgentResponse format
|
||||
"""
|
||||
# Extract content items - add null check
|
||||
content_items = item.get("content", []) or []
|
||||
|
||||
# Extract text from content items - use output_text type from OpenAI
|
||||
text = ""
|
||||
for content_item in content_items:
|
||||
# Skip if content_item is None or not a dict
|
||||
if content_item is None or not isinstance(content_item, dict):
|
||||
continue
|
||||
|
||||
# In OpenAI Agent Response API, text content is in "output_text" type items
|
||||
if content_item.get("type") == "output_text":
|
||||
text += content_item.get("text", "")
|
||||
|
||||
# Create agent response
|
||||
return {
|
||||
"role": "assistant",
|
||||
"content": text
|
||||
or "I don't have a response for that right now.", # Provide fallback when text is empty
|
||||
"metadata": {"title": "💬 Response"},
|
||||
}
|
||||
|
||||
async def _process_computer_call(self, item: Dict[str, Any], queue: asyncio.Queue) -> None:
|
||||
"""Process a computer call item from the response.
|
||||
|
||||
Args:
|
||||
item: Computer call item
|
||||
queue: Queue to add responses to
|
||||
"""
|
||||
try:
|
||||
# Log the computer call
|
||||
action = item.get("action", {}) or {}
|
||||
if not isinstance(action, dict):
|
||||
logger.warning(f"Expected dict for action, got {type(action)}")
|
||||
action = {}
|
||||
|
||||
action_type = action.get("type", "unknown")
|
||||
logger.info(f"Processing computer call: {action_type}")
|
||||
|
||||
# Execute the tool call
|
||||
result = await self.loop.tool_manager.execute_tool("computer", action)
|
||||
|
||||
# Add any message to the conversation history and queue
|
||||
if result and result.base64_image:
|
||||
# Update message history with the call output
|
||||
self.loop.message_manager.add_user_message(
|
||||
[{"type": "text", "text": f"[Computer action completed: {action_type}]"}]
|
||||
)
|
||||
|
||||
# Add image to messages (using correct content types for Agent Response API)
|
||||
self.loop.message_manager.add_user_message(
|
||||
[
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": result.base64_image,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
# If browser environment, include URL if available
|
||||
# if (
|
||||
# hasattr(self.loop.computer, "environment")
|
||||
# and self.loop.computer.environment == "browser"
|
||||
# ):
|
||||
# try:
|
||||
# if hasattr(self.loop.computer.interface, "get_current_url"):
|
||||
# current_url = await self.loop.computer.interface.get_current_url()
|
||||
# self.loop.message_manager.add_user_message(
|
||||
# [
|
||||
# {
|
||||
# "type": "text",
|
||||
# "text": f"Current URL: {current_url}",
|
||||
# }
|
||||
# ]
|
||||
# )
|
||||
# except Exception as e:
|
||||
# logger.warning(f"Failed to get current URL: {str(e)}")
|
||||
|
||||
# Log successful completion
|
||||
logger.info(f"Computer call {action_type} executed successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing computer call: {str(e)}")
|
||||
logger.debug(traceback.format_exc())
|
||||
|
||||
# Add error to conversation
|
||||
self.loop.message_manager.add_user_message(
|
||||
[{"type": "text", "text": f"Error executing computer action: {str(e)}"}]
|
||||
)
|
||||
|
||||
# Send error to queue
|
||||
error_response = {
|
||||
"role": "assistant",
|
||||
"content": f"Error executing computer action: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
await queue.put(error_response)
|
||||
15
libs/agent/agent/providers/openai/tools/__init__.py
Normal file
15
libs/agent/agent/providers/openai/tools/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""OpenAI tools module for computer control."""
|
||||
|
||||
from .manager import ToolManager
|
||||
from .computer import ComputerTool
|
||||
from .base import BaseOpenAITool, ToolResult, ToolError, ToolFailure, CLIResult
|
||||
|
||||
__all__ = [
|
||||
"ToolManager",
|
||||
"ComputerTool",
|
||||
"BaseOpenAITool",
|
||||
"ToolResult",
|
||||
"ToolError",
|
||||
"ToolFailure",
|
||||
"CLIResult",
|
||||
]
|
||||
79
libs/agent/agent/providers/openai/tools/base.py
Normal file
79
libs/agent/agent/providers/openai/tools/base.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""OpenAI-specific tool base classes."""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ....core.tools.base import BaseTool
|
||||
|
||||
|
||||
class BaseOpenAITool(BaseTool, metaclass=ABCMeta):
|
||||
"""Abstract base class for OpenAI-defined tools."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the base OpenAI tool."""
|
||||
# No specific initialization needed yet, but included for future extensibility
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, **kwargs) -> Any:
|
||||
"""Executes the tool with the given arguments."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to OpenAI-specific API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters for OpenAI API
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass(kw_only=True, frozen=True)
|
||||
class ToolResult:
|
||||
"""Represents the result of a tool execution."""
|
||||
|
||||
output: str | None = None
|
||||
error: str | None = None
|
||||
base64_image: str | None = None
|
||||
system: str | None = None
|
||||
content: list[dict] | None = None
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field.name) for field in fields(self))
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(field: str | None, other_field: str | None, concatenate: bool = True):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
content=self.content or other.content, # Use first non-None content
|
||||
)
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""Returns a new ToolResult with the given fields replaced."""
|
||||
return replace(self, **kwargs)
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""A ToolResult that can be rendered as a CLI output."""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""A ToolResult that represents a failure."""
|
||||
|
||||
|
||||
class ToolError(Exception):
|
||||
"""Raised when a tool encounters an error."""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
319
libs/agent/agent/providers/openai/tools/computer.py
Normal file
319
libs/agent/agent/providers/openai/tools/computer.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""Computer tool for OpenAI."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
from typing import Literal, Any, Dict, Optional, List, Union
|
||||
|
||||
from computer.computer import Computer
|
||||
|
||||
from .base import BaseOpenAITool, ToolError, ToolResult
|
||||
from ....core.tools.computer import BaseComputerTool
|
||||
|
||||
TYPING_DELAY_MS = 12
|
||||
TYPING_GROUP_SIZE = 50
|
||||
|
||||
# Key mapping for special keys
|
||||
KEY_MAPPING = {
|
||||
"enter": "return",
|
||||
"backspace": "delete",
|
||||
"delete": "forwarddelete",
|
||||
"escape": "esc",
|
||||
"pageup": "page_up",
|
||||
"pagedown": "page_down",
|
||||
"arrowup": "up",
|
||||
"arrowdown": "down",
|
||||
"arrowleft": "left",
|
||||
"arrowright": "right",
|
||||
"home": "home",
|
||||
"end": "end",
|
||||
"tab": "tab",
|
||||
"space": "space",
|
||||
"shift": "shift",
|
||||
"control": "control",
|
||||
"alt": "alt",
|
||||
"meta": "command",
|
||||
}
|
||||
|
||||
Action = Literal[
|
||||
"key",
|
||||
"type",
|
||||
"mouse_move",
|
||||
"left_click",
|
||||
"right_click",
|
||||
"double_click",
|
||||
"screenshot",
|
||||
"scroll",
|
||||
]
|
||||
|
||||
|
||||
class ComputerTool(BaseComputerTool, BaseOpenAITool):
|
||||
"""
|
||||
A tool that allows the agent to interact with the screen, keyboard, and mouse of the current computer.
|
||||
"""
|
||||
|
||||
name: Literal["computer"] = "computer"
|
||||
api_type: Literal["computer_use_preview"] = "computer_use_preview"
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
display_num: Optional[int] = None
|
||||
computer: Computer # The CUA Computer instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_screenshot_delay = 1.0 # macOS is generally faster than X11
|
||||
_scaling_enabled = True
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the computer tool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance
|
||||
"""
|
||||
self.computer = computer
|
||||
self.width = None
|
||||
self.height = None
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize the base computer tool first
|
||||
BaseComputerTool.__init__(self, computer)
|
||||
# Then initialize the OpenAI tool
|
||||
BaseOpenAITool.__init__(self)
|
||||
|
||||
# Additional initialization
|
||||
self.width = None # Will be initialized from computer interface
|
||||
self.height = None # Will be initialized from computer interface
|
||||
self.display_num = None
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
if self.width is None or self.height is None:
|
||||
raise RuntimeError(
|
||||
"Screen dimensions not initialized. Call initialize_dimensions() first."
|
||||
)
|
||||
return {
|
||||
"type": self.api_type,
|
||||
"display_width": self.width,
|
||||
"display_height": self.height,
|
||||
"display_number": self.display_num,
|
||||
}
|
||||
|
||||
async def initialize_dimensions(self):
|
||||
"""Initialize screen dimensions from the computer interface."""
|
||||
try:
|
||||
display_size = await self.computer.interface.get_screen_size()
|
||||
self.width = display_size["width"]
|
||||
self.height = display_size["height"]
|
||||
assert isinstance(self.width, int) and isinstance(self.height, int)
|
||||
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
|
||||
except Exception as e:
|
||||
# Fall back to defaults if we can't get accurate dimensions
|
||||
self.width = 1024
|
||||
self.height = 768
|
||||
self.logger.warning(
|
||||
f"Failed to get screen dimensions, using defaults: {self.width}x{self.height}. Error: {e}"
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
type: str, # OpenAI uses 'type' instead of 'action'
|
||||
text: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
# Ensure dimensions are initialized
|
||||
if self.width is None or self.height is None:
|
||||
await self.initialize_dimensions()
|
||||
if self.width is None or self.height is None:
|
||||
raise ToolError("Failed to initialize screen dimensions")
|
||||
|
||||
if type == "type":
|
||||
if text is None:
|
||||
raise ToolError("text is required for type action")
|
||||
return await self.handle_typing(text)
|
||||
elif type == "click":
|
||||
# Map button to correct action name
|
||||
button = kwargs.get("button")
|
||||
if button is None:
|
||||
raise ToolError("button is required for click action")
|
||||
return await self.handle_click(button, kwargs["x"], kwargs["y"])
|
||||
elif type == "keypress":
|
||||
# Check for keys in kwargs if text is None
|
||||
if text is None:
|
||||
if "keys" in kwargs and isinstance(kwargs["keys"], list):
|
||||
# Pass the keys list directly instead of joining and then splitting
|
||||
return await self.handle_key(kwargs["keys"])
|
||||
else:
|
||||
raise ToolError("Either 'text' or 'keys' is required for keypress action")
|
||||
return await self.handle_key(text)
|
||||
elif type == "mouse_move":
|
||||
if "coordinates" not in kwargs:
|
||||
raise ToolError("coordinates is required for mouse_move action")
|
||||
return await self.handle_mouse_move(
|
||||
kwargs["coordinates"][0], kwargs["coordinates"][1]
|
||||
)
|
||||
elif type == "scroll":
|
||||
# Get x, y coordinates directly from kwargs
|
||||
x = kwargs.get("x")
|
||||
y = kwargs.get("y")
|
||||
if x is None or y is None:
|
||||
raise ToolError("x and y coordinates are required for scroll action")
|
||||
scroll_x = kwargs.get("scroll_x", 0)
|
||||
scroll_y = kwargs.get("scroll_y", 0)
|
||||
return await self.handle_scroll(x, y, scroll_x, scroll_y)
|
||||
elif type == "screenshot":
|
||||
return await self.screenshot()
|
||||
elif type == "wait":
|
||||
duration = kwargs.get("duration", 1.0)
|
||||
await asyncio.sleep(duration)
|
||||
return await self.screenshot()
|
||||
else:
|
||||
raise ToolError(f"Unsupported action: {type}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in ComputerTool.__call__: {str(e)}")
|
||||
raise ToolError(f"Failed to execute {type}: {str(e)}")
|
||||
|
||||
async def handle_click(self, button: str, x: int, y: int) -> ToolResult:
|
||||
"""Handle different click actions."""
|
||||
try:
|
||||
# Perform requested click action
|
||||
if button == "left":
|
||||
await self.computer.interface.left_click(x, y)
|
||||
elif button == "right":
|
||||
await self.computer.interface.right_click(x, y)
|
||||
elif button == "double":
|
||||
await self.computer.interface.double_click(x, y)
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Performed {button} click at ({x}, {y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_click: {str(e)}")
|
||||
raise ToolError(f"Failed to perform {button} click at ({x}, {y}): {str(e)}")
|
||||
|
||||
async def handle_typing(self, text: str) -> ToolResult:
|
||||
"""Handle typing text with a small delay between characters."""
|
||||
try:
|
||||
# Type the text with a small delay
|
||||
await self.computer.interface.type_text(text)
|
||||
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Take screenshot after typing
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(output=f"Typed: {text}", base64_image=base64_screenshot)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_typing: {str(e)}")
|
||||
raise ToolError(f"Failed to type '{text}': {str(e)}")
|
||||
|
||||
async def handle_key(self, key: Union[str, List[str]]) -> ToolResult:
|
||||
"""Handle key press, supporting both single keys and combinations.
|
||||
|
||||
Args:
|
||||
key: Either a string (e.g. "ctrl+c") or a list of keys (e.g. ["ctrl", "c"])
|
||||
"""
|
||||
try:
|
||||
# Check if key is already a list
|
||||
if isinstance(key, list):
|
||||
keys = [k.strip().lower() for k in key]
|
||||
else:
|
||||
# Split key string into list if it's a combination (e.g. "ctrl+c")
|
||||
keys = [k.strip().lower() for k in key.split("+")]
|
||||
|
||||
# Map each key
|
||||
mapped_keys = [KEY_MAPPING.get(k, k) for k in keys]
|
||||
|
||||
if len(mapped_keys) > 1:
|
||||
# For key combinations (like Ctrl+C)
|
||||
for k in mapped_keys:
|
||||
await self.computer.interface.press_key(k)
|
||||
await asyncio.sleep(0.1)
|
||||
for k in reversed(mapped_keys):
|
||||
await self.computer.interface.press_key(k)
|
||||
else:
|
||||
# Single key press
|
||||
await self.computer.interface.press_key(mapped_keys[0])
|
||||
|
||||
# Wait briefly
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(output=f"Pressed key: {key}", base64_image=base64_screenshot)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_key: {str(e)}")
|
||||
raise ToolError(f"Failed to press key '{key}': {str(e)}")
|
||||
|
||||
async def handle_mouse_move(self, x: int, y: int) -> ToolResult:
|
||||
"""Handle mouse movement."""
|
||||
try:
|
||||
# Move cursor to position
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
|
||||
# Wait briefly
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(output=f"Moved cursor to ({x}, {y})", base64_image=base64_screenshot)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_mouse_move: {str(e)}")
|
||||
raise ToolError(f"Failed to move cursor to ({x}, {y}): {str(e)}")
|
||||
|
||||
async def handle_scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> ToolResult:
|
||||
"""Handle scrolling."""
|
||||
try:
|
||||
# Move cursor to position first
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
|
||||
# Scroll based on direction
|
||||
if scroll_y > 0:
|
||||
await self.computer.interface.scroll_down(abs(scroll_y))
|
||||
elif scroll_y < 0:
|
||||
await self.computer.interface.scroll_up(abs(scroll_y))
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Scrolled at ({x}, {y}) with delta ({scroll_x}, {scroll_y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in handle_scroll: {str(e)}")
|
||||
raise ToolError(f"Failed to scroll at ({x}, {y}): {str(e)}")
|
||||
|
||||
async def screenshot(self) -> ToolResult:
|
||||
"""Take a screenshot."""
|
||||
try:
|
||||
# Take screenshot
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(output="Screenshot taken", base64_image=base64_screenshot)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in screenshot: {str(e)}")
|
||||
raise ToolError(f"Failed to take screenshot: {str(e)}")
|
||||
106
libs/agent/agent/providers/openai/tools/manager.py
Normal file
106
libs/agent/agent/providers/openai/tools/manager.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Tool manager for the OpenAI provider."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, Any, Optional, List, Callable, Awaitable, Union
|
||||
|
||||
from computer import Computer
|
||||
from ..types import ComputerAction, ResponseItemType
|
||||
from .computer import ComputerTool
|
||||
from ....core.tools.base import ToolResult, ToolFailure
|
||||
from ....core.tools.collection import ToolCollection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolManager:
|
||||
"""Manager for computer tools in the OpenAI agent."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
computer: Computer,
|
||||
acknowledge_safety_check_callback: Optional[Callable[[str], Awaitable[bool]]] = None,
|
||||
):
|
||||
"""Initialize the tool manager.
|
||||
|
||||
Args:
|
||||
computer: Computer instance
|
||||
acknowledge_safety_check_callback: Optional callback for safety check acknowledgment
|
||||
"""
|
||||
self.computer = computer
|
||||
self.acknowledge_safety_check_callback = acknowledge_safety_check_callback
|
||||
self._initialized = False
|
||||
self.computer_tool = ComputerTool(computer)
|
||||
self.tools = None
|
||||
logger.info("Initialized OpenAI ToolManager")
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the tool manager."""
|
||||
if not self._initialized:
|
||||
logger.info("Initializing OpenAI ToolManager")
|
||||
|
||||
# Initialize the computer tool
|
||||
await self.computer_tool.initialize_dimensions()
|
||||
|
||||
# Initialize tool collection
|
||||
self.tools = ToolCollection(self.computer_tool)
|
||||
|
||||
self._initialized = True
|
||||
logger.info("OpenAI ToolManager initialized")
|
||||
|
||||
async def get_tools_definition(self) -> List[Dict[str, Any]]:
|
||||
"""Get the tools definition for the OpenAI agent.
|
||||
|
||||
Returns:
|
||||
Tools definition for the OpenAI agent
|
||||
"""
|
||||
if not self.tools:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
|
||||
# For the OpenAI Agent Response API, we use a special "computer-preview" tool
|
||||
# which provides the correct interface for computer control
|
||||
display_width, display_height = await self._get_computer_dimensions()
|
||||
|
||||
# Get environment, using "mac" as default since we're on macOS
|
||||
environment = getattr(self.computer, "environment", "mac")
|
||||
|
||||
# Ensure environment is one of the allowed values
|
||||
if environment not in ["windows", "mac", "linux", "browser"]:
|
||||
logger.warning(f"Invalid environment value: {environment}, using 'mac' instead")
|
||||
environment = "mac"
|
||||
|
||||
return [
|
||||
{
|
||||
"type": "computer-preview",
|
||||
"display_width": display_width,
|
||||
"display_height": display_height,
|
||||
"environment": environment,
|
||||
}
|
||||
]
|
||||
|
||||
async def _get_computer_dimensions(self) -> tuple[int, int]:
|
||||
"""Get the dimensions of the computer display.
|
||||
|
||||
Returns:
|
||||
Tuple of (width, height)
|
||||
"""
|
||||
# If computer tool is initialized, use its dimensions
|
||||
if self.computer_tool.width is not None and self.computer_tool.height is not None:
|
||||
return (self.computer_tool.width, self.computer_tool.height)
|
||||
|
||||
# Try to get from computer.interface if available
|
||||
screen_size = await self.computer.interface.get_screen_size()
|
||||
return (int(screen_size["width"]), int(screen_size["height"]))
|
||||
|
||||
async def execute_tool(self, name: str, tool_input: Dict[str, Any]) -> ToolResult:
|
||||
"""Execute a tool with the given input.
|
||||
|
||||
Args:
|
||||
name: Name of the tool to execute
|
||||
tool_input: Input parameters for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool execution
|
||||
"""
|
||||
if not self.tools:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
return await self.tools.run(name=name, tool_input=tool_input)
|
||||
36
libs/agent/agent/providers/openai/types.py
Normal file
36
libs/agent/agent/providers/openai/types.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""Type definitions for the OpenAI provider."""
|
||||
|
||||
from enum import StrEnum, auto
|
||||
from typing import Dict, List, Optional, Union, Any
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class LLMProvider(StrEnum):
|
||||
"""OpenAI LLM provider types."""
|
||||
|
||||
OPENAI = "openai"
|
||||
|
||||
|
||||
class ResponseItemType(StrEnum):
|
||||
"""Types of items in OpenAI Agent Response output."""
|
||||
|
||||
MESSAGE = "message"
|
||||
COMPUTER_CALL = "computer_call"
|
||||
COMPUTER_CALL_OUTPUT = "computer_call_output"
|
||||
REASONING = "reasoning"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComputerAction:
|
||||
"""Represents a computer action to be performed."""
|
||||
|
||||
type: str
|
||||
x: Optional[int] = None
|
||||
y: Optional[int] = None
|
||||
text: Optional[str] = None
|
||||
button: Optional[str] = None
|
||||
keys: Optional[List[str]] = None
|
||||
ms: Optional[int] = None
|
||||
scroll_x: Optional[int] = None
|
||||
scroll_y: Optional[int] = None
|
||||
path: Optional[List[Dict[str, int]]] = None
|
||||
98
libs/agent/agent/providers/openai/utils.py
Normal file
98
libs/agent/agent/providers/openai/utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""Utility functions for the OpenAI provider."""
|
||||
|
||||
import logging
|
||||
import json
|
||||
import base64
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ...core.types import AgentResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def format_images_for_openai(images_base64: List[str]) -> List[Dict[str, Any]]:
|
||||
"""Format images for OpenAI Agent Response API.
|
||||
|
||||
Args:
|
||||
images_base64: List of base64 encoded images
|
||||
|
||||
Returns:
|
||||
List of formatted image items for Agent Response API
|
||||
"""
|
||||
return [
|
||||
{"type": "input_image", "image_url": f"data:image/png;base64,{image}"}
|
||||
for image in images_base64
|
||||
]
|
||||
|
||||
|
||||
def extract_message_content(message: Dict[str, Any]) -> str:
|
||||
"""Extract text content from a message.
|
||||
|
||||
Args:
|
||||
message: Message to extract content from
|
||||
|
||||
Returns:
|
||||
Text content from the message
|
||||
"""
|
||||
if isinstance(message.get("content"), str):
|
||||
return message["content"]
|
||||
|
||||
if isinstance(message.get("content"), list):
|
||||
text = ""
|
||||
role = message.get("role", "user")
|
||||
|
||||
for item in message["content"]:
|
||||
if isinstance(item, dict):
|
||||
# For user messages
|
||||
if role == "user" and item.get("type") == "input_text":
|
||||
text += item.get("text", "")
|
||||
# For standard format
|
||||
elif item.get("type") == "text":
|
||||
text += item.get("text", "")
|
||||
# For assistant messages in Agent Response API format
|
||||
elif item.get("type") == "output_text":
|
||||
text += item.get("text", "")
|
||||
return text
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
def sanitize_message(msg: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sanitize a message for logging by removing large image data.
|
||||
|
||||
Args:
|
||||
msg: Message to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized message
|
||||
"""
|
||||
if not isinstance(msg, dict):
|
||||
return msg
|
||||
|
||||
sanitized = msg.copy()
|
||||
|
||||
# Handle message content
|
||||
if isinstance(sanitized.get("content"), list):
|
||||
sanitized_content = []
|
||||
for item in sanitized["content"]:
|
||||
if isinstance(item, dict):
|
||||
# Handle various image types
|
||||
if item.get("type") == "image_url" and "image_url" in item:
|
||||
sanitized_content.append({"type": "image_url", "image_url": "[omitted]"})
|
||||
elif item.get("type") == "input_image" and "image_url" in item:
|
||||
sanitized_content.append({"type": "input_image", "image_url": "[omitted]"})
|
||||
elif item.get("type") == "image" and "source" in item:
|
||||
sanitized_content.append({"type": "image", "source": "[omitted]"})
|
||||
else:
|
||||
sanitized_content.append(item)
|
||||
else:
|
||||
sanitized_content.append(item)
|
||||
sanitized["content"] = sanitized_content
|
||||
|
||||
# Handle computer_call_output
|
||||
if sanitized.get("type") == "computer_call_output" and "output" in sanitized:
|
||||
output = sanitized["output"]
|
||||
if isinstance(output, dict) and "image_url" in output:
|
||||
sanitized["output"] = {**output, "image_url": "[omitted]"}
|
||||
|
||||
return sanitized
|
||||
@@ -30,6 +30,10 @@ anthropic = [
|
||||
"anthropic>=0.49.0",
|
||||
"boto3>=1.35.81,<2.0.0",
|
||||
]
|
||||
openai = [
|
||||
"openai>=1.14.0,<2.0.0",
|
||||
"httpx>=0.27.0,<0.29.0",
|
||||
]
|
||||
som = [
|
||||
"torch>=2.2.1",
|
||||
"torchvision>=0.17.1",
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
# """Basic tests for the agent package."""
|
||||
|
||||
# import pytest
|
||||
# from agent import OmniComputerAgent, LLMProvider
|
||||
# from agent.base.agent import BaseComputerAgent
|
||||
# from computer import Computer
|
||||
|
||||
# def test_agent_import():
|
||||
# """Test that we can import the OmniComputerAgent class."""
|
||||
# assert OmniComputerAgent is not None
|
||||
# assert LLMProvider is not None
|
||||
|
||||
# def test_agent_init():
|
||||
# """Test that we can create an OmniComputerAgent instance."""
|
||||
# agent = OmniComputerAgent(
|
||||
# provider=LLMProvider.OPENAI,
|
||||
# use_host_computer_server=True
|
||||
# )
|
||||
# assert agent is not None
|
||||
|
||||
# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed")
|
||||
# def test_computer_agent_anthropic():
|
||||
# """Test creating an Anthropic agent."""
|
||||
# agent = ComputerAgent(provider=Provider.ANTHROPIC)
|
||||
# assert isinstance(agent._agent, BaseComputerAgent)
|
||||
|
||||
# def test_computer_agent_invalid_provider():
|
||||
# """Test creating an agent with an invalid provider."""
|
||||
# with pytest.raises(ValueError, match="Unsupported provider"):
|
||||
# ComputerAgent(provider="invalid_provider")
|
||||
|
||||
# def test_computer_agent_uninstalled_provider():
|
||||
# """Test creating an agent with an uninstalled provider."""
|
||||
# with pytest.raises(NotImplementedError, match="OpenAI provider not yet implemented"):
|
||||
# # OpenAI provider is not implemented yet
|
||||
# ComputerAgent(provider=Provider.OPENAI)
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed")
|
||||
# async def test_agent_cleanup():
|
||||
# """Test agent cleanup."""
|
||||
# agent = ComputerAgent(provider=Provider.ANTHROPIC)
|
||||
# await agent.cleanup() # Should not raise any errors
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed")
|
||||
# async def test_agent_direct_initialization():
|
||||
# """Test direct initialization of the agent."""
|
||||
# # Create with default computer
|
||||
# agent = ComputerAgent(provider=Provider.ANTHROPIC)
|
||||
# try:
|
||||
# # Should not raise any errors
|
||||
# await agent.run("test task")
|
||||
# finally:
|
||||
# await agent.cleanup()
|
||||
|
||||
# # Create with custom computer
|
||||
# custom_computer = Computer(
|
||||
# display="1920x1080",
|
||||
# memory="8GB",
|
||||
# cpu="4",
|
||||
# os="macos",
|
||||
# use_host_computer_server=False,
|
||||
# )
|
||||
# agent = ComputerAgent(provider=Provider.ANTHROPIC, computer=custom_computer)
|
||||
# try:
|
||||
# # Should not raise any errors
|
||||
# await agent.run("test task")
|
||||
# finally:
|
||||
# await agent.cleanup()
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.skipif(not hasattr(ComputerAgent, '_ANTHROPIC_AVAILABLE'), reason="Anthropic provider not installed")
|
||||
# async def test_agent_context_manager():
|
||||
# """Test context manager initialization of the agent."""
|
||||
# # Test with default computer
|
||||
# async with ComputerAgent(provider=Provider.ANTHROPIC) as agent:
|
||||
# # Should not raise any errors
|
||||
# await agent.run("test task")
|
||||
|
||||
# # Test with custom computer
|
||||
# custom_computer = Computer(
|
||||
# display="1920x1080",
|
||||
# memory="8GB",
|
||||
# cpu="4",
|
||||
# os="macos",
|
||||
# use_host_computer_server=False,
|
||||
# )
|
||||
# async with ComputerAgent(provider=Provider.ANTHROPIC, computer=custom_computer) as agent:
|
||||
# # Should not raise any errors
|
||||
# await agent.run("test task")
|
||||
@@ -15,7 +15,7 @@
|
||||
</h1>
|
||||
</div>
|
||||
|
||||
**Computer** is a Computer-Use Interface (CUI) framework powering Cua for interacting with local macOS and Linux sandboxes, PyAutoGUI-compatible, and pluggable with any AI agent systems (Cua, Langchain, CrewAI, AutoGen). Computer relies on [Lume](https://github.com/trycua/lume) for creating and managing sandbox environments.
|
||||
**cua-computer** is a Computer-Use Interface (CUI) framework powering Cua for interacting with local macOS and Linux sandboxes, PyAutoGUI-compatible, and pluggable with any AI agent systems (Cua, Langchain, CrewAI, AutoGen). Computer relies on [Lume](https://github.com/trycua/lume) for creating and managing sandbox environments.
|
||||
|
||||
### Get started with Computer
|
||||
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
from typing import Optional, List, Literal, Dict, Any, Union, TYPE_CHECKING, cast
|
||||
from pylume import PyLume
|
||||
from pylume.models import (
|
||||
VMRunOpts,
|
||||
VMUpdateOpts,
|
||||
ImageRef,
|
||||
SharedDirectory,
|
||||
VMStatus
|
||||
)
|
||||
from pylume.models import VMRunOpts, VMUpdateOpts, ImageRef, SharedDirectory, VMStatus
|
||||
import asyncio
|
||||
from .models import Computer as ComputerConfig, Display
|
||||
from .interface.factory import InterfaceFactory
|
||||
@@ -66,8 +60,6 @@ class Computer:
|
||||
port: Optional port to use for the PyLume server
|
||||
host: Host to use for PyLume connections (e.g. "localhost", "host.docker.internal")
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
from .interface.base import BaseComputerInterface
|
||||
|
||||
self.logger = Logger("cua.computer", verbosity)
|
||||
self.logger.info("Initializing Computer...")
|
||||
@@ -159,6 +151,18 @@ class Computer:
|
||||
"""Exit async context manager."""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""Enter synchronous context manager."""
|
||||
# Run the event loop to call the async run method
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(self.run())
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit synchronous context manager."""
|
||||
# We could add cleanup here if needed in the future
|
||||
pass
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Initialize the VM and computer interface."""
|
||||
if TYPE_CHECKING:
|
||||
|
||||
Reference in New Issue
Block a user