mirror of
https://github.com/trycua/computer.git
synced 2026-05-02 05:01:22 -05:00
Add dev container, fix lints
This commit is contained in:
@@ -48,9 +48,7 @@ except Exception as e:
|
||||
# Other issues with telemetry
|
||||
logger.warning(f"Error initializing telemetry: {e}")
|
||||
|
||||
from .core.factory import AgentFactory
|
||||
from .core.agent import ComputerAgent
|
||||
from .providers.omni.types import LLMProvider, LLM
|
||||
from .types.base import Provider, AgentLoop
|
||||
from .types.base import AgentLoop
|
||||
|
||||
__all__ = ["AgentFactory", "Provider", "ComputerAgent", "AgentLoop", "LLMProvider", "LLM"]
|
||||
__all__ = ["AgentLoop", "LLMProvider", "LLM"]
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Core agent components."""
|
||||
|
||||
from .base_agent import BaseComputerAgent
|
||||
from .loop import BaseLoop
|
||||
from .messages import (
|
||||
create_user_message,
|
||||
@@ -12,7 +11,7 @@ from .messages import (
|
||||
ImageRetentionConfig,
|
||||
)
|
||||
from .callbacks import (
|
||||
CallbackManager,
|
||||
CallbackManager,
|
||||
CallbackHandler,
|
||||
BaseCallbackManager,
|
||||
ContentCallback,
|
||||
@@ -21,9 +20,8 @@ from .callbacks import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseComputerAgent",
|
||||
"BaseLoop",
|
||||
"CallbackManager",
|
||||
"BaseLoop",
|
||||
"CallbackManager",
|
||||
"CallbackHandler",
|
||||
"BaseMessageManager",
|
||||
"ImageRetentionConfig",
|
||||
|
||||
@@ -1,252 +0,0 @@
|
||||
"""Unified computer agent implementation that supports multiple loops."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, TYPE_CHECKING, Union, cast
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
from computer import Computer
|
||||
|
||||
from ..types.base import Provider, AgentLoop
|
||||
from .base_agent import BaseComputerAgent
|
||||
from ..core.telemetry import record_agent_initialization
|
||||
|
||||
# Only import types for type checking to avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
from ..providers.anthropic.loop import AnthropicLoop
|
||||
from ..providers.omni.loop import OmniLoop
|
||||
from ..providers.omni.parser import OmniParser
|
||||
|
||||
# Import the provider types
|
||||
from ..providers.omni.types import LLMProvider, LLM, Model, LLMModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default models for different providers
|
||||
DEFAULT_MODELS = {
|
||||
LLMProvider.OPENAI: "gpt-4o",
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
}
|
||||
|
||||
# Map providers to their environment variable names
|
||||
ENV_VARS = {
|
||||
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
||||
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
||||
}
|
||||
|
||||
|
||||
class ComputerAgent(BaseComputerAgent):
|
||||
"""Unified implementation of the computer agent supporting multiple loop types.
|
||||
|
||||
This class consolidates the previous AnthropicComputerAgent and OmniComputerAgent
|
||||
into a single implementation with configurable loop type.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
computer: Computer,
|
||||
loop: AgentLoop = AgentLoop.OMNI,
|
||||
model: Optional[Union[LLM, Dict[str, str], str]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
save_trajectory: bool = True,
|
||||
trajectory_dir: Optional[str] = "trajectories",
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
max_retries: int = 3,
|
||||
verbosity: int = logging.INFO,
|
||||
telemetry_enabled: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize a ComputerAgent instance.
|
||||
|
||||
Args:
|
||||
computer: The Computer instance to control
|
||||
loop: The agent loop to use: ANTHROPIC or OMNI
|
||||
model: The model to use. Can be a string, dict or LLM object.
|
||||
Defaults to LLM for the loop type.
|
||||
api_key: The API key to use. If None, will use environment variables.
|
||||
save_trajectory: Whether to save the trajectory.
|
||||
trajectory_dir: The directory to save trajectories to.
|
||||
only_n_most_recent_images: Only keep this many most recent images.
|
||||
max_retries: Maximum number of retries for failed requests.
|
||||
verbosity: Logging level (standard Python logging levels).
|
||||
telemetry_enabled: Whether to enable telemetry tracking. Defaults to True.
|
||||
**kwargs: Additional keyword arguments to pass to the loop.
|
||||
"""
|
||||
super().__init__(computer)
|
||||
self._configure_logging(verbosity)
|
||||
logger.info(f"Initializing ComputerAgent with {loop} loop")
|
||||
|
||||
# Store telemetry preference
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
|
||||
# Process the model configuration
|
||||
self.model = self._process_model_config(model, loop)
|
||||
self.loop_type = loop
|
||||
self.api_key = api_key
|
||||
|
||||
# Store computer
|
||||
self.computer = computer
|
||||
|
||||
# Save trajectory settings
|
||||
self.save_trajectory = save_trajectory
|
||||
self.trajectory_dir = trajectory_dir
|
||||
self.only_n_most_recent_images = only_n_most_recent_images
|
||||
|
||||
# Store the max retries setting
|
||||
self.max_retries = max_retries
|
||||
|
||||
# Initialize message history
|
||||
self.messages = []
|
||||
|
||||
# Extra kwargs for the loop
|
||||
self.loop_kwargs = kwargs
|
||||
|
||||
# Initialize the actual loop implementation
|
||||
self.loop = self._init_loop()
|
||||
|
||||
# Record initialization in telemetry if enabled
|
||||
if telemetry_enabled:
|
||||
record_agent_initialization()
|
||||
|
||||
def _process_model_config(
|
||||
self, model_input: Optional[Union[LLM, Dict[str, str], str]], loop: AgentLoop
|
||||
) -> LLM:
|
||||
"""Process and normalize model configuration.
|
||||
|
||||
Args:
|
||||
model_input: Input model configuration (LLM, dict, string, or None)
|
||||
loop: The loop type being used
|
||||
|
||||
Returns:
|
||||
Normalized LLM instance
|
||||
"""
|
||||
# Handle case where model_input is None
|
||||
if model_input is None:
|
||||
# Use Anthropic for Anthropic loop, OpenAI for Omni loop
|
||||
default_provider = (
|
||||
LLMProvider.ANTHROPIC if loop == AgentLoop.ANTHROPIC else LLMProvider.OPENAI
|
||||
)
|
||||
return LLM(provider=default_provider)
|
||||
|
||||
# Handle case where model_input is already a LLM or one of its aliases
|
||||
if isinstance(model_input, (LLM, Model, LLMModel)):
|
||||
return model_input
|
||||
|
||||
# Handle case where model_input is a dict
|
||||
if isinstance(model_input, dict):
|
||||
provider = model_input.get("provider", LLMProvider.OPENAI)
|
||||
if isinstance(provider, str):
|
||||
provider = LLMProvider(provider)
|
||||
return LLM(provider=provider, name=model_input.get("name"))
|
||||
|
||||
# Handle case where model_input is a string (model name)
|
||||
if isinstance(model_input, str):
|
||||
default_provider = (
|
||||
LLMProvider.ANTHROPIC if loop == AgentLoop.ANTHROPIC else LLMProvider.OPENAI
|
||||
)
|
||||
return LLM(provider=default_provider, name=model_input)
|
||||
|
||||
raise ValueError(f"Unsupported model configuration: {model_input}")
|
||||
|
||||
def _configure_logging(self, verbosity: int):
|
||||
"""Configure logging based on verbosity level."""
|
||||
# Use the logging level directly without mapping
|
||||
logger.setLevel(verbosity)
|
||||
logging.getLogger("agent").setLevel(verbosity)
|
||||
|
||||
# Log the verbosity level that was set
|
||||
if verbosity <= logging.DEBUG:
|
||||
logger.info("Agent logging set to DEBUG level (full debug information)")
|
||||
elif verbosity <= logging.INFO:
|
||||
logger.info("Agent logging set to INFO level (standard output)")
|
||||
elif verbosity <= logging.WARNING:
|
||||
logger.warning("Agent logging set to WARNING level (warnings and errors only)")
|
||||
elif verbosity <= logging.ERROR:
|
||||
logger.warning("Agent logging set to ERROR level (errors only)")
|
||||
elif verbosity <= logging.CRITICAL:
|
||||
logger.warning("Agent logging set to CRITICAL level (critical errors only)")
|
||||
|
||||
def _init_loop(self) -> Any:
|
||||
"""Initialize the loop based on the loop_type.
|
||||
|
||||
Returns:
|
||||
Initialized loop instance
|
||||
"""
|
||||
# Lazy import OmniLoop and OmniParser to avoid circular imports
|
||||
from ..providers.omni.loop import OmniLoop
|
||||
from ..providers.omni.parser import OmniParser
|
||||
|
||||
if self.loop_type == AgentLoop.ANTHROPIC:
|
||||
from ..providers.anthropic.loop import AnthropicLoop
|
||||
|
||||
# Ensure we always have a valid model name
|
||||
model_name = self.model.name or DEFAULT_MODELS[LLMProvider.ANTHROPIC]
|
||||
|
||||
return AnthropicLoop(
|
||||
api_key=self.api_key,
|
||||
model=model_name,
|
||||
computer=self.computer,
|
||||
save_trajectory=self.save_trajectory,
|
||||
base_dir=self.trajectory_dir,
|
||||
only_n_most_recent_images=self.only_n_most_recent_images,
|
||||
**self.loop_kwargs,
|
||||
)
|
||||
|
||||
# Initialize parser for OmniLoop with appropriate device
|
||||
if "parser" not in self.loop_kwargs:
|
||||
self.loop_kwargs["parser"] = OmniParser()
|
||||
|
||||
# Ensure we always have a valid model name
|
||||
model_name = self.model.name or DEFAULT_MODELS[self.model.provider]
|
||||
|
||||
return OmniLoop(
|
||||
provider=self.model.provider,
|
||||
api_key=self.api_key,
|
||||
model=model_name,
|
||||
computer=self.computer,
|
||||
save_trajectory=self.save_trajectory,
|
||||
base_dir=self.trajectory_dir,
|
||||
only_n_most_recent_images=self.only_n_most_recent_images,
|
||||
**self.loop_kwargs,
|
||||
)
|
||||
|
||||
async def _execute_task(self, task: str) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Execute a task using the appropriate agent loop.
|
||||
|
||||
Args:
|
||||
task: The task to execute
|
||||
|
||||
Returns:
|
||||
AsyncGenerator yielding task outputs
|
||||
"""
|
||||
logger.info(f"Executing task: {task}")
|
||||
|
||||
try:
|
||||
# Create a message from the task
|
||||
task_message = {"role": "user", "content": task}
|
||||
messages_with_task = self.messages + [task_message]
|
||||
|
||||
# Use the run method of the loop
|
||||
async for output in self.loop.run(messages_with_task):
|
||||
yield output
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing task: {e}")
|
||||
raise
|
||||
finally:
|
||||
pass
|
||||
|
||||
async def _execute_action(self, action_type: str, **action_params) -> Any:
|
||||
"""Execute an action with telemetry tracking."""
|
||||
try:
|
||||
# Execute the action
|
||||
result = await super()._execute_action(action_type, **action_params)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(f"Error executing action {action_type}: {e}")
|
||||
raise
|
||||
finally:
|
||||
pass
|
||||
@@ -1,164 +0,0 @@
|
||||
"""Base computer agent implementation."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
|
||||
from computer import Computer
|
||||
|
||||
from ..types.base import Provider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseComputerAgent(ABC):
|
||||
"""Base class for computer agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retries: int = 3,
|
||||
computer: Optional[Computer] = None,
|
||||
screenshot_dir: Optional[str] = None,
|
||||
log_dir: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the base computer agent.
|
||||
|
||||
Args:
|
||||
max_retries: Maximum number of retry attempts
|
||||
computer: Optional Computer instance
|
||||
screenshot_dir: Directory to save screenshots
|
||||
log_dir: Directory to save logs (set to None to disable logging to files)
|
||||
**kwargs: Additional provider-specific arguments
|
||||
"""
|
||||
self.max_retries = max_retries
|
||||
self.computer = computer or Computer()
|
||||
self.queue = asyncio.Queue()
|
||||
self.screenshot_dir = screenshot_dir
|
||||
self.log_dir = log_dir
|
||||
self._retry_count = 0
|
||||
self.provider = Provider.UNKNOWN
|
||||
|
||||
# Setup logging
|
||||
if self.log_dir:
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
logger.info(f"Created logs directory: {self.log_dir}")
|
||||
|
||||
# Setup screenshots directory
|
||||
if self.screenshot_dir:
|
||||
os.makedirs(self.screenshot_dir, exist_ok=True)
|
||||
logger.info(f"Created screenshots directory: {self.screenshot_dir}")
|
||||
|
||||
logger.info("BaseComputerAgent initialized")
|
||||
|
||||
async def run(self, task: str) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Run a task using the computer agent.
|
||||
|
||||
Args:
|
||||
task: Task description
|
||||
|
||||
Yields:
|
||||
Task execution updates
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Running task: {task}")
|
||||
|
||||
# Initialize the computer if needed
|
||||
await self._init_if_needed()
|
||||
|
||||
# Execute the task and yield results
|
||||
# The _execute_task method should be implemented to yield results
|
||||
async for result in self._execute_task(task):
|
||||
yield result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in agent run method: {str(e)}")
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
|
||||
async def _init_if_needed(self):
|
||||
"""Initialize the computer interface if it hasn't been initialized yet."""
|
||||
if not self.computer._initialized:
|
||||
logger.info("Computer not initialized, initializing now...")
|
||||
try:
|
||||
# Call run directly without setting the flag first
|
||||
await self.computer.run()
|
||||
logger.info("Computer interface initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing computer interface: {str(e)}")
|
||||
raise
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Initialize the agent when used as a context manager."""
|
||||
logger.info("Entering BaseComputerAgent context")
|
||||
|
||||
# In case the computer wasn't initialized
|
||||
try:
|
||||
# Initialize the computer only if not already initialized
|
||||
logger.info("Checking if computer is already initialized...")
|
||||
if not self.computer._initialized:
|
||||
logger.info("Initializing computer in __aenter__...")
|
||||
# Use the computer's __aenter__ directly instead of calling run()
|
||||
# This avoids the circular dependency
|
||||
await self.computer.__aenter__()
|
||||
logger.info("Computer initialized in __aenter__")
|
||||
else:
|
||||
logger.info("Computer already initialized, skipping initialization")
|
||||
|
||||
# Take a test screenshot to verify the computer is working
|
||||
logger.info("Testing computer with a screenshot...")
|
||||
try:
|
||||
test_screenshot = await self.computer.interface.screenshot()
|
||||
# Determine the screenshot size based on its type
|
||||
if isinstance(test_screenshot, bytes):
|
||||
size = len(test_screenshot)
|
||||
else:
|
||||
# Assume it's an object with base64_image attribute
|
||||
try:
|
||||
size = len(test_screenshot.base64_image)
|
||||
except AttributeError:
|
||||
size = "unknown"
|
||||
logger.info(f"Screenshot test successful, size: {size}")
|
||||
except Exception as e:
|
||||
logger.error(f"Screenshot test failed: {str(e)}")
|
||||
# Even though screenshot failed, we continue since some tests might not need it
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing computer in __aenter__: {str(e)}")
|
||||
raise
|
||||
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Cleanup computer resources if needed."""
|
||||
logger.info("Cleaning up agent resources")
|
||||
|
||||
# Do any necessary cleanup
|
||||
# We're not shutting down the computer here as it might be shared
|
||||
# Just log that we're exiting
|
||||
if exc_type:
|
||||
logger.error(f"Exiting agent context with error: {exc_type.__name__}: {exc_val}")
|
||||
else:
|
||||
logger.info("Exiting agent context normally")
|
||||
|
||||
# If we have a queue, make sure to signal it's done
|
||||
if hasattr(self, "queue") and self.queue:
|
||||
await self.queue.put(None) # Signal that we're done
|
||||
|
||||
@abstractmethod
|
||||
async def _execute_task(self, task: str) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Execute a task. Must be implemented by subclasses.
|
||||
|
||||
This is an async method that returns an AsyncGenerator. Implementations
|
||||
should use 'yield' statements to produce results asynchronously.
|
||||
"""
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": "Base class method called",
|
||||
"metadata": {"title": "Error"},
|
||||
}
|
||||
raise NotImplementedError("Subclasses must implement _execute_task")
|
||||
@@ -1,69 +1,251 @@
|
||||
"""Main entry point for computer agents."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
import os
|
||||
from typing import Any, AsyncGenerator, Dict, Optional, cast
|
||||
from dataclasses import dataclass
|
||||
|
||||
from computer import Computer
|
||||
from ..types.base import Provider
|
||||
from .factory import AgentFactory
|
||||
from ..providers.anthropic.loop import AnthropicLoop
|
||||
from ..providers.omni.loop import OmniLoop
|
||||
from ..providers.omni.parser import OmniParser
|
||||
from ..providers.omni.types import LLMProvider, LLM
|
||||
from .. import AgentLoop
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Default models for different providers
|
||||
DEFAULT_MODELS = {
|
||||
LLMProvider.OPENAI: "gpt-4o",
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
}
|
||||
|
||||
# Map providers to their environment variable names
|
||||
ENV_VARS = {
|
||||
LLMProvider.OPENAI: "OPENAI_API_KEY",
|
||||
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
||||
}
|
||||
|
||||
|
||||
class ComputerAgent:
|
||||
"""A computer agent that can perform automated tasks using natural language instructions."""
|
||||
|
||||
def __init__(self, provider: Provider, computer: Optional[Computer] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
computer: Computer,
|
||||
model: LLM,
|
||||
loop: AgentLoop,
|
||||
max_retries: int = 3,
|
||||
screenshot_dir: Optional[str] = None,
|
||||
log_dir: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
save_trajectory: bool = True,
|
||||
trajectory_dir: str = "trajectories",
|
||||
only_n_most_recent_images: Optional[int] = None,
|
||||
parser: Optional[OmniParser] = None,
|
||||
verbosity: int = logging.INFO,
|
||||
):
|
||||
"""Initialize the ComputerAgent.
|
||||
|
||||
Args:
|
||||
provider: The AI provider to use (e.g., Provider.ANTHROPIC)
|
||||
computer: Optional Computer instance. If not provided, one will be created with default settings.
|
||||
**kwargs: Additional provider-specific arguments
|
||||
computer: Computer instance. If not provided, one will be created with default settings.
|
||||
max_retries: Maximum number of retry attempts.
|
||||
screenshot_dir: Directory to save screenshots.
|
||||
log_dir: Directory to save logs (set to None to disable logging to files).
|
||||
model: LLM object containing provider and model name. Takes precedence over provider/model_name.
|
||||
provider: The AI provider to use (e.g., LLMProvider.ANTHROPIC). Only used if model is None.
|
||||
api_key: The API key for the provider. If not provided, will look for environment variable.
|
||||
model_name: The model name to use. Only used if model is None.
|
||||
save_trajectory: Whether to save the trajectory.
|
||||
trajectory_dir: Directory to save the trajectory.
|
||||
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests.
|
||||
parser: Parser instance for the OmniLoop. Only used if provider is not ANTHROPIC.
|
||||
verbosity: Logging level.
|
||||
"""
|
||||
self.provider = provider
|
||||
self._computer = computer
|
||||
self._kwargs = kwargs
|
||||
self._agent = None
|
||||
# Basic agent configuration
|
||||
self.max_retries = max_retries
|
||||
self.computer = computer or Computer()
|
||||
self.queue = asyncio.Queue()
|
||||
self.screenshot_dir = screenshot_dir
|
||||
self.log_dir = log_dir
|
||||
self._retry_count = 0
|
||||
self._initialized = False
|
||||
self._in_context = False
|
||||
|
||||
# Create provider-specific agent using factory
|
||||
self._agent = AgentFactory.create(provider=provider, computer=computer, **kwargs)
|
||||
# Set logging level
|
||||
logger.setLevel(verbosity)
|
||||
|
||||
# Setup logging
|
||||
if self.log_dir:
|
||||
os.makedirs(self.log_dir, exist_ok=True)
|
||||
logger.info(f"Created logs directory: {self.log_dir}")
|
||||
|
||||
# Setup screenshots directory
|
||||
if self.screenshot_dir:
|
||||
os.makedirs(self.screenshot_dir, exist_ok=True)
|
||||
logger.info(f"Created screenshots directory: {self.screenshot_dir}")
|
||||
|
||||
# Use the provided LLM object
|
||||
self.provider = model.provider
|
||||
actual_model_name = model.name or DEFAULT_MODELS.get(self.provider, "")
|
||||
|
||||
# Ensure we have a valid model name
|
||||
if not actual_model_name:
|
||||
actual_model_name = DEFAULT_MODELS.get(self.provider, "")
|
||||
if not actual_model_name:
|
||||
raise ValueError(
|
||||
f"No model specified for provider {self.provider} and no default found"
|
||||
)
|
||||
|
||||
# Ensure computer is properly cast for typing purposes
|
||||
computer_instance = cast(Computer, self.computer)
|
||||
|
||||
# Get API key from environment if not provided
|
||||
actual_api_key = api_key or os.environ.get(ENV_VARS[self.provider], "")
|
||||
if not actual_api_key:
|
||||
raise ValueError(f"No API key provided for {self.provider}")
|
||||
|
||||
# Initialize the appropriate loop based on the loop parameter
|
||||
if loop == AgentLoop.ANTHROPIC:
|
||||
self._loop = AnthropicLoop(
|
||||
api_key=actual_api_key,
|
||||
model=actual_model_name,
|
||||
computer=computer_instance,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
)
|
||||
else:
|
||||
# Default to OmniLoop for other loop types
|
||||
# Initialize parser if not provided
|
||||
actual_parser = parser or OmniParser()
|
||||
|
||||
self._loop = OmniLoop(
|
||||
provider=self.provider,
|
||||
api_key=actual_api_key,
|
||||
model=actual_model_name,
|
||||
computer=computer_instance,
|
||||
save_trajectory=save_trajectory,
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
parser=actual_parser,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"ComputerAgent initialized with provider: {self.provider}, model: {actual_model_name}"
|
||||
)
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter the async context manager."""
|
||||
"""Initialize the agent when used as a context manager."""
|
||||
logger.info("Entering ComputerAgent context")
|
||||
self._in_context = True
|
||||
|
||||
# In case the computer wasn't initialized
|
||||
try:
|
||||
# Initialize the computer only if not already initialized
|
||||
logger.info("Checking if computer is already initialized...")
|
||||
if not self.computer._initialized:
|
||||
logger.info("Initializing computer in __aenter__...")
|
||||
# Use the computer's __aenter__ directly instead of calling run()
|
||||
await self.computer.__aenter__()
|
||||
logger.info("Computer initialized in __aenter__")
|
||||
else:
|
||||
logger.info("Computer already initialized, skipping initialization")
|
||||
|
||||
# Take a test screenshot to verify the computer is working
|
||||
logger.info("Testing computer with a screenshot...")
|
||||
try:
|
||||
test_screenshot = await self.computer.interface.screenshot()
|
||||
# Determine the screenshot size based on its type
|
||||
if isinstance(test_screenshot, (bytes, bytearray, memoryview)):
|
||||
size = len(test_screenshot)
|
||||
elif hasattr(test_screenshot, "base64_image"):
|
||||
size = len(test_screenshot.base64_image)
|
||||
else:
|
||||
size = "unknown"
|
||||
logger.info(f"Screenshot test successful, size: {size}")
|
||||
except Exception as e:
|
||||
logger.error(f"Screenshot test failed: {str(e)}")
|
||||
# Even though screenshot failed, we continue since some tests might not need it
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing computer in __aenter__: {str(e)}")
|
||||
raise
|
||||
|
||||
await self.initialize()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit the async context manager."""
|
||||
"""Cleanup agent resources if needed."""
|
||||
logger.info("Cleaning up agent resources")
|
||||
self._in_context = False
|
||||
|
||||
# Do any necessary cleanup
|
||||
# We're not shutting down the computer here as it might be shared
|
||||
# Just log that we're exiting
|
||||
if exc_type:
|
||||
logger.error(f"Exiting agent context with error: {exc_type.__name__}: {exc_val}")
|
||||
else:
|
||||
logger.info("Exiting agent context normally")
|
||||
|
||||
# If we have a queue, make sure to signal it's done
|
||||
if hasattr(self, "queue") and self.queue:
|
||||
await self.queue.put(None) # Signal that we're done
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the agent and its components."""
|
||||
if not self._initialized:
|
||||
if not self._in_context and self._computer:
|
||||
# If not in context manager but have a computer, initialize it
|
||||
await self._computer.run()
|
||||
# Always initialize the computer if available
|
||||
if self.computer and not self.computer._initialized:
|
||||
await self.computer.run()
|
||||
self._initialized = True
|
||||
|
||||
async def _init_if_needed(self):
|
||||
"""Initialize the computer interface if it hasn't been initialized yet."""
|
||||
if not self.computer._initialized:
|
||||
logger.info("Computer not initialized, initializing now...")
|
||||
try:
|
||||
# Call run directly
|
||||
await self.computer.run()
|
||||
logger.info("Computer interface initialized successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing computer interface: {str(e)}")
|
||||
raise
|
||||
|
||||
async def run(self, task: str) -> AsyncGenerator[Dict[str, Any], None]:
|
||||
"""Run the agent with a given task."""
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
"""Run a task using the computer agent.
|
||||
|
||||
if self._agent is None:
|
||||
logger.error("Agent not initialized properly")
|
||||
yield {"error": "Agent not initialized properly"}
|
||||
return
|
||||
Args:
|
||||
task: Task description
|
||||
|
||||
async for result in self._agent.run(task):
|
||||
yield result
|
||||
Yields:
|
||||
Task execution updates
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Running task: {task}")
|
||||
|
||||
@property
|
||||
def computer(self) -> Optional[Computer]:
|
||||
"""Get the underlying computer instance."""
|
||||
return self._agent.computer if self._agent else None
|
||||
# Initialize the computer if needed
|
||||
if not self._initialized:
|
||||
await self.initialize()
|
||||
|
||||
# Format task as a message
|
||||
messages = [{"role": "user", "content": task}]
|
||||
|
||||
# Pass properly formatted messages to the loop
|
||||
if self._loop is None:
|
||||
logger.error("Loop not initialized properly")
|
||||
yield {"error": "Loop not initialized properly"}
|
||||
return
|
||||
|
||||
# Execute the task and yield results
|
||||
async for result in self._loop.run(messages):
|
||||
yield result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in agent run method: {str(e)}")
|
||||
yield {
|
||||
"role": "assistant",
|
||||
"content": f"Error: {str(e)}",
|
||||
"metadata": {"title": "❌ Error"},
|
||||
}
|
||||
|
||||
@@ -84,7 +84,21 @@ class ExperimentManager:
|
||||
if isinstance(data, dict):
|
||||
result = {}
|
||||
for k, v in data.items():
|
||||
result[k] = self.sanitize_log_data(v)
|
||||
# Special handling for 'data' field in Anthropic message source
|
||||
if k == "data" and isinstance(v, str) and len(v) > 1000:
|
||||
result[k] = f"[BASE64_DATA_LENGTH_{len(v)}]"
|
||||
# Special handling for the 'media_type' key which indicates we're in an image block
|
||||
elif k == "media_type" and "image" in str(v):
|
||||
result[k] = v
|
||||
# If we're in an image block, look for a sibling 'data' field with base64 content
|
||||
if (
|
||||
"data" in result
|
||||
and isinstance(result["data"], str)
|
||||
and len(result["data"]) > 1000
|
||||
):
|
||||
result["data"] = f"[BASE64_DATA_LENGTH_{len(result['data'])}]"
|
||||
else:
|
||||
result[k] = self.sanitize_log_data(v)
|
||||
return result
|
||||
elif isinstance(data, list):
|
||||
return [self.sanitize_log_data(item) for item in data]
|
||||
@@ -93,15 +107,18 @@ class ExperimentManager:
|
||||
else:
|
||||
return data
|
||||
|
||||
def save_screenshot(self, img_base64: str, action_type: str = "") -> None:
|
||||
def save_screenshot(self, img_base64: str, action_type: str = "") -> Optional[str]:
|
||||
"""Save a screenshot to the experiment directory.
|
||||
|
||||
Args:
|
||||
img_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
|
||||
Returns:
|
||||
Path to the saved screenshot or None if there was an error
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
# Increment screenshot counter
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
"""Factory for creating provider-specific agents."""
|
||||
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from computer import Computer
|
||||
from ..types.base import Provider
|
||||
from .base_agent import BaseComputerAgent
|
||||
|
||||
# Import provider-specific implementations
|
||||
_ANTHROPIC_AVAILABLE = False
|
||||
_OPENAI_AVAILABLE = False
|
||||
_OLLAMA_AVAILABLE = False
|
||||
_OMNI_AVAILABLE = False
|
||||
|
||||
# Try importing providers
|
||||
try:
|
||||
import anthropic
|
||||
from ..providers.anthropic.agent import AnthropicComputerAgent
|
||||
|
||||
_ANTHROPIC_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
import openai
|
||||
|
||||
_OPENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
try:
|
||||
from ..providers.omni.agent import OmniComputerAgent
|
||||
|
||||
_OMNI_AVAILABLE = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class AgentFactory:
|
||||
"""Factory for creating provider-specific agent implementations."""
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
provider: Provider, computer: Optional[Computer] = None, **kwargs: Any
|
||||
) -> BaseComputerAgent:
|
||||
"""Create an agent based on the specified provider.
|
||||
|
||||
Args:
|
||||
provider: The AI provider to use
|
||||
computer: Optional Computer instance
|
||||
**kwargs: Additional provider-specific arguments
|
||||
|
||||
Returns:
|
||||
A provider-specific agent implementation
|
||||
|
||||
Raises:
|
||||
ImportError: If provider dependencies are not installed
|
||||
ValueError: If provider is not supported
|
||||
"""
|
||||
# Create a Computer instance if none is provided
|
||||
if computer is None:
|
||||
computer = Computer()
|
||||
|
||||
if provider == Provider.ANTHROPIC:
|
||||
if not _ANTHROPIC_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Anthropic provider requires additional dependencies. "
|
||||
"Install them with: pip install cua-agent[anthropic]"
|
||||
)
|
||||
return AnthropicComputerAgent(max_retries=3, computer=computer, **kwargs)
|
||||
elif provider == Provider.OPENAI:
|
||||
if not _OPENAI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenAI provider requires additional dependencies. "
|
||||
"Install them with: pip install cua-agent[openai]"
|
||||
)
|
||||
raise NotImplementedError("OpenAI provider not yet implemented")
|
||||
elif provider == Provider.OLLAMA:
|
||||
if not _OLLAMA_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Ollama provider requires additional dependencies. "
|
||||
"Install them with: pip install cua-agent[ollama]"
|
||||
)
|
||||
# Only import ollama when actually creating an Ollama agent
|
||||
try:
|
||||
import ollama
|
||||
from ..providers.ollama.agent import OllamaComputerAgent
|
||||
|
||||
return OllamaComputerAgent(max_retries=3, computer=computer, **kwargs)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Failed to import ollama package. " "Install it with: pip install ollama"
|
||||
)
|
||||
elif provider == Provider.OMNI:
|
||||
if not _OMNI_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Omni provider requires additional dependencies. "
|
||||
"Install them with: pip install cua-agent[omni]"
|
||||
)
|
||||
return OmniComputerAgent(max_retries=3, computer=computer, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
@@ -141,9 +141,6 @@ class BaseLoop(ABC):
|
||||
# Initialize API client
|
||||
await self.initialize_client()
|
||||
|
||||
# Initialize computer
|
||||
await self.computer.initialize()
|
||||
|
||||
logger.info("Initialization complete.")
|
||||
return
|
||||
except Exception as e:
|
||||
@@ -173,15 +170,22 @@ class BaseLoop(ABC):
|
||||
base64_image = ""
|
||||
|
||||
# Handle different types of screenshot returns
|
||||
if isinstance(screenshot, bytes):
|
||||
if isinstance(screenshot, (bytes, bytearray, memoryview)):
|
||||
# Raw bytes screenshot
|
||||
base64_image = base64.b64encode(screenshot).decode("utf-8")
|
||||
elif hasattr(screenshot, "base64_image"):
|
||||
# Object-style screenshot with attributes
|
||||
base64_image = screenshot.base64_image
|
||||
if hasattr(screenshot, "width") and hasattr(screenshot, "height"):
|
||||
width = screenshot.width
|
||||
height = screenshot.height
|
||||
# Type checking can't infer these attributes, but they exist at runtime
|
||||
# on certain screenshot return types
|
||||
base64_image = getattr(screenshot, "base64_image")
|
||||
width = (
|
||||
getattr(screenshot, "width", width) if hasattr(screenshot, "width") else width
|
||||
)
|
||||
height = (
|
||||
getattr(screenshot, "height", height)
|
||||
if hasattr(screenshot, "height")
|
||||
else height
|
||||
)
|
||||
|
||||
# Create parsed screen data
|
||||
parsed_screen = {
|
||||
|
||||
@@ -4,39 +4,11 @@ import logging
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, Callable
|
||||
|
||||
# Import the core telemetry module
|
||||
TELEMETRY_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from core.telemetry import (
|
||||
record_event,
|
||||
increment,
|
||||
get_telemetry_client,
|
||||
flush,
|
||||
is_telemetry_enabled,
|
||||
is_telemetry_globally_disabled,
|
||||
)
|
||||
|
||||
def increment_counter(counter_name: str, value: int = 1) -> None:
|
||||
"""Wrapper for increment to maintain backward compatibility."""
|
||||
if is_telemetry_enabled():
|
||||
increment(counter_name, value)
|
||||
|
||||
def set_dimension(name: str, value: Any) -> None:
|
||||
"""Set a dimension that will be attached to all events."""
|
||||
logger = logging.getLogger("cua.agent.telemetry")
|
||||
logger.debug(f"Setting dimension {name}={value}")
|
||||
|
||||
TELEMETRY_AVAILABLE = True
|
||||
logger = logging.getLogger("cua.agent.telemetry")
|
||||
logger.info("Successfully imported telemetry")
|
||||
except ImportError as e:
|
||||
logger = logging.getLogger("cua.agent.telemetry")
|
||||
logger.warning(f"Could not import telemetry: {e}")
|
||||
TELEMETRY_AVAILABLE = False
|
||||
|
||||
|
||||
# Local fallbacks in case core telemetry isn't available
|
||||
def _noop(*args: Any, **kwargs: Any) -> None:
|
||||
@@ -44,18 +16,58 @@ def _noop(*args: Any, **kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
|
||||
# Define default functions with unique names to avoid shadowing
|
||||
_default_record_event = _noop
|
||||
_default_increment_counter = _noop
|
||||
_default_set_dimension = _noop
|
||||
_default_get_telemetry_client = lambda: None
|
||||
_default_flush = _noop
|
||||
_default_is_telemetry_enabled = lambda: False
|
||||
_default_is_telemetry_globally_disabled = lambda: True
|
||||
|
||||
# Set the actual functions to the defaults initially
|
||||
record_event = _default_record_event
|
||||
increment_counter = _default_increment_counter
|
||||
set_dimension = _default_set_dimension
|
||||
get_telemetry_client = _default_get_telemetry_client
|
||||
flush = _default_flush
|
||||
is_telemetry_enabled = _default_is_telemetry_enabled
|
||||
is_telemetry_globally_disabled = _default_is_telemetry_globally_disabled
|
||||
|
||||
logger = logging.getLogger("cua.agent.telemetry")
|
||||
|
||||
# If telemetry isn't available, use no-op functions
|
||||
if not TELEMETRY_AVAILABLE:
|
||||
try:
|
||||
# Import from core telemetry
|
||||
from core.telemetry import (
|
||||
record_event as core_record_event,
|
||||
increment as core_increment,
|
||||
get_telemetry_client as core_get_telemetry_client,
|
||||
flush as core_flush,
|
||||
is_telemetry_enabled as core_is_telemetry_enabled,
|
||||
is_telemetry_globally_disabled as core_is_telemetry_globally_disabled,
|
||||
)
|
||||
|
||||
# Override the default functions with actual implementations
|
||||
record_event = core_record_event
|
||||
get_telemetry_client = core_get_telemetry_client
|
||||
flush = core_flush
|
||||
is_telemetry_enabled = core_is_telemetry_enabled
|
||||
is_telemetry_globally_disabled = core_is_telemetry_globally_disabled
|
||||
|
||||
def increment_counter(counter_name: str, value: int = 1) -> None:
|
||||
"""Wrapper for increment to maintain backward compatibility."""
|
||||
if is_telemetry_enabled():
|
||||
core_increment(counter_name, value)
|
||||
|
||||
def set_dimension(name: str, value: Any) -> None:
|
||||
"""Set a dimension that will be attached to all events."""
|
||||
logger.debug(f"Setting dimension {name}={value}")
|
||||
|
||||
TELEMETRY_AVAILABLE = True
|
||||
logger.info("Successfully imported telemetry")
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import telemetry: {e}")
|
||||
logger.debug("Telemetry not available, using no-op functions")
|
||||
record_event = _noop # type: ignore
|
||||
increment_counter = _noop # type: ignore
|
||||
set_dimension = _noop # type: ignore
|
||||
get_telemetry_client = lambda: None # type: ignore
|
||||
flush = _noop # type: ignore
|
||||
is_telemetry_enabled = lambda: False # type: ignore
|
||||
is_telemetry_globally_disabled = lambda: True # type: ignore
|
||||
|
||||
# Get system info once to use in telemetry
|
||||
SYSTEM_INFO = {
|
||||
@@ -71,7 +83,7 @@ def enable_telemetry() -> bool:
|
||||
Returns:
|
||||
bool: True if telemetry was successfully enabled, False otherwise
|
||||
"""
|
||||
global TELEMETRY_AVAILABLE
|
||||
global TELEMETRY_AVAILABLE, record_event, increment_counter, get_telemetry_client, flush, is_telemetry_enabled, is_telemetry_globally_disabled
|
||||
|
||||
# Check if globally disabled using core function
|
||||
if TELEMETRY_AVAILABLE and is_telemetry_globally_disabled():
|
||||
|
||||
@@ -17,6 +17,7 @@ from anthropic.types.beta import (
|
||||
BetaTextBlock,
|
||||
BetaTextBlockParam,
|
||||
BetaToolUseBlockParam,
|
||||
BetaContentBlockParam,
|
||||
)
|
||||
|
||||
# Computer
|
||||
@@ -24,12 +25,12 @@ from computer import Computer
|
||||
|
||||
# Base imports
|
||||
from ...core.loop import BaseLoop
|
||||
from ...core.messages import ImageRetentionConfig
|
||||
from ...core.messages import ImageRetentionConfig as CoreImageRetentionConfig
|
||||
|
||||
# Anthropic provider-specific imports
|
||||
from .api.client import AnthropicClientFactory, BaseAnthropicClient
|
||||
from .tools.manager import ToolManager
|
||||
from .messages.manager import MessageManager
|
||||
from .messages.manager import MessageManager, ImageRetentionConfig
|
||||
from .callbacks.manager import CallbackManager
|
||||
from .prompts import SYSTEM_PROMPT
|
||||
from .types import LLMProvider
|
||||
@@ -48,8 +49,8 @@ class AnthropicLoop(BaseLoop):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
computer: Computer,
|
||||
model: str = "claude-3-7-sonnet-20250219", # Fixed model
|
||||
computer: Optional[Computer] = None,
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
max_retries: int = 3,
|
||||
@@ -69,7 +70,7 @@ class AnthropicLoop(BaseLoop):
|
||||
retry_delay: Delay between retries in seconds
|
||||
save_trajectory: Whether to save trajectory data
|
||||
"""
|
||||
# Initialize base class
|
||||
# Initialize base class with core config
|
||||
super().__init__(
|
||||
computer=computer,
|
||||
model=model,
|
||||
@@ -93,8 +94,8 @@ class AnthropicLoop(BaseLoop):
|
||||
self.message_manager = None
|
||||
self.callback_manager = None
|
||||
|
||||
# Configure image retention
|
||||
self.image_retention_config = ImageRetentionConfig(
|
||||
# Configure image retention with core config
|
||||
self.image_retention_config = CoreImageRetentionConfig(
|
||||
num_images_to_keep=only_n_most_recent_images
|
||||
)
|
||||
|
||||
@@ -113,7 +114,7 @@ class AnthropicLoop(BaseLoop):
|
||||
|
||||
# Initialize message manager
|
||||
self.message_manager = MessageManager(
|
||||
ImageRetentionConfig(
|
||||
image_retention_config=ImageRetentionConfig(
|
||||
num_images_to_keep=self.only_n_most_recent_images, enable_caching=True
|
||||
)
|
||||
)
|
||||
@@ -250,6 +251,10 @@ class AnthropicLoop(BaseLoop):
|
||||
await self._process_screen(parsed_screen, self.message_history)
|
||||
|
||||
# Prepare messages and make API call
|
||||
if self.message_manager is None:
|
||||
raise RuntimeError(
|
||||
"Message manager not initialized. Call initialize_client() first."
|
||||
)
|
||||
prepared_messages = self.message_manager.prepare_messages(
|
||||
cast(List[BetaMessageParam], self.message_history.copy())
|
||||
)
|
||||
@@ -257,7 +262,7 @@ class AnthropicLoop(BaseLoop):
|
||||
# Create new turn directory for this API call
|
||||
self._create_turn_dir()
|
||||
|
||||
# Make API call
|
||||
# Use _make_api_call instead of direct client call to ensure logging
|
||||
response = await self._make_api_call(prepared_messages)
|
||||
|
||||
# Handle the response
|
||||
@@ -287,6 +292,11 @@ class AnthropicLoop(BaseLoop):
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
if self.client is None:
|
||||
raise RuntimeError("Client not initialized. Call initialize_client() first.")
|
||||
if self.tool_manager is None:
|
||||
raise RuntimeError("Tool manager not initialized. Call initialize_client() first.")
|
||||
|
||||
last_error = None
|
||||
|
||||
for attempt in range(self.max_retries):
|
||||
@@ -297,6 +307,7 @@ class AnthropicLoop(BaseLoop):
|
||||
"max_tokens": self.max_tokens,
|
||||
"system": SYSTEM_PROMPT,
|
||||
}
|
||||
# Let ExperimentManager handle sanitization
|
||||
self._log_api_call("request", request_data)
|
||||
|
||||
# Setup betas and system
|
||||
@@ -320,7 +331,7 @@ class AnthropicLoop(BaseLoop):
|
||||
betas=betas,
|
||||
)
|
||||
|
||||
# Log success response
|
||||
# Let ExperimentManager handle sanitization
|
||||
self._log_api_call("response", request_data, response)
|
||||
|
||||
return response
|
||||
@@ -365,25 +376,38 @@ class AnthropicLoop(BaseLoop):
|
||||
}
|
||||
)
|
||||
|
||||
if self.callback_manager is None:
|
||||
raise RuntimeError(
|
||||
"Callback manager not initialized. Call initialize_client() first."
|
||||
)
|
||||
|
||||
# Handle tool use blocks and collect results
|
||||
tool_result_content = []
|
||||
for content_block in response_params:
|
||||
# Notify callback of content
|
||||
self.callback_manager.on_content(content_block)
|
||||
self.callback_manager.on_content(cast(BetaContentBlockParam, content_block))
|
||||
|
||||
# Handle tool use
|
||||
if content_block.get("type") == "tool_use":
|
||||
if self.tool_manager is None:
|
||||
raise RuntimeError(
|
||||
"Tool manager not initialized. Call initialize_client() first."
|
||||
)
|
||||
result = await self.tool_manager.execute_tool(
|
||||
name=content_block["name"],
|
||||
tool_input=cast(Dict[str, Any], content_block["input"]),
|
||||
)
|
||||
|
||||
# Create tool result and add to content
|
||||
tool_result = self._make_tool_result(result, content_block["id"])
|
||||
tool_result = self._make_tool_result(
|
||||
cast(ToolResult, result), content_block["id"]
|
||||
)
|
||||
tool_result_content.append(tool_result)
|
||||
|
||||
# Notify callback of tool result
|
||||
self.callback_manager.on_tool_result(result, content_block["id"])
|
||||
self.callback_manager.on_tool_result(
|
||||
cast(ToolResult, result), content_block["id"]
|
||||
)
|
||||
|
||||
# If no tool results, we're done
|
||||
if not tool_result_content:
|
||||
@@ -495,13 +519,13 @@ class AnthropicLoop(BaseLoop):
|
||||
result_text = f"<s>{result.system}</s>\n{result_text}"
|
||||
return result_text
|
||||
|
||||
def _handle_content(self, content: Dict[str, Any]) -> None:
|
||||
def _handle_content(self, content: BetaContentBlockParam) -> None:
|
||||
"""Handle content updates from the assistant."""
|
||||
if content.get("type") == "text":
|
||||
text = content.get("text", "")
|
||||
text_content = cast(BetaTextBlockParam, content)
|
||||
text = text_content["text"]
|
||||
if text == "<DONE>":
|
||||
return
|
||||
|
||||
logger.info(f"Assistant: {text}")
|
||||
|
||||
def _handle_tool_result(self, result: ToolResult, tool_id: str) -> None:
|
||||
@@ -517,5 +541,10 @@ class AnthropicLoop(BaseLoop):
|
||||
"""Handle API interactions."""
|
||||
if error:
|
||||
logger.error(f"API error: {error}")
|
||||
self._log_api_call("error", request, error=error)
|
||||
else:
|
||||
logger.debug(f"API request: {request}")
|
||||
if response:
|
||||
self._log_api_call("response", request, response)
|
||||
else:
|
||||
self._log_api_call("request", request)
|
||||
|
||||
@@ -90,7 +90,9 @@ class MessageManager:
|
||||
blocks_with_cache_control += 1
|
||||
# Add cache control to the last content block only
|
||||
if content and len(content) > 0:
|
||||
content[-1]["cache_control"] = {"type": "ephemeral"}
|
||||
content[-1]["cache_control"] = BetaCacheControlEphemeralParam(
|
||||
type="ephemeral"
|
||||
)
|
||||
else:
|
||||
# Remove any existing cache control
|
||||
if content and len(content) > 0:
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Any, Dict
|
||||
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
|
||||
from ....core.tools.base import BaseTool, ToolError, ToolResult, ToolFailure, CLIResult
|
||||
from ....core.tools.base import BaseTool
|
||||
|
||||
|
||||
class BaseAnthropicTool(BaseTool, metaclass=ABCMeta):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Collection classes for managing multiple tools."""
|
||||
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
|
||||
@@ -22,7 +22,7 @@ class ToolCollection:
|
||||
def to_params(
|
||||
self,
|
||||
) -> list[BetaToolUnionParam]:
|
||||
return [tool.to_params() for tool in self.tools]
|
||||
return cast(list[BetaToolUnionParam], [tool.to_params() for tool in self.tools])
|
||||
|
||||
async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
|
||||
tool = self.tool_map.get(name)
|
||||
|
||||
@@ -61,9 +61,9 @@ class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
|
||||
name: Literal["computer"] = "computer"
|
||||
api_type: Literal["computer_20250124"] = "computer_20250124"
|
||||
width: int | None
|
||||
height: int | None
|
||||
display_num: int | None
|
||||
width: int | None = None
|
||||
height: int | None = None
|
||||
display_num: int | None = None
|
||||
computer: Computer # The CUA Computer instance
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -106,6 +106,7 @@ class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
display_size = await self.computer.interface.get_screen_size()
|
||||
self.width = display_size["width"]
|
||||
self.height = display_size["height"]
|
||||
assert isinstance(self.width, int) and isinstance(self.height, int)
|
||||
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
|
||||
|
||||
async def __call__(
|
||||
@@ -120,6 +121,8 @@ class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
# Ensure dimensions are initialized
|
||||
if self.width is None or self.height is None:
|
||||
await self.initialize_dimensions()
|
||||
if self.width is None or self.height is None:
|
||||
raise ToolError("Failed to initialize screen dimensions")
|
||||
except Exception as e:
|
||||
raise ToolError(f"Failed to initialize dimensions: {e}")
|
||||
|
||||
@@ -147,7 +150,10 @@ class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
self.logger.info(
|
||||
f"Scaling image from {pre_img.size} to {self.width}x{self.height} to match screen dimensions"
|
||||
)
|
||||
pre_img = pre_img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
if not isinstance(self.width, int) or not isinstance(self.height, int):
|
||||
raise ToolError("Screen dimensions must be integers")
|
||||
size = (int(self.width), int(self.height))
|
||||
pre_img = pre_img.resize(size, Image.Resampling.LANCZOS)
|
||||
|
||||
self.logger.info(f" Current dimensions: {pre_img.width}x{pre_img.height}")
|
||||
|
||||
@@ -160,15 +166,7 @@ class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
# Then perform drag operation - check if drag_to exists or we need to use other methods
|
||||
try:
|
||||
if hasattr(self.computer.interface, "drag_to"):
|
||||
await self.computer.interface.drag_to(x, y)
|
||||
else:
|
||||
# Alternative approach: press mouse down, move, release
|
||||
await self.computer.interface.mouse_down()
|
||||
await asyncio.sleep(0.2)
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
await asyncio.sleep(0.2)
|
||||
await self.computer.interface.mouse_up()
|
||||
await self.computer.interface.drag_to(x, y)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during drag operation: {str(e)}")
|
||||
raise ToolError(f"Failed to perform drag: {str(e)}")
|
||||
@@ -214,9 +212,10 @@ class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
self.logger.info(
|
||||
f"Scaling image from {pre_img.size} to {self.width}x{self.height} to match screen dimensions"
|
||||
)
|
||||
pre_img = pre_img.resize(
|
||||
(self.width, self.height), Image.Resampling.LANCZOS
|
||||
)
|
||||
if not isinstance(self.width, int) or not isinstance(self.height, int):
|
||||
raise ToolError("Screen dimensions must be integers")
|
||||
size = (int(self.width), int(self.height))
|
||||
pre_img = pre_img.resize(size, Image.Resampling.LANCZOS)
|
||||
# Save the scaled image back to bytes
|
||||
buffer = io.BytesIO()
|
||||
pre_img.save(buffer, format="PNG")
|
||||
@@ -275,9 +274,10 @@ class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
self.logger.info(
|
||||
f"Scaling image from {pre_img.size} to {self.width}x{self.height}"
|
||||
)
|
||||
pre_img = pre_img.resize(
|
||||
(self.width, self.height), Image.Resampling.LANCZOS
|
||||
)
|
||||
if not isinstance(self.width, int) or not isinstance(self.height, int):
|
||||
raise ToolError("Screen dimensions must be integers")
|
||||
size = (int(self.width), int(self.height))
|
||||
pre_img = pre_img.resize(size, Image.Resampling.LANCZOS)
|
||||
|
||||
# Perform the click action
|
||||
if action == "left_click":
|
||||
@@ -335,7 +335,10 @@ class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
self.logger.info(
|
||||
f"Scaling image from {pre_img.size} to {self.width}x{self.height}"
|
||||
)
|
||||
pre_img = pre_img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
if not isinstance(self.width, int) or not isinstance(self.height, int):
|
||||
raise ToolError("Screen dimensions must be integers")
|
||||
size = (int(self.width), int(self.height))
|
||||
pre_img = pre_img.resize(size, Image.Resampling.LANCZOS)
|
||||
|
||||
if action == "key":
|
||||
# Special handling for page up/down on macOS
|
||||
@@ -365,7 +368,7 @@ class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
# Handle single key press
|
||||
self.logger.info(f"Pressing key: {text}")
|
||||
try:
|
||||
await self.computer.interface.press(text)
|
||||
await self.computer.interface.press_key(text)
|
||||
output_text = text
|
||||
except ValueError as e:
|
||||
raise ToolError(f"Invalid key: {text}. {str(e)}")
|
||||
@@ -442,7 +445,10 @@ class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
self.logger.info(
|
||||
f"Scaling image from {img.size} to {self.width}x{self.height}"
|
||||
)
|
||||
img = img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
if not isinstance(self.width, int) or not isinstance(self.height, int):
|
||||
raise ToolError("Screen dimensions must be integers")
|
||||
size = (int(self.width), int(self.height))
|
||||
img = img.resize(size, Image.Resampling.LANCZOS)
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
screenshot = buffer.getvalue()
|
||||
@@ -451,7 +457,8 @@ class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
|
||||
elif action == "cursor_position":
|
||||
pos = await self.computer.interface.get_cursor_position()
|
||||
return ToolResult(output=f"X={int(pos[0])},Y={int(pos[1])}")
|
||||
x, y = pos # Unpack the tuple
|
||||
return ToolResult(output=f"X={int(x)},Y={int(y)}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error during {action} action: {str(e)}")
|
||||
@@ -517,7 +524,10 @@ class ComputerTool(BaseComputerTool, BaseAnthropicTool):
|
||||
# Scale image if needed
|
||||
if img.size != (self.width, self.height):
|
||||
self.logger.info(f"Scaling image from {img.size} to {self.width}x{self.height}")
|
||||
img = img.resize((self.width, self.height), Image.Resampling.LANCZOS)
|
||||
if not isinstance(self.width, int) or not isinstance(self.height, int):
|
||||
raise ToolError("Screen dimensions must be integers")
|
||||
size = (int(self.width), int(self.height))
|
||||
img = img.resize(size, Image.Resampling.LANCZOS)
|
||||
buffer = io.BytesIO()
|
||||
img.save(buffer, format="PNG")
|
||||
screenshot = buffer.getvalue()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, cast
|
||||
from anthropic.types.beta import BetaToolUnionParam
|
||||
from computer.computer import Computer
|
||||
|
||||
@@ -37,7 +37,7 @@ class ToolManager(BaseToolManager):
|
||||
"""Get tool parameters for Anthropic API calls."""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
return self.tools.to_params()
|
||||
return cast(List[BetaToolUnionParam], self.tools.to_params())
|
||||
|
||||
async def execute_tool(self, name: str, tool_input: dict[str, Any]) -> ToolResult:
|
||||
"""Execute a tool with the given input.
|
||||
|
||||
@@ -126,15 +126,18 @@ class ExperimentManager:
|
||||
# Since we no longer want to use the images/ folder, we'll skip this functionality
|
||||
return
|
||||
|
||||
def save_screenshot(self, img_base64: str, action_type: str = "") -> None:
|
||||
def save_screenshot(self, img_base64: str, action_type: str = "") -> Optional[str]:
|
||||
"""Save a screenshot to the experiment directory.
|
||||
|
||||
Args:
|
||||
img_base64: Base64 encoded screenshot
|
||||
action_type: Type of action that triggered the screenshot
|
||||
|
||||
Returns:
|
||||
Optional[str]: Path to the saved screenshot, or None if saving failed
|
||||
"""
|
||||
if not self.current_turn_dir:
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
# Increment screenshot counter
|
||||
|
||||
@@ -13,6 +13,7 @@ import asyncio
|
||||
from httpx import ConnectError, ReadTimeout
|
||||
import shutil
|
||||
import copy
|
||||
from typing import cast
|
||||
|
||||
from .parser import OmniParser, ParseResult, ParserMetadata, UIElement
|
||||
from ...core.loop import BaseLoop
|
||||
@@ -182,8 +183,6 @@ class OmniLoop(BaseLoop):
|
||||
|
||||
if self.provider == LLMProvider.OPENAI:
|
||||
self.client = OpenAIClient(api_key=self.api_key, model=self.model)
|
||||
elif self.provider == LLMProvider.GROQ:
|
||||
self.client = GroqClient(api_key=self.api_key, model=self.model)
|
||||
elif self.provider == LLMProvider.ANTHROPIC:
|
||||
self.client = AnthropicClient(
|
||||
api_key=self.api_key,
|
||||
@@ -329,10 +328,15 @@ class OmniLoop(BaseLoop):
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
async def _handle_response(
|
||||
self, response: Any, messages: List[Dict[str, Any]], parsed_screen: Dict[str, Any]
|
||||
self, response: Any, messages: List[Dict[str, Any]], parsed_screen: ParseResult
|
||||
) -> Tuple[bool, bool]:
|
||||
"""Handle API response.
|
||||
|
||||
Args:
|
||||
response: API response
|
||||
messages: List of messages to update
|
||||
parsed_screen: Current parsed screen information
|
||||
|
||||
Returns:
|
||||
Tuple of (should_continue, action_screenshot_saved)
|
||||
"""
|
||||
@@ -394,7 +398,9 @@ class OmniLoop(BaseLoop):
|
||||
|
||||
try:
|
||||
# Execute action with current parsed screen info
|
||||
await self._execute_action(parsed_content, parsed_screen)
|
||||
await self._execute_action(
|
||||
parsed_content, cast(ParseResult, parsed_screen)
|
||||
)
|
||||
action_screenshot_saved = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing action: {str(e)}")
|
||||
@@ -463,7 +469,7 @@ class OmniLoop(BaseLoop):
|
||||
|
||||
try:
|
||||
# Execute action with current parsed screen info
|
||||
await self._execute_action(parsed_content, parsed_screen)
|
||||
await self._execute_action(parsed_content, cast(ParseResult, parsed_screen))
|
||||
action_screenshot_saved = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing action: {str(e)}")
|
||||
@@ -488,7 +494,7 @@ class OmniLoop(BaseLoop):
|
||||
|
||||
try:
|
||||
# Execute action with current parsed screen info
|
||||
await self._execute_action(content, parsed_screen)
|
||||
await self._execute_action(content, cast(ParseResult, parsed_screen))
|
||||
action_screenshot_saved = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing action: {str(e)}")
|
||||
|
||||
@@ -122,8 +122,9 @@ class OmniParser:
|
||||
# Create a minimal valid result for error cases
|
||||
return ParseResult(
|
||||
elements=[],
|
||||
screen_info=None,
|
||||
annotated_image_base64="",
|
||||
parsed_content_list=[f"Error: {str(e)}"],
|
||||
parsed_content_list=[{"error": str(e)}],
|
||||
metadata=ParserMetadata(
|
||||
image_size=(0, 0),
|
||||
num_icons=0,
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
from .bash import OmniBashTool
|
||||
from .computer import OmniComputerTool
|
||||
from .edit import OmniEditTool
|
||||
from .manager import OmniToolManager
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -177,7 +177,7 @@ class OmniComputerTool(BaseComputerTool):
|
||||
keys = text.split("+")
|
||||
await self.computer.interface.hotkey(*keys)
|
||||
else:
|
||||
await self.computer.interface.press(text)
|
||||
await self.computer.interface.press_key(text)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
@@ -188,7 +188,8 @@ class OmniComputerTool(BaseComputerTool):
|
||||
)
|
||||
elif action == "cursor_position":
|
||||
pos = await self.computer.interface.get_cursor_position()
|
||||
return ToolResult(output=f"X={int(pos[0])},Y={int(pos[1])}")
|
||||
x, y = pos
|
||||
return ToolResult(output=f"X={int(x)},Y={int(y)}")
|
||||
elif action == "scroll":
|
||||
if direction == "down":
|
||||
self.logger.info(f"Scrolling down, amount: {amount}")
|
||||
|
||||
@@ -10,7 +10,6 @@ from ....core.tools.collection import ToolCollection
|
||||
|
||||
from .bash import OmniBashTool
|
||||
from .computer import OmniComputerTool
|
||||
from .edit import OmniEditTool
|
||||
|
||||
|
||||
class ProviderType(Enum):
|
||||
@@ -35,11 +34,10 @@ class OmniToolManager(BaseToolManager):
|
||||
# Initialize tools
|
||||
self.computer_tool = OmniComputerTool(self.computer)
|
||||
self.bash_tool = OmniBashTool(self.computer)
|
||||
self.edit_tool = OmniEditTool(self.computer)
|
||||
|
||||
def _initialize_tools(self) -> ToolCollection:
|
||||
"""Initialize all available tools."""
|
||||
return ToolCollection(self.computer_tool, self.bash_tool, self.edit_tool)
|
||||
return ToolCollection(self.computer_tool, self.bash_tool)
|
||||
|
||||
async def _initialize_tools_specific(self) -> None:
|
||||
"""Initialize provider-specific tool requirements."""
|
||||
|
||||
@@ -96,7 +96,7 @@ def compress_image_base64(
|
||||
# Resize image
|
||||
new_width = int(img.width * scale_factor)
|
||||
new_height = int(img.height * scale_factor)
|
||||
current_img = img.resize((new_width, new_height), Image.LANCZOS)
|
||||
current_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Try with reduced size and quality
|
||||
buffer = io.BytesIO()
|
||||
@@ -130,7 +130,9 @@ def compress_image_base64(
|
||||
|
||||
# Last resort: Use minimum quality and size
|
||||
buffer = io.BytesIO()
|
||||
smallest_img = img.resize((int(img.width * 0.5), int(img.height * 0.5)), Image.LANCZOS)
|
||||
smallest_img = img.resize(
|
||||
(int(img.width * 0.5), int(img.height * 0.5)), Image.Resampling.LANCZOS
|
||||
)
|
||||
# Convert to RGB if necessary
|
||||
if smallest_img.mode in ("RGBA", "LA") or (
|
||||
smallest_img.mode == "P" and "transparency" in smallest_img.info
|
||||
|
||||
@@ -1,23 +1,20 @@
|
||||
"""Type definitions for the agent package."""
|
||||
|
||||
from .base import Provider, HostConfig, TaskResult, Annotation
|
||||
from .base import HostConfig, TaskResult, Annotation
|
||||
from .messages import Message, Request, Response, StepMessage, DisengageMessage
|
||||
from .tools import ToolInvocation, ToolInvocationState, ClientAttachment, ToolResult
|
||||
|
||||
__all__ = [
|
||||
# Base types
|
||||
"Provider",
|
||||
"HostConfig",
|
||||
"TaskResult",
|
||||
"Annotation",
|
||||
|
||||
# Message types
|
||||
"Message",
|
||||
"Request",
|
||||
"Response",
|
||||
"StepMessage",
|
||||
"DisengageMessage",
|
||||
|
||||
# Tool types
|
||||
"ToolInvocation",
|
||||
"ToolInvocationState",
|
||||
|
||||
@@ -5,17 +5,6 @@ from typing import Dict, Any
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class Provider(str, Enum):
|
||||
"""Available AI providers."""
|
||||
|
||||
UNKNOWN = "unknown" # Default provider for base class
|
||||
ANTHROPIC = "anthropic"
|
||||
OPENAI = "openai"
|
||||
OLLAMA = "ollama"
|
||||
OMNI = "omni"
|
||||
GROQ = "groq"
|
||||
|
||||
|
||||
class HostConfig(BaseModel):
|
||||
"""Host configuration."""
|
||||
|
||||
@@ -48,6 +37,5 @@ class AgentLoop(Enum):
|
||||
"""Enumeration of available loop types."""
|
||||
|
||||
ANTHROPIC = auto() # Anthropic implementation
|
||||
OPENAI = auto() # OpenAI implementation
|
||||
OMNI = auto() # OmniLoop implementation
|
||||
# Add more loop types as needed
|
||||
|
||||
Reference in New Issue
Block a user