diff --git a/libs/python/agent2/README.md b/libs/python/agent2/README.md new file mode 100644 index 00000000..31115835 --- /dev/null +++ b/libs/python/agent2/README.md @@ -0,0 +1,151 @@ +# Agent2 - Computer Use Agent + +**agent2** is a clean Computer-Use framework with liteLLM integration for running agentic workflows on macOS and Linux. + +## Key Features + +- **Docstring-based Tools**: Define tools using standard Python docstrings (no decorators needed) +- **Regex Model Matching**: Agent loops can match models using regex patterns +- **liteLLM Integration**: All completions use liteLLM's `.responses()` method +- **Streaming Support**: Built-in streaming with asyncio.Queue and cancellation support +- **Computer Tools**: Direct integration with computer interface for clicks, typing, etc. +- **Custom Tools**: Easy Python function tools with comprehensive docstrings + +## Install + +```bash +pip install "cua-agent2[all]" + +# or install specific providers +pip install "cua-agent2[anthropic]" # Anthropic support +pip install "cua-agent2[openai]" # OpenAI computer-use-preview support +``` + +## Usage + +### Define Tools + +```python +# No imports needed for tools - just define functions with comprehensive docstrings + +def read_file(location: str) -> str: + """Read contents of a file + + Parameters + ---------- + location : str + Path to the file to read + + Returns + ------- + str + Contents of the file + """ + with open(location, 'r') as f: + return f.read() + +def search_web(query: str) -> str: + """Search the web for information + + Parameters + ---------- + query : str + Search query to look for + + Returns + ------- + str + Search results + """ + return f"Search results for: {query}" +``` + +### Define Agent Loops + +```python +from agent2 import agent_loop +from agent2.types import Messages + +@agent_loop(models=r"claude-3.*", priority=10) +async def custom_claude_loop(messages: Messages, model: str, stream: bool = False, tools: Optional[List[Dict[str, Any]]] = None, **kwargs): + """Custom agent loop for Claude models.""" + # Map computer tools to Claude format + anthropic_tools = _prepare_tools_for_anthropic(tools) + + # Your custom logic here + response = await litellm.aresponses( + model=model, + messages=messages, + stream=stream, + tools=anthropic_tools, + **kwargs + ) + + if stream: + async for chunk in response: + yield chunk + else: + yield response + +@agent_loop(models=r"omni+.*", priority=10) +async def custom_omni_loop(messages: Messages, model: str, stream: bool = False, tools: Optional[List[Dict[str, Any]]] = None, **kwargs): + """Custom agent loop for Omni models.""" + # Map computer tools to Claude format + omni_tools, som_prompt = _prepare_tools_for_omni(tools) + + # Your custom logic here + response = await litellm.aresponses( + model=model.replace("omni+", ""), + messages=som_prompt, + stream=stream, + tools=omni_tools, + **kwargs + ) + + if stream: + async for chunk in response: + yield chunk + else: + yield response +``` + +### Use ComputerAgent + +```python +from agent2 import ComputerAgent +from computer import Computer + +async def main(): + with Computer() as computer: + agent = ComputerAgent( + model="claude-3-5-sonnet-20241022", + tools=[computer, read_file, search_web] + ) + + messages = [{"role": "user", "content": "Save a picture of a cat to my desktop."}] + + async for chunk in agent.run(messages, stream=True): + print(chunk) + + omni_agent = ComputerAgent( + model="omni+vertex_ai/gemini-pro", + tools=[computer, read_file, search_web] + ) + + messages = [{"role": "user", "content": "Save a picture of a cat to my desktop."}] + + async for chunk in omni_agent.run(messages, stream=True): + print(chunk) +``` + +## Supported Agent Loops + +- **Anthropic**: Claude models with computer use +- **Computer-Use-Preview**: OpenAI's computer use preview models + +## Architecture + +- Agent loops are automatically selected based on model regex matching +- Computer tools are mapped to model-specific schemas +- All completions use `litellm.responses()` for consistency +- Streaming is handled with asyncio.Queue for cancellation support diff --git a/libs/python/agent2/agent2/__init__.py b/libs/python/agent2/agent2/__init__.py new file mode 100644 index 00000000..7125beb3 --- /dev/null +++ b/libs/python/agent2/agent2/__init__.py @@ -0,0 +1,19 @@ +""" +Agent2 - Decorator-based Computer Use Agent with liteLLM integration +""" + +from .decorators import agent_loop +from .agent import ComputerAgent +from .types import Messages, AgentResponse + +# Import loops to register them +from . import loops + +__all__ = [ + "agent_loop", + "ComputerAgent", + "Messages", + "AgentResponse" +] + +__version__ = "0.1.0" diff --git a/libs/python/agent2/agent2/adapters/__init__.py b/libs/python/agent2/agent2/adapters/__init__.py new file mode 100644 index 00000000..c16120b6 --- /dev/null +++ b/libs/python/agent2/agent2/adapters/__init__.py @@ -0,0 +1,9 @@ +""" +Adapters package for agent2 - Custom LLM adapters for LiteLLM +""" + +from .huggingfacelocal_adapter import HuggingFaceLocalAdapter + +__all__ = [ + "HuggingFaceLocalAdapter", +] diff --git a/libs/python/agent2/agent2/adapters/huggingfacelocal_adapter.py b/libs/python/agent2/agent2/adapters/huggingfacelocal_adapter.py new file mode 100644 index 00000000..83352447 --- /dev/null +++ b/libs/python/agent2/agent2/adapters/huggingfacelocal_adapter.py @@ -0,0 +1,216 @@ +import asyncio +import warnings +from typing import Iterator, AsyncIterator, Dict, List, Any, Optional +import torch +from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor +from litellm.types.utils import GenericStreamingChunk, ModelResponse +from litellm import CustomLLM, completion, acompletion + + +class HuggingFaceLocalAdapter(CustomLLM): + """HuggingFace Local Adapter for running vision-language models locally.""" + + def __init__(self, device: str = "auto", **kwargs): + """Initialize the adapter. + + Args: + device: Device to load model on ("auto", "cuda", "cpu", etc.) + **kwargs: Additional arguments + """ + super().__init__() + self.device = device + self.models = {} # Cache for loaded models + self.processors = {} # Cache for loaded processors + + def _load_model_and_processor(self, model_name: str): + """Load model and processor if not already cached. + + Args: + model_name: Name of the model to load + + Returns: + Tuple of (model, processor) + """ + if model_name not in self.models: + # Load model + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_name, + torch_dtype=torch.float16, + device_map=self.device, + attn_implementation="sdpa" + ) + + # Load processor + processor = AutoProcessor.from_pretrained(model_name) + + # Cache them + self.models[model_name] = model + self.processors[model_name] = processor + + return self.models[model_name], self.processors[model_name] + + def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Convert OpenAI format messages to HuggingFace format. + + Args: + messages: Messages in OpenAI format + + Returns: + Messages in HuggingFace format + """ + converted_messages = [] + + for message in messages: + converted_message = { + "role": message["role"], + "content": [] + } + + content = message.get("content", []) + if isinstance(content, str): + # Simple text content + converted_message["content"].append({ + "type": "text", + "text": content + }) + elif isinstance(content, list): + # Multi-modal content + for item in content: + if item.get("type") == "text": + converted_message["content"].append({ + "type": "text", + "text": item.get("text", "") + }) + elif item.get("type") == "image_url": + # Convert image_url format to image format + image_url = item.get("image_url", {}).get("url", "") + converted_message["content"].append({ + "type": "image", + "image": image_url + }) + + converted_messages.append(converted_message) + + return converted_messages + + def _generate(self, **kwargs) -> str: + """Generate response using the local HuggingFace model. + + Args: + **kwargs: Keyword arguments containing messages and model info + + Returns: + Generated text response + """ + # Extract messages and model from kwargs + messages = kwargs.get('messages', []) + model_name = kwargs.get('model', 'ByteDance-Seed/UI-TARS-1.5-7B') + max_new_tokens = kwargs.get('max_tokens', 128) + + # Warn about ignored kwargs + ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'} + if ignored_kwargs: + warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}") + + # Load model and processor + model, processor = self._load_model_and_processor(model_name) + + # Convert messages to HuggingFace format + hf_messages = self._convert_messages(messages) + + # Apply chat template and tokenize + inputs = processor.apply_chat_template( + hf_messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt" + ) + + # Move inputs to the same device as model + if torch.cuda.is_available() and self.device != "cpu": + inputs = inputs.to("cuda") + + # Generate response + with torch.no_grad(): + generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) + + # Trim input tokens from output + generated_ids_trimmed = [ + out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + + # Decode output + output_text = processor.batch_decode( + generated_ids_trimmed, + skip_special_tokens=True, + clean_up_tokenization_spaces=False + ) + + return output_text[0] if output_text else "" + + def completion(self, *args, **kwargs) -> ModelResponse: + """Synchronous completion method. + + Returns: + ModelResponse with generated text + """ + generated_text = self._generate(**kwargs) + + return completion( + model=f"huggingface-local/{kwargs['model']}", + mock_response=generated_text, + ) + + async def acompletion(self, *args, **kwargs) -> ModelResponse: + """Asynchronous completion method. + + Returns: + ModelResponse with generated text + """ + # Run _generate in thread pool to avoid blocking + generated_text = await asyncio.to_thread(self._generate, **kwargs) + + return await acompletion( + model=f"huggingface-local/{kwargs['model']}", + mock_response=generated_text, + ) + + def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: + """Synchronous streaming method. + + Returns: + Iterator of GenericStreamingChunk + """ + generated_text = self._generate(**kwargs) + + generic_streaming_chunk: GenericStreamingChunk = { + "finish_reason": "stop", + "index": 0, + "is_finished": True, + "text": generated_text, + "tool_use": None, + "usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}, + } + + yield generic_streaming_chunk + + async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]: + """Asynchronous streaming method. + + Returns: + AsyncIterator of GenericStreamingChunk + """ + # Run _generate in thread pool to avoid blocking + generated_text = await asyncio.to_thread(self._generate, **kwargs) + + generic_streaming_chunk: GenericStreamingChunk = { + "finish_reason": "stop", + "index": 0, + "is_finished": True, + "text": generated_text, + "tool_use": None, + "usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}, + } + + yield generic_streaming_chunk \ No newline at end of file diff --git a/libs/python/agent2/agent2/agent.py b/libs/python/agent2/agent2/agent.py new file mode 100644 index 00000000..e14669a7 --- /dev/null +++ b/libs/python/agent2/agent2/agent.py @@ -0,0 +1,564 @@ +""" +ComputerAgent - Main agent class that selects and runs agent loops +""" + +import asyncio +from typing import Dict, List, Any, Optional, AsyncGenerator, Union, cast, Callable, Set + +from litellm.responses.utils import Usage +from .types import Messages, Computer +from .decorators import find_agent_loop +from .computer_handler import OpenAIComputerHandler, acknowledge_safety_check_callback, check_blocklisted_url +import json +import litellm +import litellm.utils +import inspect +from .adapters import HuggingFaceLocalAdapter +from .callbacks import ImageRetentionCallback, LoggingCallback, TrajectorySaverCallback, BudgetManagerCallback + +def get_json(obj: Any, max_depth: int = 10) -> Any: + def custom_serializer(o: Any, depth: int = 0, seen: Set[int] = None) -> Any: + if seen is None: + seen = set() + + # Use model_dump() if available + if hasattr(o, 'model_dump'): + return o.model_dump() + + # Check depth limit + if depth > max_depth: + return f"" + + # Check for circular references using object id + obj_id = id(o) + if obj_id in seen: + return f"" + + # Handle Computer objects + if hasattr(o, '__class__') and 'computer' in getattr(o, '__class__').__name__.lower(): + return f"" + + # Handle objects with __dict__ + if hasattr(o, '__dict__'): + seen.add(obj_id) + try: + result = {} + for k, v in o.__dict__.items(): + if v is not None: + # Recursively serialize with updated depth and seen set + serialized_value = custom_serializer(v, depth + 1, seen.copy()) + result[k] = serialized_value + return result + finally: + seen.discard(obj_id) + + # Handle common types that might contain nested objects + elif isinstance(o, dict): + seen.add(obj_id) + try: + return { + k: custom_serializer(v, depth + 1, seen.copy()) + for k, v in o.items() + if v is not None + } + finally: + seen.discard(obj_id) + + elif isinstance(o, (list, tuple, set)): + seen.add(obj_id) + try: + return [ + custom_serializer(item, depth + 1, seen.copy()) + for item in o + if item is not None + ] + finally: + seen.discard(obj_id) + + # For basic types that json.dumps can handle + elif isinstance(o, (str, int, float, bool)) or o is None: + return o + + # Fallback to string representation + else: + return str(o) + + def remove_nones(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: remove_nones(v) for k, v in obj.items() if v is not None} + elif isinstance(obj, list): + return [remove_nones(item) for item in obj if item is not None] + return obj + + # Serialize with circular reference and depth protection + serialized = custom_serializer(obj) + + # Convert to JSON string and back to ensure JSON compatibility + json_str = json.dumps(serialized) + parsed = json.loads(json_str) + + # Final cleanup of any remaining None values + return remove_nones(parsed) + +def sanitize_message(msg: Any) -> Any: + """Return a copy of the message with image_url omitted for computer_call_output messages.""" + if msg.get("type") == "computer_call_output": + output = msg.get("output", {}) + if isinstance(output, dict): + sanitized = msg.copy() + sanitized["output"] = {**output, "image_url": "[omitted]"} + return sanitized + return msg + +class ComputerAgent: + """ + Main agent class that automatically selects the appropriate agent loop + based on the model and executes tool calls. + """ + + def __init__( + self, + model: str, + tools: Optional[List[Any]] = None, + custom_loop: Optional[Callable] = None, + only_n_most_recent_images: Optional[int] = None, + callbacks: Optional[List[Any]] = None, + verbosity: Optional[int] = None, + trajectory_dir: Optional[str] = None, + max_retries: Optional[int] = 3, + screenshot_delay: Optional[float | int] = 0.5, + use_prompt_caching: Optional[bool] = False, + max_trajectory_budget: Optional[float | dict] = None, + **kwargs + ): + """ + Initialize ComputerAgent. + + Args: + model: Model name (e.g., "claude-3-5-sonnet-20241022", "computer-use-preview", "omni+vertex_ai/gemini-pro") + tools: List of tools (computer objects, decorated functions, etc.) + custom_loop: Custom agent loop function to use instead of auto-selection + only_n_most_recent_images: If set, only keep the N most recent images in message history. Adds ImageRetentionCallback automatically. + callbacks: List of AsyncCallbackHandler instances for preprocessing/postprocessing + verbosity: Logging level (logging.DEBUG, logging.INFO, etc.). If set, adds LoggingCallback automatically + trajectory_dir: If set, saves trajectory data (screenshots, responses) to this directory. Adds TrajectorySaverCallback automatically. + max_retries: Maximum number of retries for failed API calls + screenshot_delay: Delay before screenshots in seconds + use_prompt_caching: If set, use prompt caching to avoid reprocessing the same prompt. Intended for use with anthropic providers. + max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded + **kwargs: Additional arguments passed to the agent loop + """ + self.model = model + self.tools = tools or [] + self.custom_loop = custom_loop + self.only_n_most_recent_images = only_n_most_recent_images + self.callbacks = callbacks or [] + self.verbosity = verbosity + self.trajectory_dir = trajectory_dir + self.max_retries = max_retries + self.screenshot_delay = screenshot_delay + self.use_prompt_caching = use_prompt_caching + self.kwargs = kwargs + + # == Add built-in callbacks == + + # Add logging callback if verbosity is set + if self.verbosity is not None: + self.callbacks.append(LoggingCallback(level=self.verbosity)) + + # Add image retention callback if only_n_most_recent_images is set + if self.only_n_most_recent_images: + self.callbacks.append(ImageRetentionCallback(self.only_n_most_recent_images)) + + # Add trajectory saver callback if trajectory_dir is set + if self.trajectory_dir: + self.callbacks.append(TrajectorySaverCallback(self.trajectory_dir)) + + # Add budget manager if max_trajectory_budget is set + if max_trajectory_budget: + if isinstance(max_trajectory_budget, dict): + self.callbacks.append(BudgetManagerCallback(**max_trajectory_budget)) + else: + self.callbacks.append(BudgetManagerCallback(max_trajectory_budget)) + + # == Enable local model providers w/ LiteLLM == + + # Register local model providers + hf_adapter = HuggingFaceLocalAdapter( + device="auto" + ) + litellm.custom_provider_map = [ + {"provider": "huggingface-local", "custom_handler": hf_adapter} + ] + + # == Initialize computer agent == + + # Find the appropriate agent loop + if custom_loop: + self.agent_loop = custom_loop + self.agent_loop_info = None + else: + loop_info = find_agent_loop(model) + if not loop_info: + raise ValueError(f"No agent loop found for model: {model}") + self.agent_loop = loop_info.func + self.agent_loop_info = loop_info + + # Process tools and create tool schemas + self.tool_schemas = self._process_tools() + + # Find computer tool and create interface adapter + computer_handler = None + for schema in self.tool_schemas: + if schema["type"] == "computer": + computer_handler = OpenAIComputerHandler(schema["computer"].interface) + break + self.computer_handler = computer_handler + + def _process_input(self, input: Messages) -> List[Dict[str, Any]]: + """Process input messages and create schemas for the agent loop""" + if isinstance(input, str): + return [{"role": "user", "content": input}] + return [get_json(msg) for msg in input] + + def _process_tools(self) -> List[Dict[str, Any]]: + """Process tools and create schemas for the agent loop""" + schemas = [] + + for tool in self.tools: + # Check if it's a computer object (has interface attribute) + if hasattr(tool, 'interface'): + # This is a computer tool - will be handled by agent loop + schemas.append({ + "type": "computer", + "computer": tool + }) + elif callable(tool): + # Use litellm.utils.function_to_dict to extract schema from docstring + try: + function_schema = litellm.utils.function_to_dict(tool) + schemas.append({ + "type": "function", + "function": function_schema + }) + except Exception as e: + print(f"Warning: Could not process tool {tool}: {e}") + else: + print(f"Warning: Unknown tool type: {tool}") + + return schemas + + def _get_tool(self, name: str) -> Optional[Callable]: + """Get a tool by name""" + for tool in self.tools: + if hasattr(tool, '__name__') and tool.__name__ == name: + return tool + elif hasattr(tool, 'func') and tool.func.__name__ == name: + return tool + return None + + # ============================================================================ + # AGENT RUN LOOP LIFECYCLE HOOKS + # ============================================================================ + + async def _on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None: + """Initialize run tracking by calling callbacks.""" + for callback in self.callbacks: + if hasattr(callback, 'on_run_start'): + await callback.on_run_start(kwargs, old_items) + + async def _on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None: + """Finalize run tracking by calling callbacks.""" + for callback in self.callbacks: + if hasattr(callback, 'on_run_end'): + await callback.on_run_end(kwargs, old_items, new_items) + + async def _on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool: + """Check if run should continue by calling callbacks.""" + for callback in self.callbacks: + if hasattr(callback, 'on_run_continue'): + should_continue = await callback.on_run_continue(kwargs, old_items, new_items) + if not should_continue: + return False + return True + + async def _on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Prepare messages for the LLM call by applying callbacks.""" + result = messages + for callback in self.callbacks: + if hasattr(callback, 'on_llm_start'): + result = await callback.on_llm_start(result) + return result + + async def _on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Postprocess messages after the LLM call by applying callbacks.""" + result = messages + for callback in self.callbacks: + if hasattr(callback, 'on_llm_end'): + result = await callback.on_llm_end(result) + return result + + async def _on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None: + """Called when responses are received.""" + for callback in self.callbacks: + if hasattr(callback, 'on_responses'): + await callback.on_responses(get_json(kwargs), get_json(responses)) + + async def _on_computer_call_start(self, item: Dict[str, Any]) -> None: + """Called when a computer call is about to start.""" + for callback in self.callbacks: + if hasattr(callback, 'on_computer_call_start'): + await callback.on_computer_call_start(get_json(item)) + + async def _on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None: + """Called when a computer call has completed.""" + for callback in self.callbacks: + if hasattr(callback, 'on_computer_call_end'): + await callback.on_computer_call_end(get_json(item), get_json(result)) + + async def _on_function_call_start(self, item: Dict[str, Any]) -> None: + """Called when a function call is about to start.""" + for callback in self.callbacks: + if hasattr(callback, 'on_function_call_start'): + await callback.on_function_call_start(get_json(item)) + + async def _on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None: + """Called when a function call has completed.""" + for callback in self.callbacks: + if hasattr(callback, 'on_function_call_end'): + await callback.on_function_call_end(get_json(item), get_json(result)) + + async def _on_text(self, item: Dict[str, Any]) -> None: + """Called when a text message is encountered.""" + for callback in self.callbacks: + if hasattr(callback, 'on_text'): + await callback.on_text(get_json(item)) + + async def _on_api_start(self, kwargs: Dict[str, Any]) -> None: + """Called when an LLM API call is about to start.""" + for callback in self.callbacks: + if hasattr(callback, 'on_api_start'): + await callback.on_api_start(get_json(kwargs)) + + async def _on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None: + """Called when an LLM API call has completed.""" + for callback in self.callbacks: + if hasattr(callback, 'on_api_end'): + await callback.on_api_end(get_json(kwargs), get_json(result)) + + async def _on_usage(self, usage: Dict[str, Any]) -> None: + """Called when usage information is received.""" + for callback in self.callbacks: + if hasattr(callback, 'on_usage'): + await callback.on_usage(get_json(usage)) + + async def _on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None: + """Called when a screenshot is taken.""" + for callback in self.callbacks: + if hasattr(callback, 'on_screenshot'): + await callback.on_screenshot(screenshot, name) + + # ============================================================================ + # AGENT OUTPUT PROCESSING + # ============================================================================ + + async def _handle_item(self, item: Any, computer: Optional[Computer] = None) -> List[Dict[str, Any]]: + """Handle each item; may cause a computer action + screenshot.""" + + item_type = item.get("type", None) + + if item_type == "message": + await self._on_text(item) + # # Print messages + # if item.get("content"): + # for content_item in item.get("content"): + # if content_item.get("text"): + # print(content_item.get("text")) + return [] + + if item_type == "computer_call": + await self._on_computer_call_start(item) + if not computer: + raise ValueError("Computer handler is required for computer calls") + + # Perform computer actions + action = item.get("action") + action_type = action.get("type") + + # Extract action arguments (all fields except 'type') + action_args = {k: v for k, v in action.items() if k != "type"} + + # print(f"{action_type}({action_args})") + + # Execute the computer action + computer_method = getattr(computer, action_type, None) + if computer_method: + await computer_method(**action_args) + else: + print(f"Unknown computer action: {action_type}") + return [] + + # Take screenshot after action + if self.screenshot_delay and self.screenshot_delay > 0: + await asyncio.sleep(self.screenshot_delay) + screenshot_base64 = await computer.screenshot() + await self._on_screenshot(screenshot_base64, "screenshot_after") + + # Handle safety checks + pending_checks = item.get("pending_safety_checks", []) + acknowledged_checks = [] + for check in pending_checks: + check_message = check.get("message", str(check)) + if acknowledge_safety_check_callback(check_message): + acknowledged_checks.append(check) + else: + raise ValueError(f"Safety check failed: {check_message}") + + # Create call output + call_output = { + "type": "computer_call_output", + "call_id": item.get("call_id"), + "acknowledged_safety_checks": acknowledged_checks, + "output": { + "type": "input_image", + "image_url": f"data:image/png;base64,{screenshot_base64}", + }, + } + + # Additional URL safety checks for browser environments + if await computer.get_environment() == "browser": + current_url = await computer.get_current_url() + call_output["output"]["current_url"] = current_url + check_blocklisted_url(current_url) + + result = [call_output] + await self._on_computer_call_end(item, result) + return result + + if item_type == "function_call": + await self._on_function_call_start(item) + # Perform function call + function = self._get_tool(item.get("name")) + if not function: + raise ValueError(f"Function {item.get("name")} not found") + + args = json.loads(item.get("arguments")) + + # Execute function - use asyncio.to_thread for non-async functions + if inspect.iscoroutinefunction(function): + result = await function(**args) + else: + result = await asyncio.to_thread(function, **args) + + # Create function call output + call_output = { + "type": "function_call_output", + "call_id": item.get("call_id"), + "output": str(result), + } + + result = [call_output] + await self._on_function_call_end(item, result) + return result + + return [] + + # ============================================================================ + # MAIN AGENT LOOP + # ============================================================================ + + async def run( + self, + messages: Messages, + stream: bool = False, + **kwargs + ) -> AsyncGenerator[Dict[str, Any], None]: + """ + Run the agent with the given messages using Computer protocol handler pattern. + + Args: + messages: List of message dictionaries + stream: Whether to stream the response + **kwargs: Additional arguments + + Returns: + AsyncGenerator that yields response chunks + """ + # Merge kwargs + merged_kwargs = {**self.kwargs, **kwargs} + + old_items = self._process_input(messages) + new_items = [] + + # Initialize run tracking + run_kwargs = { + "messages": messages, + "stream": stream, + "model": self.model, + "agent_loop": self.agent_loop.__name__, + **merged_kwargs + } + await self._on_run_start(run_kwargs, old_items) + + while new_items[-1].get("role") != "assistant" if new_items else True: + # Lifecycle hook: Check if we should continue based on callbacks (e.g., budget manager) + should_continue = await self._on_run_continue(run_kwargs, old_items, new_items) + if not should_continue: + break + + # Lifecycle hook: Prepare messages for the LLM call + # Use cases: + # - PII anonymization + # - Image retention policy + combined_messages = old_items + new_items + preprocessed_messages = await self._on_llm_start(combined_messages) + + loop_kwargs = { + "messages": preprocessed_messages, + "model": self.model, + "tools": self.tool_schemas, + "stream": False, + "computer_handler": self.computer_handler, + "max_retries": self.max_retries, + "use_prompt_caching": self.use_prompt_caching, + **merged_kwargs + } + + # Run agent loop iteration + result = await self.agent_loop( + **loop_kwargs, + _on_api_start=self._on_api_start, + _on_api_end=self._on_api_end, + _on_usage=self._on_usage, + _on_screenshot=self._on_screenshot, + ) + result = get_json(result) + + # Lifecycle hook: Postprocess messages after the LLM call + # Use cases: + # - PII deanonymization (if you want tool calls to see PII) + result["output"] = await self._on_llm_end(result.get("output", [])) + await self._on_responses(loop_kwargs, result) + + # Yield agent response + yield result + + # Add agent response to new_items + new_items += result.get("output") + + # Handle computer actions + for item in result.get("output"): + partial_items = await self._handle_item(item, self.computer_handler) + new_items += partial_items + + # Yield partial response + yield { + "output": partial_items, + "usage": Usage( + prompt_tokens=0, + completion_tokens=0, + total_tokens=0, + ) + } + + await self._on_run_end(loop_kwargs, old_items, new_items) \ No newline at end of file diff --git a/libs/python/agent2/agent2/callbacks/__init__.py b/libs/python/agent2/agent2/callbacks/__init__.py new file mode 100644 index 00000000..6f364b1d --- /dev/null +++ b/libs/python/agent2/agent2/callbacks/__init__.py @@ -0,0 +1,17 @@ +""" +Callback system for ComputerAgent preprocessing and postprocessing hooks. +""" + +from .base import AsyncCallbackHandler +from .image_retention import ImageRetentionCallback +from .logging import LoggingCallback +from .trajectory_saver import TrajectorySaverCallback +from .budget_manager import BudgetManagerCallback + +__all__ = [ + "AsyncCallbackHandler", + "ImageRetentionCallback", + "LoggingCallback", + "TrajectorySaverCallback", + "BudgetManagerCallback", +] diff --git a/libs/python/agent2/agent2/callbacks/base.py b/libs/python/agent2/agent2/callbacks/base.py new file mode 100644 index 00000000..01688077 --- /dev/null +++ b/libs/python/agent2/agent2/callbacks/base.py @@ -0,0 +1,153 @@ +""" +Base callback handler interface for ComputerAgent preprocessing and postprocessing hooks. +""" + +from abc import ABC, abstractmethod +from typing import List, Dict, Any, Optional, Union + + +class AsyncCallbackHandler(ABC): + """ + Base class for async callback handlers that can preprocess messages before + the agent loop and postprocess output after the agent loop. + """ + + async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None: + """Called at the start of an agent run loop.""" + pass + + async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None: + """Called at the end of an agent run loop.""" + pass + + async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool: + """Called during agent run loop to determine if execution should continue. + + Args: + kwargs: Run arguments + old_items: Original messages + new_items: New messages generated during run + + Returns: + True to continue execution, False to stop + """ + return True + + async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Called before messages are sent to the agent loop. + + Args: + messages: List of message dictionaries to preprocess + + Returns: + List of preprocessed message dictionaries + """ + return messages + + async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Called after the agent loop returns output. + + Args: + output: List of output message dictionaries to postprocess + + Returns: + List of postprocessed output dictionaries + """ + return output + + async def on_computer_call_start(self, item: Dict[str, Any]) -> None: + """ + Called when a computer call is about to start. + + Args: + item: The computer call item dictionary + """ + pass + + async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None: + """ + Called when a computer call has completed. + + Args: + item: The computer call item dictionary + result: The result of the computer call + """ + pass + + async def on_function_call_start(self, item: Dict[str, Any]) -> None: + """ + Called when a function call is about to start. + + Args: + item: The function call item dictionary + """ + pass + + async def on_function_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None: + """ + Called when a function call has completed. + + Args: + item: The function call item dictionary + result: The result of the function call + """ + pass + + async def on_text(self, item: Dict[str, Any]) -> None: + """ + Called when a text message is encountered. + + Args: + item: The message item dictionary + """ + pass + + async def on_api_start(self, kwargs: Dict[str, Any]) -> None: + """ + Called when an API call is about to start. + + Args: + kwargs: The kwargs being passed to the API call + """ + pass + + async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None: + """ + Called when an API call has completed. + + Args: + kwargs: The kwargs that were passed to the API call + result: The result of the API call + """ + pass + + async def on_usage(self, usage: Dict[str, Any]) -> None: + """ + Called when usage information is received. + + Args: + usage: The usage information + """ + pass + + async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None: + """ + Called when a screenshot is taken. + + Args: + screenshot: The screenshot image + name: The name of the screenshot + """ + pass + + async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None: + """ + Called when responses are received. + + Args: + kwargs: The kwargs being passed to the agent loop + responses: The responses received + """ + pass \ No newline at end of file diff --git a/libs/python/agent2/agent2/callbacks/budget_manager.py b/libs/python/agent2/agent2/callbacks/budget_manager.py new file mode 100644 index 00000000..bc17c695 --- /dev/null +++ b/libs/python/agent2/agent2/callbacks/budget_manager.py @@ -0,0 +1,44 @@ +from typing import Dict, List, Any +from .base import AsyncCallbackHandler + +class BudgetExceededError(Exception): + """Exception raised when budget is exceeded.""" + pass + +class BudgetManagerCallback(AsyncCallbackHandler): + """Budget manager callback that tracks usage costs and can stop execution when budget is exceeded.""" + + def __init__(self, max_budget: float, reset_after_each_run: bool = True, raise_error: bool = False): + """ + Initialize BudgetManagerCallback. + + Args: + max_budget: Maximum budget allowed + reset_after_each_run: Whether to reset budget after each run + raise_error: Whether to raise an error when budget is exceeded + """ + self.max_budget = max_budget + self.reset_after_each_run = reset_after_each_run + self.raise_error = raise_error + self.total_cost = 0.0 + + async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None: + """Reset budget if configured to do so.""" + if self.reset_after_each_run: + self.total_cost = 0.0 + + async def on_usage(self, usage: Dict[str, Any]) -> None: + """Track usage costs.""" + if "response_cost" in usage: + self.total_cost += usage["response_cost"] + + async def on_run_continue(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> bool: + """Check if budget allows continuation.""" + if self.total_cost >= self.max_budget: + if self.raise_error: + raise BudgetExceededError(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}") + else: + print(f"Budget exceeded: ${self.total_cost} >= ${self.max_budget}") + return False + return True + \ No newline at end of file diff --git a/libs/python/agent2/agent2/callbacks/image_retention.py b/libs/python/agent2/agent2/callbacks/image_retention.py new file mode 100644 index 00000000..d91754b1 --- /dev/null +++ b/libs/python/agent2/agent2/callbacks/image_retention.py @@ -0,0 +1,139 @@ +""" +Image retention callback handler that limits the number of recent images in message history. +""" + +from typing import List, Dict, Any, Optional +from .base import AsyncCallbackHandler + + +class ImageRetentionCallback(AsyncCallbackHandler): + """ + Callback handler that applies image retention policy to limit the number + of recent images in message history to prevent context window overflow. + """ + + def __init__(self, only_n_most_recent_images: Optional[int] = None): + """ + Initialize the image retention callback. + + Args: + only_n_most_recent_images: If set, only keep the N most recent images in message history + """ + self.only_n_most_recent_images = only_n_most_recent_images + + async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Apply image retention policy to messages before sending to agent loop. + + Args: + messages: List of message dictionaries + + Returns: + List of messages with image retention policy applied + """ + if self.only_n_most_recent_images is None: + return messages + + return self._apply_image_retention(messages) + + def _apply_image_retention(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Apply image retention policy to keep only the N most recent images. + + Removes computer_call_output items with image_url and their corresponding computer_call items, + keeping only the most recent N image pairs based on only_n_most_recent_images setting. + + Args: + messages: List of message dictionaries + + Returns: + Filtered list of messages with image retention applied + """ + if self.only_n_most_recent_images is None: + return messages + + # First pass: Assign call_id to reasoning items based on the next computer_call + messages_with_call_ids = [] + for i, msg in enumerate(messages): + msg_copy = msg.copy() if isinstance(msg, dict) else msg + + # If this is a reasoning item without a call_id, find the next computer_call + if (msg_copy.get("type") == "reasoning" and + not msg_copy.get("call_id")): + # Look ahead for the next computer_call + for j in range(i + 1, len(messages)): + next_msg = messages[j] + if (next_msg.get("type") == "computer_call" and + next_msg.get("call_id")): + msg_copy["call_id"] = next_msg.get("call_id") + break + + messages_with_call_ids.append(msg_copy) + + # Find all computer_call_output items with images and their call_ids + image_call_ids = [] + for msg in reversed(messages_with_call_ids): # Process in reverse to get most recent first + if (msg.get("type") == "computer_call_output" and + isinstance(msg.get("output"), dict) and + "image_url" in msg.get("output", {})): + call_id = msg.get("call_id") + if call_id and call_id not in image_call_ids: + image_call_ids.append(call_id) + if len(image_call_ids) >= self.only_n_most_recent_images: + break + + # Keep the most recent N image call_ids (reverse to get chronological order) + keep_call_ids = set(image_call_ids[:self.only_n_most_recent_images]) + + # Filter messages: remove computer_call, computer_call_output, and reasoning for old images + filtered_messages = [] + for msg in messages_with_call_ids: + msg_type = msg.get("type") + call_id = msg.get("call_id") + + # Remove old computer_call items + if msg_type == "computer_call" and call_id not in keep_call_ids: + # Check if this call_id corresponds to an image call + has_image_output = any( + m.get("type") == "computer_call_output" and + m.get("call_id") == call_id and + isinstance(m.get("output"), dict) and + "image_url" in m.get("output", {}) + for m in messages_with_call_ids + ) + if has_image_output: + continue # Skip this computer_call + + # Remove old computer_call_output items with images + if (msg_type == "computer_call_output" and + call_id not in keep_call_ids and + isinstance(msg.get("output"), dict) and + "image_url" in msg.get("output", {})): + continue # Skip this computer_call_output + + # Remove old reasoning items that are paired with removed computer calls + if (msg_type == "reasoning" and + call_id and call_id not in keep_call_ids): + # Check if this call_id corresponds to an image call that's being removed + has_image_output = any( + m.get("type") == "computer_call_output" and + m.get("call_id") == call_id and + isinstance(m.get("output"), dict) and + "image_url" in m.get("output", {}) + for m in messages_with_call_ids + ) + if has_image_output: + continue # Skip this reasoning item + + filtered_messages.append(msg) + + # Clean up: Remove call_id from reasoning items before returning + final_messages = [] + for msg in filtered_messages: + if msg.get("type") == "reasoning" and "call_id" in msg: + # Create a copy without call_id for reasoning items + cleaned_msg = {k: v for k, v in msg.items() if k != "call_id"} + final_messages.append(cleaned_msg) + else: + final_messages.append(msg) + + return final_messages \ No newline at end of file diff --git a/libs/python/agent2/agent2/callbacks/logging.py b/libs/python/agent2/agent2/callbacks/logging.py new file mode 100644 index 00000000..5a64fd16 --- /dev/null +++ b/libs/python/agent2/agent2/callbacks/logging.py @@ -0,0 +1,247 @@ +""" +Logging callback for ComputerAgent that provides configurable logging of agent lifecycle events. +""" + +import json +import logging +from typing import Dict, List, Any, Optional, Union +from .base import AsyncCallbackHandler + + +def sanitize_image_urls(data: Any) -> Any: + """ + Recursively search for 'image_url' keys and set their values to '[omitted]'. + + Args: + data: Any data structure (dict, list, or primitive type) + + Returns: + A deep copy of the data with all 'image_url' values replaced with '[omitted]' + """ + if isinstance(data, dict): + # Create a copy of the dictionary + sanitized = {} + for key, value in data.items(): + if key == "image_url": + sanitized[key] = "[omitted]" + else: + # Recursively sanitize the value + sanitized[key] = sanitize_image_urls(value) + return sanitized + + elif isinstance(data, list): + # Recursively sanitize each item in the list + return [sanitize_image_urls(item) for item in data] + + else: + # For primitive types (str, int, bool, None, etc.), return as-is + return data + + +class LoggingCallback(AsyncCallbackHandler): + """ + Callback handler that logs agent lifecycle events with configurable verbosity. + + Logging levels: + - DEBUG: All events including API calls, message preprocessing, and detailed outputs + - INFO: Major lifecycle events (start/end, messages, outputs) + - WARNING: Only warnings and errors + - ERROR: Only errors + """ + + def __init__(self, logger: Optional[logging.Logger] = None, level: int = logging.INFO): + """ + Initialize the logging callback. + + Args: + logger: Logger instance to use. If None, creates a logger named 'agent2.ComputerAgent' + level: Logging level (logging.DEBUG, logging.INFO, etc.) + """ + self.logger = logger or logging.getLogger('agent2.ComputerAgent') + self.level = level + + # Set up logger if it doesn't have handlers + if not self.logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + ) + handler.setFormatter(formatter) + self.logger.addHandler(handler) + self.logger.setLevel(level) + + def _update_usage(self, usage: Dict[str, Any]) -> None: + """Update total usage statistics.""" + def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None: + for key, value in source.items(): + if isinstance(value, dict): + if key not in target: + target[key] = {} + add_dicts(target[key], value) + else: + if key not in target: + target[key] = 0 + target[key] += value + add_dicts(self.total_usage, usage) + + async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None: + """Called before the run starts.""" + self.total_usage = {} + + async def on_usage(self, usage: Dict[str, Any]) -> None: + """Called when usage information is received.""" + self._update_usage(usage) + + async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None: + """Called after the run ends.""" + def format_dict(d, indent=0): + lines = [] + prefix = f" - {' ' * indent}" + for key, value in d.items(): + if isinstance(value, dict): + lines.append(f"{prefix}{key}:") + lines.extend(format_dict(value, indent + 1)) + elif isinstance(value, float): + lines.append(f"{prefix}{key}: ${value:.4f}") + else: + lines.append(f"{prefix}{key}: {value}") + return lines + + formatted_output = "\n".join(format_dict(self.total_usage)) + self.logger.info(f"Total usage:\n{formatted_output}") + + async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Called before LLM processing starts.""" + if self.logger.isEnabledFor(logging.INFO): + self.logger.info(f"LLM processing started with {len(messages)} messages") + if self.logger.isEnabledFor(logging.DEBUG): + sanitized_messages = [sanitize_image_urls(msg) for msg in messages] + self.logger.debug(f"LLM input messages: {json.dumps(sanitized_messages, indent=2)}") + return messages + + async def on_llm_end(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Called after LLM processing ends.""" + if self.logger.isEnabledFor(logging.DEBUG): + sanitized_messages = [sanitize_image_urls(msg) for msg in messages] + self.logger.debug(f"LLM output: {json.dumps(sanitized_messages, indent=2)}") + return messages + + async def on_computer_call_start(self, item: Dict[str, Any]) -> None: + """Called when a computer call starts.""" + action = item.get("action", {}) + action_type = action.get("type", "unknown") + action_args = {k: v for k, v in action.items() if k != "type"} + + # INFO level logging for the action + self.logger.info(f"Computer: {action_type}({action_args})") + + # DEBUG level logging for full details + if self.logger.isEnabledFor(logging.DEBUG): + self.logger.debug(f"Computer call started: {json.dumps(action, indent=2)}") + + async def on_computer_call_end(self, item: Dict[str, Any], result: Any) -> None: + """Called when a computer call ends.""" + if self.logger.isEnabledFor(logging.DEBUG): + action = item.get("action", "unknown") + self.logger.debug(f"Computer call completed: {json.dumps(action, indent=2)}") + if result: + sanitized_result = sanitize_image_urls(result) + self.logger.debug(f"Computer call result: {json.dumps(sanitized_result, indent=2)}") + + async def on_function_call_start(self, item: Dict[str, Any]) -> None: + """Called when a function call starts.""" + name = item.get("name", "unknown") + arguments = item.get("arguments", "{}") + + # INFO level logging for the function call + self.logger.info(f"Function: {name}({arguments})") + + # DEBUG level logging for full details + if self.logger.isEnabledFor(logging.DEBUG): + self.logger.debug(f"Function call started: {name}") + + async def on_function_call_end(self, item: Dict[str, Any], result: Any) -> None: + """Called when a function call ends.""" + # INFO level logging for function output (similar to function_call_output) + if result: + # Handle both list and direct result formats + if isinstance(result, list) and len(result) > 0: + output = result[0].get("output", str(result)) if isinstance(result[0], dict) else str(result[0]) + else: + output = str(result) + + # Truncate long outputs + if len(output) > 100: + output = output[:100] + "..." + + self.logger.info(f"Output: {output}") + + # DEBUG level logging for full details + if self.logger.isEnabledFor(logging.DEBUG): + name = item.get("name", "unknown") + self.logger.debug(f"Function call completed: {name}") + if result: + self.logger.debug(f"Function call result: {json.dumps(result, indent=2)}") + + async def on_text(self, item: Dict[str, Any]) -> None: + """Called when a text message is encountered.""" + # Get the role to determine if it's Agent or User + role = item.get("role", "unknown") + content_items = item.get("content", []) + + # Process content items to build display text + text_parts = [] + for content_item in content_items: + content_type = content_item.get("type", "output_text") + if content_type == "output_text": + text_content = content_item.get("text", "") + if not text_content.strip(): + text_parts.append("[empty]") + else: + # Truncate long text and add ellipsis + if len(text_content) > 2048: + text_parts.append(text_content[:2048] + "...") + else: + text_parts.append(text_content) + else: + # Non-text content, show as [type] + text_parts.append(f"[{content_type}]") + + # Join all text parts + display_text = ''.join(text_parts) if text_parts else "[empty]" + + # Log with appropriate level and format + if role == "assistant": + self.logger.info(f"Agent: {display_text}") + elif role == "user": + self.logger.info(f"User: {display_text}") + else: + # Fallback for unknown roles, use debug level + if self.logger.isEnabledFor(logging.DEBUG): + self.logger.debug(f"Text message ({role}): {display_text}") + + async def on_api_start(self, kwargs: Dict[str, Any]) -> None: + """Called when an API call is about to start.""" + if self.logger.isEnabledFor(logging.DEBUG): + model = kwargs.get("model", "unknown") + self.logger.debug(f"API call starting for model: {model}") + # Log sanitized messages if present + if "messages" in kwargs: + sanitized_messages = sanitize_image_urls(kwargs["messages"]) + self.logger.debug(f"API call messages: {json.dumps(sanitized_messages, indent=2)}") + elif "input" in kwargs: + sanitized_input = sanitize_image_urls(kwargs["input"]) + self.logger.debug(f"API call input: {json.dumps(sanitized_input, indent=2)}") + + async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None: + """Called when an API call has completed.""" + if self.logger.isEnabledFor(logging.DEBUG): + model = kwargs.get("model", "unknown") + self.logger.debug(f"API call completed for model: {model}") + self.logger.debug(f"API call result: {json.dumps(sanitize_image_urls(result), indent=2)}") + + async def on_screenshot(self, item: Union[str, bytes], name: str = "screenshot") -> None: + """Called when a screenshot is taken.""" + if self.logger.isEnabledFor(logging.DEBUG): + image_size = len(item) / 1024 + self.logger.debug(f"Screenshot captured: {name} {image_size:.2f} KB") \ No newline at end of file diff --git a/libs/python/agent2/agent2/callbacks/pii_anonymization.py b/libs/python/agent2/agent2/callbacks/pii_anonymization.py new file mode 100644 index 00000000..f5c31a61 --- /dev/null +++ b/libs/python/agent2/agent2/callbacks/pii_anonymization.py @@ -0,0 +1,259 @@ +""" +PII anonymization callback handler using Microsoft Presidio for text and image redaction. +""" + +from typing import List, Dict, Any, Optional, Tuple +from .base import AsyncCallbackHandler +import base64 +import io +import logging + +try: + from presidio_analyzer import AnalyzerEngine + from presidio_anonymizer import AnonymizerEngine, DeanonymizeEngine + from presidio_anonymizer.entities import RecognizerResult, OperatorConfig + from presidio_image_redactor import ImageRedactorEngine + from PIL import Image + PRESIDIO_AVAILABLE = True +except ImportError: + PRESIDIO_AVAILABLE = False + +logger = logging.getLogger(__name__) + +class PIIAnonymizationCallback(AsyncCallbackHandler): + """ + Callback handler that anonymizes PII in text and images using Microsoft Presidio. + + This handler: + 1. Anonymizes PII in messages before sending to the agent loop + 2. Deanonymizes PII in tool calls and message outputs after the agent loop + 3. Redacts PII from images in computer_call_output messages + """ + + def __init__( + self, + anonymize_text: bool = True, + anonymize_images: bool = True, + entities_to_anonymize: Optional[List[str]] = None, + anonymization_operator: str = "replace", + image_redaction_color: Tuple[int, int, int] = (255, 192, 203) # Pink + ): + """ + Initialize the PII anonymization callback. + + Args: + anonymize_text: Whether to anonymize text content + anonymize_images: Whether to redact images + entities_to_anonymize: List of entity types to anonymize (None for all) + anonymization_operator: Presidio operator to use ("replace", "mask", "redact", etc.) + image_redaction_color: RGB color for image redaction + """ + if not PRESIDIO_AVAILABLE: + raise ImportError( + "Presidio is not available. Install with: " + "pip install presidio-analyzer presidio-anonymizer presidio-image-redactor" + ) + + self.anonymize_text = anonymize_text + self.anonymize_images = anonymize_images + self.entities_to_anonymize = entities_to_anonymize + self.anonymization_operator = anonymization_operator + self.image_redaction_color = image_redaction_color + + # Initialize Presidio engines + self.analyzer = AnalyzerEngine() + self.anonymizer = AnonymizerEngine() + self.deanonymizer = DeanonymizeEngine() + self.image_redactor = ImageRedactorEngine() + + # Store anonymization mappings for deanonymization + self.anonymization_mappings: Dict[str, Any] = {} + + async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Anonymize PII in messages before sending to agent loop. + + Args: + messages: List of message dictionaries + + Returns: + List of messages with PII anonymized + """ + if not self.anonymize_text and not self.anonymize_images: + return messages + + anonymized_messages = [] + for msg in messages: + anonymized_msg = await self._anonymize_message(msg) + anonymized_messages.append(anonymized_msg) + + return anonymized_messages + + async def on_llm_end(self, output: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Deanonymize PII in tool calls and message outputs after agent loop. + + Args: + output: List of output dictionaries + + Returns: + List of output with PII deanonymized for tool calls + """ + if not self.anonymize_text: + return output + + deanonymized_output = [] + for item in output: + # Only deanonymize tool calls and computer_call messages + if item.get("type") in ["computer_call", "computer_call_output"]: + deanonymized_item = await self._deanonymize_item(item) + deanonymized_output.append(deanonymized_item) + else: + deanonymized_output.append(item) + + return deanonymized_output + + async def _anonymize_message(self, message: Dict[str, Any]) -> Dict[str, Any]: + """Anonymize PII in a single message.""" + msg_copy = message.copy() + + # Anonymize text content + if self.anonymize_text: + msg_copy = await self._anonymize_text_content(msg_copy) + + # Redact images in computer_call_output + if self.anonymize_images and msg_copy.get("type") == "computer_call_output": + msg_copy = await self._redact_image_content(msg_copy) + + return msg_copy + + async def _anonymize_text_content(self, message: Dict[str, Any]) -> Dict[str, Any]: + """Anonymize text content in a message.""" + msg_copy = message.copy() + + # Handle content array + content = msg_copy.get("content", []) + if isinstance(content, str): + anonymized_text, _ = await self._anonymize_text(content) + msg_copy["content"] = anonymized_text + elif isinstance(content, list): + anonymized_content = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "text": + text = item.get("text", "") + anonymized_text, _ = await self._anonymize_text(text) + item_copy = item.copy() + item_copy["text"] = anonymized_text + anonymized_content.append(item_copy) + else: + anonymized_content.append(item) + msg_copy["content"] = anonymized_content + + return msg_copy + + async def _redact_image_content(self, message: Dict[str, Any]) -> Dict[str, Any]: + """Redact PII from images in computer_call_output messages.""" + msg_copy = message.copy() + output = msg_copy.get("output", {}) + + if isinstance(output, dict) and "image_url" in output: + try: + # Extract base64 image data + image_url = output["image_url"] + if image_url.startswith("data:image/"): + # Parse data URL + header, data = image_url.split(",", 1) + image_data = base64.b64decode(data) + + # Load image with PIL + image = Image.open(io.BytesIO(image_data)) + + # Redact PII from image + redacted_image = self.image_redactor.redact(image, self.image_redaction_color) + + # Convert back to base64 + buffer = io.BytesIO() + redacted_image.save(buffer, format="PNG") + redacted_data = base64.b64encode(buffer.getvalue()).decode() + + # Update image URL + output_copy = output.copy() + output_copy["image_url"] = f"data:image/png;base64,{redacted_data}" + msg_copy["output"] = output_copy + + except Exception as e: + logger.warning(f"Failed to redact image: {e}") + + return msg_copy + + async def _deanonymize_item(self, item: Dict[str, Any]) -> Dict[str, Any]: + """Deanonymize PII in tool calls and computer outputs.""" + item_copy = item.copy() + + # Handle computer_call arguments + if item.get("type") == "computer_call": + args = item_copy.get("args", {}) + if isinstance(args, dict): + deanonymized_args = {} + for key, value in args.items(): + if isinstance(value, str): + deanonymized_value, _ = await self._deanonymize_text(value) + deanonymized_args[key] = deanonymized_value + else: + deanonymized_args[key] = value + item_copy["args"] = deanonymized_args + + return item_copy + + async def _anonymize_text(self, text: str) -> Tuple[str, List[RecognizerResult]]: + """Anonymize PII in text and return the anonymized text and results.""" + if not text.strip(): + return text, [] + + try: + # Analyze text for PII + analyzer_results = self.analyzer.analyze( + text=text, + entities=self.entities_to_anonymize, + language="en" + ) + + if not analyzer_results: + return text, [] + + # Anonymize the text + anonymized_result = self.anonymizer.anonymize( + text=text, + analyzer_results=analyzer_results, + operators={entity_type: OperatorConfig(self.anonymization_operator) + for entity_type in set(result.entity_type for result in analyzer_results)} + ) + + # Store mapping for deanonymization + mapping_key = str(hash(text)) + self.anonymization_mappings[mapping_key] = { + "original": text, + "anonymized": anonymized_result.text, + "results": analyzer_results + } + + return anonymized_result.text, analyzer_results + + except Exception as e: + logger.warning(f"Failed to anonymize text: {e}") + return text, [] + + async def _deanonymize_text(self, text: str) -> Tuple[str, bool]: + """Attempt to deanonymize text using stored mappings.""" + try: + # Look for matching anonymized text in mappings + for mapping_key, mapping in self.anonymization_mappings.items(): + if mapping["anonymized"] == text: + return mapping["original"], True + + # If no mapping found, return original text + return text, False + + except Exception as e: + logger.warning(f"Failed to deanonymize text: {e}") + return text, False diff --git a/libs/python/agent2/agent2/callbacks/trajectory_saver.py b/libs/python/agent2/agent2/callbacks/trajectory_saver.py new file mode 100644 index 00000000..b59563d5 --- /dev/null +++ b/libs/python/agent2/agent2/callbacks/trajectory_saver.py @@ -0,0 +1,305 @@ +""" +Trajectory saving callback handler for ComputerAgent. +""" + +import os +import json +import uuid +from datetime import datetime +import base64 +from pathlib import Path +from typing import List, Dict, Any, Optional, Union, override +from PIL import Image, ImageDraw +import io +from .base import AsyncCallbackHandler + +def sanitize_image_urls(data: Any) -> Any: + """ + Recursively search for 'image_url' keys and set their values to '[omitted]'. + + Args: + data: Any data structure (dict, list, or primitive type) + + Returns: + A deep copy of the data with all 'image_url' values replaced with '[omitted]' + """ + if isinstance(data, dict): + # Create a copy of the dictionary + sanitized = {} + for key, value in data.items(): + if key == "image_url": + sanitized[key] = "[omitted]" + else: + # Recursively sanitize the value + sanitized[key] = sanitize_image_urls(value) + return sanitized + + elif isinstance(data, list): + # Recursively sanitize each item in the list + return [sanitize_image_urls(item) for item in data] + + else: + # For primitive types (str, int, bool, None, etc.), return as-is + return data + + +class TrajectorySaverCallback(AsyncCallbackHandler): + """ + Callback handler that saves agent trajectories to disk. + + Saves each run as a separate trajectory with unique ID, and each turn + within the trajectory gets its own folder with screenshots and responses. + """ + + def __init__(self, trajectory_dir: str): + """ + Initialize trajectory saver. + + Args: + trajectory_dir: Base directory to save trajectories + """ + self.trajectory_dir = Path(trajectory_dir) + self.trajectory_id: Optional[str] = None + self.current_turn: int = 0 + self.current_artifact: int = 0 + self.model: Optional[str] = None + self.total_usage: Dict[str, Any] = {} + + # Ensure trajectory directory exists + self.trajectory_dir.mkdir(parents=True, exist_ok=True) + + def _get_turn_dir(self) -> Path: + """Get the directory for the current turn.""" + if not self.trajectory_id: + raise ValueError("Trajectory not initialized - call _on_run_start first") + + # format: trajectory_id/turn_000 + turn_dir = self.trajectory_dir / self.trajectory_id / f"turn_{self.current_turn:03d}" + turn_dir.mkdir(parents=True, exist_ok=True) + return turn_dir + + def _save_artifact(self, name: str, artifact: Union[str, bytes, Dict[str, Any]]) -> None: + """Save an artifact to the current turn directory.""" + turn_dir = self._get_turn_dir() + if isinstance(artifact, bytes): + # format: turn_000/0000_name.png + artifact_filename = f"{self.current_artifact:04d}_{name}" + artifact_path = turn_dir / f"{artifact_filename}.png" + with open(artifact_path, "wb") as f: + f.write(artifact) + else: + # format: turn_000/0000_name.json + artifact_filename = f"{self.current_artifact:04d}_{name}" + artifact_path = turn_dir / f"{artifact_filename}.json" + with open(artifact_path, "w") as f: + json.dump(sanitize_image_urls(artifact), f, indent=2) + self.current_artifact += 1 + + def _update_usage(self, usage: Dict[str, Any]) -> None: + """Update total usage statistics.""" + def add_dicts(target: Dict[str, Any], source: Dict[str, Any]) -> None: + for key, value in source.items(): + if isinstance(value, dict): + if key not in target: + target[key] = {} + add_dicts(target[key], value) + else: + if key not in target: + target[key] = 0 + target[key] += value + add_dicts(self.total_usage, usage) + + @override + async def on_run_start(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]]) -> None: + """Initialize trajectory tracking for a new run.""" + model = kwargs.get("model", "unknown") + model_name_short = model.split("+")[-1].split("/")[-1].lower()[:16] + if "+" in model: + model_name_short = model.split("+")[0].lower()[:4] + "_" + model_name_short + + # id format: yyyy-mm-dd_model_hhmmss_uuid[:4] + now = datetime.now() + self.trajectory_id = f"{now.strftime('%Y-%m-%d')}_{model_name_short}_{now.strftime('%H%M%S')}_{str(uuid.uuid4())[:4]}" + self.current_turn = 0 + self.current_artifact = 0 + self.model = model + self.total_usage = {} + + # Create trajectory directory + trajectory_path = self.trajectory_dir / self.trajectory_id + trajectory_path.mkdir(parents=True, exist_ok=True) + + # Save trajectory metadata + metadata = { + "trajectory_id": self.trajectory_id, + "created_at": str(uuid.uuid1().time), + "status": "running", + "kwargs": kwargs, + } + + with open(trajectory_path / "metadata.json", "w") as f: + json.dump(metadata, f, indent=2) + + @override + async def on_run_end(self, kwargs: Dict[str, Any], old_items: List[Dict[str, Any]], new_items: List[Dict[str, Any]]) -> None: + """Finalize run tracking by updating metadata with completion status, usage, and new items.""" + if not self.trajectory_id: + return + + # Update metadata with completion status, total usage, and new items + trajectory_path = self.trajectory_dir / self.trajectory_id + metadata_path = trajectory_path / "metadata.json" + + # Read existing metadata + if metadata_path.exists(): + with open(metadata_path, "r") as f: + metadata = json.load(f) + else: + metadata = {} + + # Update metadata with completion info + metadata.update({ + "status": "completed", + "completed_at": str(uuid.uuid1().time), + "total_usage": self.total_usage, + "new_items": sanitize_image_urls(new_items), + "total_turns": self.current_turn + }) + + # Save updated metadata + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + @override + async def on_api_start(self, kwargs: Dict[str, Any]) -> None: + if not self.trajectory_id: + return + + self._save_artifact("api_start", { "kwargs": kwargs }) + + @override + async def on_api_end(self, kwargs: Dict[str, Any], result: Any) -> None: + """Save API call result.""" + if not self.trajectory_id: + return + + self._save_artifact("api_result", { "kwargs": kwargs, "result": result }) + + @override + async def on_screenshot(self, screenshot: Union[str, bytes], name: str = "screenshot") -> None: + """Save a screenshot.""" + if isinstance(screenshot, str): + screenshot = base64.b64decode(screenshot) + self._save_artifact(name, screenshot) + + @override + async def on_usage(self, usage: Dict[str, Any]) -> None: + """Called when usage information is received.""" + self._update_usage(usage) + + @override + async def on_responses(self, kwargs: Dict[str, Any], responses: Dict[str, Any]) -> None: + """Save responses to the current turn directory and update usage statistics.""" + if not self.trajectory_id: + return + + # Save responses + turn_dir = self._get_turn_dir() + response_data = { + "timestamp": str(uuid.uuid1().time), + "model": self.model, + "kwargs": kwargs, + "response": responses + } + + self._save_artifact("agent_response", response_data) + + # Increment turn counter + self.current_turn += 1 + + def _draw_crosshair_on_image(self, image_bytes: bytes, x: int, y: int) -> bytes: + """ + Draw a red dot and crosshair at the specified coordinates on the image. + + Args: + image_bytes: The original image as bytes + x: X coordinate for the crosshair + y: Y coordinate for the crosshair + + Returns: + Modified image as bytes with red dot and crosshair + """ + # Open the image + image = Image.open(io.BytesIO(image_bytes)) + draw = ImageDraw.Draw(image) + + # Draw crosshair lines (red, 2px thick) + crosshair_size = 20 + line_width = 2 + color = "red" + + # Horizontal line + draw.line([(x - crosshair_size, y), (x + crosshair_size, y)], fill=color, width=line_width) + # Vertical line + draw.line([(x, y - crosshair_size), (x, y + crosshair_size)], fill=color, width=line_width) + + # Draw center dot (filled circle) + dot_radius = 3 + draw.ellipse([(x - dot_radius, y - dot_radius), (x + dot_radius, y + dot_radius)], fill=color) + + # Convert back to bytes + output = io.BytesIO() + image.save(output, format='PNG') + return output.getvalue() + + @override + async def on_computer_call_end(self, item: Dict[str, Any], result: List[Dict[str, Any]]) -> None: + """ + Called when a computer call has completed. + Saves screenshots and computer call output. + """ + if not self.trajectory_id: + return + + self._save_artifact("computer_call_result", { "item": item, "result": result }) + + # Check if action has x/y coordinates and there's a screenshot in the result + action = item.get("action", {}) + if "x" in action and "y" in action: + # Look for screenshot in the result + for result_item in result: + if (result_item.get("type") == "computer_call_output" and + result_item.get("output", {}).get("type") == "input_image"): + + image_url = result_item["output"]["image_url"] + + # Extract base64 image data + if image_url.startswith("data:image/"): + # Format: data:image/png;base64, + base64_data = image_url.split(",", 1)[1] + else: + # Assume it's just base64 data + base64_data = image_url + + try: + # Decode the image + image_bytes = base64.b64decode(base64_data) + + # Draw crosshair at the action coordinates + annotated_image = self._draw_crosshair_on_image( + image_bytes, + int(action["x"]), + int(action["y"]) + ) + + # Save as screenshot_action + self._save_artifact("screenshot_action", annotated_image) + + except Exception as e: + # If annotation fails, just log and continue + print(f"Failed to annotate screenshot: {e}") + + break # Only process the first screenshot found + + # Increment turn counter + self.current_turn += 1 \ No newline at end of file diff --git a/libs/python/agent2/agent2/computer_handler.py b/libs/python/agent2/agent2/computer_handler.py new file mode 100644 index 00000000..4a9f0186 --- /dev/null +++ b/libs/python/agent2/agent2/computer_handler.py @@ -0,0 +1,107 @@ +""" +Computer handler implementation for OpenAI computer-use-preview protocol. +""" + +import base64 +from typing import Dict, List, Any, Literal +from .types import Computer + + +class OpenAIComputerHandler: + """Computer handler that implements the Computer protocol using the computer interface.""" + + def __init__(self, computer_interface): + """Initialize with a computer interface (from tool schema).""" + self.interface = computer_interface + + async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]: + """Get the current environment type.""" + # For now, return a default - this could be enhanced to detect actual environment + return "windows" + + async def get_dimensions(self) -> tuple[int, int]: + """Get screen dimensions as (width, height).""" + screen_size = await self.interface.get_screen_size() + return screen_size["width"], screen_size["height"] + + async def screenshot(self) -> str: + """Take a screenshot and return as base64 string.""" + screenshot_bytes = await self.interface.screenshot() + return base64.b64encode(screenshot_bytes).decode('utf-8') + + async def click(self, x: int, y: int, button: str = "left") -> None: + """Click at coordinates with specified button.""" + if button == "left": + await self.interface.left_click(x, y) + elif button == "right": + await self.interface.right_click(x, y) + else: + # Default to left click for unknown buttons + await self.interface.left_click(x, y) + + async def double_click(self, x: int, y: int) -> None: + """Double click at coordinates.""" + await self.interface.double_click(x, y) + + async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None: + """Scroll at coordinates with specified scroll amounts.""" + await self.interface.move_cursor(x, y) + await self.interface.scroll(scroll_x, scroll_y) + + async def type(self, text: str) -> None: + """Type text.""" + await self.interface.type_text(text) + + async def wait(self, ms: int = 1000) -> None: + """Wait for specified milliseconds.""" + import asyncio + await asyncio.sleep(ms / 1000.0) + + async def move(self, x: int, y: int) -> None: + """Move cursor to coordinates.""" + await self.interface.move_cursor(x, y) + + async def keypress(self, keys: List[str]) -> None: + """Press key combination.""" + if len(keys) == 1: + await self.interface.press_key(keys[0]) + else: + # Handle key combinations + await self.interface.hotkey(*keys) + + async def drag(self, path: List[Dict[str, int]]) -> None: + """Drag along specified path.""" + if not path: + return + + # Start drag from first point + start = path[0] + await self.interface.mouse_down(start["x"], start["y"]) + + # Move through path + for point in path[1:]: + await self.interface.move_cursor(point["x"], point["y"]) + + # End drag at last point + end = path[-1] + await self.interface.mouse_up(end["x"], end["y"]) + + async def get_current_url(self) -> str: + """Get current URL (for browser environments).""" + # This would need to be implemented based on the specific browser interface + # For now, return empty string + return "" + + +def acknowledge_safety_check_callback(message: str) -> bool: + """Safety check callback for user acknowledgment.""" + response = input( + f"Safety Check Warning: {message}\nDo you want to acknowledge and proceed? (y/n): " + ).lower() + return response.strip() == "y" + + +def check_blocklisted_url(url: str) -> None: + """Check if URL is blocklisted (placeholder implementation).""" + # This would contain actual URL checking logic + pass diff --git a/libs/python/agent2/agent2/decorators.py b/libs/python/agent2/agent2/decorators.py new file mode 100644 index 00000000..102d26fb --- /dev/null +++ b/libs/python/agent2/agent2/decorators.py @@ -0,0 +1,90 @@ +""" +Decorators for agent2 - agent_loop decorator +""" + +import asyncio +import inspect +from typing import Dict, List, Any, Callable, Optional +from functools import wraps + +from .types import AgentLoopInfo + +# Global registry +_agent_loops: List[AgentLoopInfo] = [] + +def agent_loop(models: str, priority: int = 0): + """ + Decorator to register an agent loop function. + + Args: + models: Regex pattern to match supported models + priority: Priority for loop selection (higher = more priority) + """ + def decorator(func: Callable): + # Validate function signature + sig = inspect.signature(func) + required_params = {'messages', 'model'} + func_params = set(sig.parameters.keys()) + + if not required_params.issubset(func_params): + missing = required_params - func_params + raise ValueError(f"Agent loop function must have parameters: {missing}") + + # Register the loop + loop_info = AgentLoopInfo( + func=func, + models_regex=models, + priority=priority + ) + _agent_loops.append(loop_info) + + # Sort by priority (highest first) + _agent_loops.sort(key=lambda x: x.priority, reverse=True) + + @wraps(func) + async def wrapper(*args, **kwargs): + # Wrap the function in an asyncio.Queue for cancellation support + queue = asyncio.Queue() + task = None + + try: + # Create a task that can be cancelled + async def run_loop(): + try: + result = await func(*args, **kwargs) + await queue.put(('result', result)) + except Exception as e: + await queue.put(('error', e)) + + task = asyncio.create_task(run_loop()) + + # Wait for result or cancellation + event_type, data = await queue.get() + + if event_type == 'error': + raise data + return data + + except asyncio.CancelledError: + if task: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + raise + + return wrapper + + return decorator + +def get_agent_loops() -> List[AgentLoopInfo]: + """Get all registered agent loops""" + return _agent_loops.copy() + +def find_agent_loop(model: str) -> Optional[AgentLoopInfo]: + """Find the best matching agent loop for a model""" + for loop_info in _agent_loops: + if loop_info.matches_model(model): + return loop_info + return None diff --git a/libs/python/agent2/agent2/loops/__init__.py b/libs/python/agent2/agent2/loops/__init__.py new file mode 100644 index 00000000..c02cdb5d --- /dev/null +++ b/libs/python/agent2/agent2/loops/__init__.py @@ -0,0 +1,11 @@ +""" +Agent loops for agent2 +""" + +# Import the loops to register them +from . import anthropic +from . import openai +from . import uitars +from . import omniparser + +__all__ = ["anthropic", "openai", "uitars", "omniparser"] diff --git a/libs/python/agent2/agent2/loops/anthropic.py b/libs/python/agent2/agent2/loops/anthropic.py new file mode 100644 index 00000000..23a587f5 --- /dev/null +++ b/libs/python/agent2/agent2/loops/anthropic.py @@ -0,0 +1,728 @@ +""" +Anthropic hosted tools agent loop implementation using liteLLM +""" + +import asyncio +import json +from typing import Dict, List, Any, AsyncGenerator, Union, Optional +import litellm +from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig + +from ..decorators import agent_loop +from ..types import Messages, AgentResponse, Tools +from ..responses import ( + make_reasoning_item, + make_output_text_item, + make_click_item, + make_double_click_item, + make_drag_item, + make_keypress_item, + make_move_item, + make_scroll_item, + make_type_item, + make_wait_item, + make_input_image_item, + make_screenshot_item +) + +# Model version mapping to tool version and beta flag +MODEL_TOOL_MAPPING = [ + # Claude 4 models + { + "pattern": r"claude-4|claude-opus-4|claude-sonnet-4", + "tool_version": "computer_20250124", + "beta_flag": "computer-use-2025-01-24" + }, + # Claude 3.7 models + { + "pattern": r"claude-3\.?7|claude-3-7", + "tool_version": "computer_20250124", + "beta_flag": "computer-use-2025-01-24" + }, + # Claude 3.5 models (fallback) + { + "pattern": r"claude-3\.?5|claude-3-5", + "tool_version": "computer_20241022", + "beta_flag": "computer-use-2024-10-22" + } +] + +def _get_tool_config_for_model(model: str) -> Dict[str, str]: + """Get tool version and beta flag for the given model.""" + import re + + for mapping in MODEL_TOOL_MAPPING: + if re.search(mapping["pattern"], model, re.IGNORECASE): + return { + "tool_version": mapping["tool_version"], + "beta_flag": mapping["beta_flag"] + } + + # Default to Claude 3.5 configuration + return { + "tool_version": "computer_20241022", + "beta_flag": "computer-use-2024-10-22" + } + +def _map_computer_tool_to_anthropic(computer_tool: Any, tool_version: str) -> Dict[str, Any]: + """Map a computer tool to Anthropic's hosted tool schema.""" + return { + "type": tool_version, + "function": { + "name": "computer", + "parameters": { + "display_height_px": getattr(computer_tool, 'display_height', 768), + "display_width_px": getattr(computer_tool, 'display_width', 1024), + "display_number": getattr(computer_tool, 'display_number', 1), + }, + }, + } + +def _prepare_tools_for_anthropic(tool_schemas: List[Dict[str, Any]], model: str) -> Tools: + """Prepare tools for Anthropic API format.""" + tool_config = _get_tool_config_for_model(model) + anthropic_tools = [] + + for schema in tool_schemas: + if schema["type"] == "computer": + # Map computer tool to Anthropic format + anthropic_tools.append(_map_computer_tool_to_anthropic( + schema["computer"], + tool_config["tool_version"] + )) + elif schema["type"] == "function": + # Function tools - convert to Anthropic format + function_schema = schema["function"] + anthropic_tools.append({ + "type": "function", + "function": { + "name": function_schema["name"], + "description": function_schema.get("description", ""), + "parameters": function_schema.get("parameters", {}) + } + }) + + return anthropic_tools + +def _convert_responses_items_to_completion_messages(messages: Messages) -> List[Dict[str, Any]]: + """Convert responses_items message format to liteLLM completion format.""" + completion_messages = [] + + for message in messages: + msg_type = message.get("type") + role = message.get("role") + + # Handle user messages (both with and without explicit type) + if role == "user" or msg_type == "user": + content = message.get("content", "") + if isinstance(content, list): + # Multi-modal content - convert input_image to image format + converted_content = [] + for item in content: + if isinstance(item, dict) and item.get("type") == "input_image": + # Convert input_image to Anthropic image format + image_url = item.get("image_url", "") + if image_url and image_url != "[omitted]": + # Extract base64 data from data URL + if "," in image_url: + base64_data = image_url.split(",")[-1] + else: + base64_data = image_url + + converted_content.append({ + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": base64_data + } + }) + else: + # Keep other content types as-is + converted_content.append(item) + + completion_messages.append({ + "role": "user", + "content": converted_content if converted_content else content + }) + else: + # Text content + completion_messages.append({ + "role": "user", + "content": content + }) + + # Handle assistant messages + elif role == "assistant": + content = message.get("content", []) + if isinstance(content, str): + content = [{ "type": "output_text", "text": content }] + + content = "\n".join(item.get("text", "") for item in content) + completion_messages.append({ + "role": "assistant", + "content": content + }) + + elif msg_type == "reasoning": + # Reasoning becomes part of assistant message + summary = message.get("summary", []) + reasoning_text = "" + + if isinstance(summary, list) and summary: + # Extract text from summary items + for item in summary: + if isinstance(item, dict) and item.get("type") == "summary_text": + reasoning_text = item.get("text", "") + break + else: + # Fallback to direct reasoning field + reasoning_text = message.get("reasoning", "") + + if reasoning_text: + completion_messages.append({ + "role": "assistant", + "content": reasoning_text + }) + + elif msg_type == "computer_call": + # Computer call becomes tool use in assistant message + action = message.get("action", {}) + action_type = action.get("type") + call_id = message.get("call_id", "call_1") + + tool_use_content = [] + + if action_type == "click": + tool_use_content.append({ + "type": "tool_use", + "id": call_id, + "name": "computer", + "input": { + "action": "click", + "coordinate": [action.get("x", 0), action.get("y", 0)] + } + }) + elif action_type == "type": + tool_use_content.append({ + "type": "tool_use", + "id": call_id, + "name": "computer", + "input": { + "action": "type", + "text": action.get("text", "") + } + }) + elif action_type == "key": + tool_use_content.append({ + "type": "tool_use", + "id": call_id, + "name": "computer", + "input": { + "action": "key", + "key": action.get("key", "") + } + }) + elif action_type == "wait": + tool_use_content.append({ + "type": "tool_use", + "id": call_id, + "name": "computer", + "input": { + "action": "screenshot" + } + }) + elif action_type == "screenshot": + tool_use_content.append({ + "type": "tool_use", + "id": call_id, + "name": "computer", + "input": { + "action": "screenshot" + } + }) + + # Convert tool_use_content to OpenAI tool_calls format + openai_tool_calls = [] + for tool_use in tool_use_content: + openai_tool_calls.append({ + "id": tool_use["id"], + "type": "function", + "function": { + "name": tool_use["name"], + "arguments": json.dumps(tool_use["input"]) + } + }) + + # If the last completion message is an assistant message, extend the tool_calls + if completion_messages and completion_messages[-1].get("role") == "assistant": + if "tool_calls" not in completion_messages[-1]: + completion_messages[-1]["tool_calls"] = [] + completion_messages[-1]["tool_calls"].extend(openai_tool_calls) + else: + # Create new assistant message with tool calls + completion_messages.append({ + "role": "assistant", + "content": None, + "tool_calls": openai_tool_calls + }) + + elif msg_type == "computer_call_output": + # Computer call output becomes OpenAI function result + output = message.get("output", {}) + call_id = message.get("call_id", "call_1") + + if output.get("type") == "input_image": + # Screenshot result - convert to OpenAI format with image_url content + image_url = output.get("image_url", "") + completion_messages.append({ + "role": "function", + "name": "computer", + "tool_call_id": call_id, + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }] + }) + else: + # Text result - convert to OpenAI format + completion_messages.append({ + "role": "function", + "name": "computer", + "tool_call_id": call_id, + "content": str(output) + }) + + return completion_messages + +def _convert_completion_to_responses_items(response: Any) -> List[Dict[str, Any]]: + """Convert liteLLM completion response to responses_items message format.""" + responses_items = [] + + if not response or not hasattr(response, 'choices') or not response.choices: + return responses_items + + choice = response.choices[0] + message = choice.message + + # Handle text content + if hasattr(message, 'content') and message.content: + if isinstance(message.content, str): + responses_items.append(make_output_text_item(message.content)) + elif isinstance(message.content, list): + for content_item in message.content: + if isinstance(content_item, dict): + if content_item.get("type") == "text": + responses_items.append(make_output_text_item(content_item.get("text", ""))) + elif content_item.get("type") == "tool_use": + # Convert tool use to computer call + tool_input = content_item.get("input", {}) + action_type = tool_input.get("action") + call_id = content_item.get("id") + + # Action reference: + # https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/computer-use-tool#available-actions + + # Basic actions (all versions) + if action_type == "screenshot": + responses_items.append(make_screenshot_item(call_id=call_id)) + elif action_type == "left_click": + coordinate = tool_input.get("coordinate", [0, 0]) + responses_items.append(make_click_item( + x=coordinate[0] if len(coordinate) > 0 else 0, + y=coordinate[1] if len(coordinate) > 1 else 0, + call_id=call_id + )) + elif action_type == "type": + responses_items.append(make_type_item( + text=tool_input.get("text", ""), + call_id=call_id + )) + elif action_type == "key": + responses_items.append(make_keypress_item( + key=tool_input.get("key", ""), + call_id=call_id + )) + elif action_type == "mouse_move": + # Mouse move - create a custom action item + coordinate = tool_input.get("coordinate", [0, 0]) + responses_items.append({ + "type": "computer_call", + "call_id": call_id, + "action": { + "type": "mouse_move", + "x": coordinate[0] if len(coordinate) > 0 else 0, + "y": coordinate[1] if len(coordinate) > 1 else 0 + } + }) + + # Enhanced actions (computer_20250124) Available in Claude 4 and Claude Sonnet 3.7 + elif action_type == "scroll": + coordinate = tool_input.get("coordinate", [0, 0]) + responses_items.append(make_scroll_item( + x=coordinate[0] if len(coordinate) > 0 else 0, + y=coordinate[1] if len(coordinate) > 1 else 0, + direction=tool_input.get("scroll_direction", "down"), + amount=tool_input.get("scroll_amount", 3), + call_id=call_id + )) + elif action_type == "left_click_drag": + start_coord = tool_input.get("start_coordinate", [0, 0]) + end_coord = tool_input.get("end_coordinate", [0, 0]) + responses_items.append(make_drag_item( + start_x=start_coord[0] if len(start_coord) > 0 else 0, + start_y=start_coord[1] if len(start_coord) > 1 else 0, + end_x=end_coord[0] if len(end_coord) > 0 else 0, + end_y=end_coord[1] if len(end_coord) > 1 else 0, + call_id=call_id + )) + elif action_type == "right_click": + coordinate = tool_input.get("coordinate", [0, 0]) + responses_items.append(make_click_item( + x=coordinate[0] if len(coordinate) > 0 else 0, + y=coordinate[1] if len(coordinate) > 1 else 0, + button="right", + call_id=call_id + )) + elif action_type == "middle_click": + coordinate = tool_input.get("coordinate", [0, 0]) + responses_items.append(make_click_item( + x=coordinate[0] if len(coordinate) > 0 else 0, + y=coordinate[1] if len(coordinate) > 1 else 0, + button="wheel", + call_id=call_id + )) + elif action_type == "double_click": + coordinate = tool_input.get("coordinate", [0, 0]) + responses_items.append(make_double_click_item( + x=coordinate[0] if len(coordinate) > 0 else 0, + y=coordinate[1] if len(coordinate) > 1 else 0, + call_id=call_id + )) + elif action_type == "triple_click": + # coordinate = tool_input.get("coordinate", [0, 0]) + # responses_items.append({ + # "type": "computer_call", + # "call_id": call_id, + # "action": { + # "type": "triple_click", + # "x": coordinate[0] if len(coordinate) > 0 else 0, + # "y": coordinate[1] if len(coordinate) > 1 else 0 + # } + # }) + raise NotImplementedError("triple_click") + elif action_type == "left_mouse_down": + # coordinate = tool_input.get("coordinate", [0, 0]) + # responses_items.append({ + # "type": "computer_call", + # "call_id": call_id, + # "action": { + # "type": "mouse_down", + # "button": "left", + # "x": coordinate[0] if len(coordinate) > 0 else 0, + # "y": coordinate[1] if len(coordinate) > 1 else 0 + # } + # }) + raise NotImplementedError("left_mouse_down") + elif action_type == "left_mouse_up": + # coordinate = tool_input.get("coordinate", [0, 0]) + # responses_items.append({ + # "type": "computer_call", + # "call_id": call_id, + # "action": { + # "type": "mouse_up", + # "button": "left", + # "x": coordinate[0] if len(coordinate) > 0 else 0, + # "y": coordinate[1] if len(coordinate) > 1 else 0 + # } + # }) + raise NotImplementedError("left_mouse_up") + elif action_type == "hold_key": + # responses_items.append({ + # "type": "computer_call", + # "call_id": call_id, + # "action": { + # "type": "key_hold", + # "key": tool_input.get("key", "") + # } + # }) + raise NotImplementedError("hold_key") + elif action_type == "wait": + responses_items.append(make_wait_item( + call_id=call_id + )) + else: + raise ValueError(f"Unknown action type: {action_type}") + + # Handle tool calls (alternative format) + if hasattr(message, 'tool_calls') and message.tool_calls: + for tool_call in message.tool_calls: + print(tool_call) + if tool_call.function.name == "computer": + try: + args = json.loads(tool_call.function.arguments) + action_type = args.get("action") + call_id = tool_call.id + + # Basic actions (all versions) + if action_type == "screenshot": + responses_items.append(make_screenshot_item( + call_id=call_id + )) + elif action_type in ["click", "left_click"]: + coordinate = args.get("coordinate", [0, 0]) + responses_items.append(make_click_item( + x=coordinate[0] if len(coordinate) > 0 else 0, + y=coordinate[1] if len(coordinate) > 1 else 0, + call_id=call_id + )) + elif action_type == "type": + responses_items.append(make_type_item( + text=args.get("text", ""), + call_id=call_id + )) + elif action_type == "key": + responses_items.append(make_keypress_item( + key=args.get("key", ""), + call_id=call_id + )) + elif action_type == "mouse_move": + coordinate = args.get("coordinate", [0, 0]) + responses_items.append(make_move_item( + x=coordinate[0] if len(coordinate) > 0 else 0, + y=coordinate[1] if len(coordinate) > 1 else 0, + call_id=call_id + )) + + # Enhanced actions (computer_20250124) Available in Claude 4 and Claude Sonnet 3.7 + elif action_type == "scroll": + coordinate = args.get("coordinate", [0, 0]) + direction = args.get("scroll_direction", "down") + amount = args.get("scroll_amount", 3) + scroll_x = amount if direction == "left" else \ + -amount if direction == "right" else 0 + scroll_y = amount if direction == "up" else \ + -amount if direction == "down" else 0 + responses_items.append(make_scroll_item( + x=coordinate[0] if len(coordinate) > 0 else 0, + y=coordinate[1] if len(coordinate) > 1 else 0, + scroll_x=scroll_x, + scroll_y=scroll_y, + call_id=call_id + )) + elif action_type == "left_click_drag": + start_coord = args.get("start_coordinate", [0, 0]) + end_coord = args.get("end_coordinate", [0, 0]) + responses_items.append(make_drag_item( + start_x=start_coord[0] if len(start_coord) > 0 else 0, + start_y=start_coord[1] if len(start_coord) > 1 else 0, + end_x=end_coord[0] if len(end_coord) > 0 else 0, + end_y=end_coord[1] if len(end_coord) > 1 else 0, + call_id=call_id + )) + elif action_type == "right_click": + coordinate = args.get("coordinate", [0, 0]) + responses_items.append(make_click_item( + x=coordinate[0] if len(coordinate) > 0 else 0, + y=coordinate[1] if len(coordinate) > 1 else 0, + button="right", + call_id=call_id + )) + elif action_type == "middle_click": + coordinate = args.get("coordinate", [0, 0]) + responses_items.append(make_click_item( + x=coordinate[0] if len(coordinate) > 0 else 0, + y=coordinate[1] if len(coordinate) > 1 else 0, + button="scroll", + call_id=call_id + )) + elif action_type == "double_click": + coordinate = args.get("coordinate", [0, 0]) + responses_items.append(make_double_click_item( + x=coordinate[0] if len(coordinate) > 0 else 0, + y=coordinate[1] if len(coordinate) > 1 else 0, + call_id=call_id + )) + elif action_type == "triple_click": + raise NotImplementedError("triple_click") + elif action_type == "left_mouse_down": + raise NotImplementedError("left_mouse_down") + elif action_type == "left_mouse_up": + raise NotImplementedError("left_mouse_up") + elif action_type == "hold_key": + raise NotImplementedError("hold_key") + elif action_type == "wait": + responses_items.append(make_wait_item( + call_id=call_id + )) + except json.JSONDecodeError: + print("Failed to decode tool call arguments") + # Skip malformed tool calls + continue + + return responses_items + +def _add_cache_control(completion_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Add cache control to completion messages""" + num_writes = 0 + for message in completion_messages: + message["cache_control"] = { "type": "ephemeral" } + num_writes += 1 + # Cache control has a maximum of 4 blocks + if num_writes >= 4: + break + + return completion_messages + +def _combine_completion_messages(completion_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Combine completion messages with the same role""" + if not completion_messages: + return completion_messages + + combined_messages = [] + + for message in completion_messages: + # If this is the first message or role is different from last, add as new message + if not combined_messages or combined_messages[-1]["role"] != message["role"]: + # Ensure content is a list format and normalize text content + new_message = message.copy() + new_message["content"] = _normalize_content(message.get("content", "")) + + # Copy tool_calls if present + if "tool_calls" in message: + new_message["tool_calls"] = message["tool_calls"].copy() + + combined_messages.append(new_message) + else: + # Same role as previous message, combine them + last_message = combined_messages[-1] + + # Combine content + current_content = _normalize_content(message.get("content", "")) + last_message["content"].extend(current_content) + + # Combine tool_calls if present + if "tool_calls" in message: + if "tool_calls" not in last_message: + last_message["tool_calls"] = [] + last_message["tool_calls"].extend(message["tool_calls"]) + + # Post-process to merge consecutive text blocks + for message in combined_messages: + message["content"] = _merge_consecutive_text(message["content"]) + + return combined_messages + +def _normalize_content(content) -> List[Dict[str, Any]]: + """Normalize content to list format""" + if isinstance(content, str): + if content.strip(): # Only add non-empty strings + return [{"type": "text", "text": content}] + else: + return [] + elif isinstance(content, list): + return content.copy() + else: + return [] + +def _merge_consecutive_text(content_list: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Merge consecutive text blocks with newlines""" + if not content_list: + return content_list + + merged = [] + + for item in content_list: + if (item.get("type") == "text" and + merged and + merged[-1].get("type") == "text"): + # Merge with previous text block + merged[-1]["text"] += "\n" + item["text"] + else: + merged.append(item.copy()) + + return merged + +@agent_loop(models=r".*claude-.*", priority=5) +async def anthropic_hosted_tools_loop( + messages: Messages, + model: str, + tools: Optional[List[Dict[str, Any]]] = None, + max_retries: Optional[int] = None, + stream: bool = False, + computer_handler=None, + use_prompt_caching: Optional[bool] = False, + _on_api_start=None, + _on_api_end=None, + _on_usage=None, + _on_screenshot=None, + **kwargs +) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]: + """ + Anthropic hosted tools agent loop using liteLLM acompletion. + + Supports Anthropic's computer use models with hosted tools. + """ + tools = tools or [] + + # Get tool configuration for this model + tool_config = _get_tool_config_for_model(model) + + # Prepare tools for Anthropic API + anthropic_tools = _prepare_tools_for_anthropic(tools, model) + + # Convert responses_items messages to completion format + completion_messages = _convert_responses_items_to_completion_messages(messages) + if use_prompt_caching: + # First combine messages to reduce number of blocks + completion_messages = _combine_completion_messages(completion_messages) + # Then add cache control, anthropic requires explicit "cache_control" dicts + completion_messages = _add_cache_control(completion_messages) + + # Prepare API call kwargs + api_kwargs = { + "model": model, + "messages": completion_messages, + "tools": anthropic_tools if anthropic_tools else None, + "stream": stream, + "num_retries": max_retries, + **kwargs + } + + # Add beta header for computer use + if anthropic_tools: + api_kwargs["headers"] = { + "anthropic-beta": tool_config["beta_flag"] + } + + # Call API start hook + if _on_api_start: + await _on_api_start(api_kwargs) + + # Use liteLLM acompletion + response = await litellm.acompletion(**api_kwargs) + + # Call API end hook + if _on_api_end: + await _on_api_end(api_kwargs, response) + + # Convert response to responses_items format + responses_items = _convert_completion_to_responses_items(response) + + # Extract usage information + responses_usage = { + **LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(), + "response_cost": response._hidden_params.get("response_cost", 0.0), + } + if _on_usage: + await _on_usage(responses_usage) + + # Create agent response + agent_response = { + "output": responses_items, + "usage": responses_usage + } + + return agent_response diff --git a/libs/python/agent2/agent2/loops/omniparser.py b/libs/python/agent2/agent2/loops/omniparser.py new file mode 100644 index 00000000..3c08d887 --- /dev/null +++ b/libs/python/agent2/agent2/loops/omniparser.py @@ -0,0 +1,339 @@ +""" +OpenAI computer-use-preview agent loop implementation using liteLLM +""" + +import asyncio +import json +from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple +import litellm +import inspect +import base64 + +from ..decorators import agent_loop +from ..types import Messages, AgentResponse, Tools + +SOM_TOOL_SCHEMA = { + "type": "function", + "name": "computer", + "description": "Control a computer by taking screenshots and interacting with UI elements. This tool shows screenshots with numbered elements overlaid on them. Each UI element has been assigned a unique ID number that you can see in the image. Use the element's ID number to interact with any element instead of pixel coordinates.", + "parameters": { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": [ + "screenshot", + "click", + "double_click", + "drag", + "type", + "keypress", + "scroll", + "move", + "wait", + "get_current_url", + "get_dimensions", + "get_environment" + ], + "description": "The action to perform" + }, + "element_id": { + "type": "integer", + "description": "The ID of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)" + }, + "start_element_id": { + "type": "integer", + "description": "The ID of the element to start dragging from (required for drag action)" + }, + "end_element_id": { + "type": "integer", + "description": "The ID of the element to drag to (required for drag action)" + }, + "text": { + "type": "string", + "description": "The text to type (required for type action)" + }, + "keys": { + "type": "string", + "description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')" + }, + "button": { + "type": "string", + "description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left", + }, + "scroll_x": { + "type": "integer", + "description": "Horizontal scroll amount for scroll action (positive for right, negative for left)", + }, + "scroll_y": { + "type": "integer", + "description": "Vertical scroll amount for scroll action (positive for down, negative for up)", + }, + }, + "required": [ + "action" + ] + } +} + +OMNIPARSER_AVAILABLE = False +try: + from som import OmniParser + OMNIPARSER_AVAILABLE = True +except ImportError: + pass +OMNIPARSER_SINGLETON = None + +def get_parser() -> OmniParser: + global OMNIPARSER_SINGLETON + if OMNIPARSER_SINGLETON is None: + OMNIPARSER_SINGLETON = OmniParser() + return OMNIPARSER_SINGLETON + +def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: + """Get the last computer_call_output message from a messages list. + + Args: + messages: List of messages to search through + + Returns: + The last computer_call_output message dict, or None if not found + """ + for message in reversed(messages): + if isinstance(message, dict) and message.get("type") == "computer_call_output": + return message + return None + +def _prepare_tools_for_omniparser(tool_schemas: List[Dict[str, Any]]) -> Tuple[Tools, dict]: + """Prepare tools for OpenAI API format""" + omniparser_tools = [] + id2xy = dict() + + for schema in tool_schemas: + if schema["type"] == "computer": + omniparser_tools.append(SOM_TOOL_SCHEMA) + if "id2xy" in schema: + id2xy = schema["id2xy"] + else: + schema["id2xy"] = id2xy + elif schema["type"] == "function": + # Function tools use OpenAI-compatible schema directly (liteLLM expects this format) + # Schema should be: {type, name, description, parameters} + omniparser_tools.append({ "type": "function", **schema["function"] }) + + return omniparser_tools, id2xy + +async def replace_function_with_computer_call(item: Dict[str, Any], id2xy: Dict[int, Tuple[float, float]]): + item_type = item.get("type") + + def _get_xy(element_id: Optional[int]) -> Union[Tuple[float, float], Tuple[None, None]]: + if element_id is None: + return (None, None) + return id2xy.get(element_id, (None, None)) + + if item_type == "function_call": + fn_name = item.get("name") + fn_args = json.loads(item.get("arguments", "{}")) + + item_id = item.get("id") + call_id = item.get("call_id") + + if fn_name == "computer": + action = fn_args.get("action") + element_id = fn_args.get("element_id") + start_element_id = fn_args.get("start_element_id") + end_element_id = fn_args.get("end_element_id") + text = fn_args.get("text") + keys = fn_args.get("keys") + button = fn_args.get("button") + scroll_x = fn_args.get("scroll_x") + scroll_y = fn_args.get("scroll_y") + + x, y = _get_xy(element_id) + start_x, start_y = _get_xy(start_element_id) + end_x, end_y = _get_xy(end_element_id) + + action_args = { + "type": action, + "x": x, + "y": y, + "start_x": start_x, + "start_y": start_y, + "end_x": end_x, + "end_y": end_y, + "text": text, + "keys": keys, + "button": button, + "scroll_x": scroll_x, + "scroll_y": scroll_y + } + # Remove None values to keep the JSON clean + action_args = {k: v for k, v in action_args.items() if v is not None} + + return [{ + "type": "computer_call", + "action": action_args, + "id": item_id, + "call_id": call_id, + "status": "completed" + }] + + return [item] + +async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[Tuple[float, float], int]): + """ + Convert computer_call back to function_call format. + Also handles computer_call_output -> function_call_output conversion. + + Args: + item: The item to convert + xy2id: Mapping from (x, y) coordinates to element IDs + """ + item_type = item.get("type") + + def _get_element_id(x: Optional[float], y: Optional[float]) -> Optional[int]: + """Get element ID from coordinates, return None if coordinates are None""" + if x is None or y is None: + return None + return xy2id.get((x, y)) + + if item_type == "computer_call": + action_data = item.get("action", {}) + + # Extract coordinates and convert back to element IDs + element_id = _get_element_id(action_data.get("x"), action_data.get("y")) + start_element_id = _get_element_id(action_data.get("start_x"), action_data.get("start_y")) + end_element_id = _get_element_id(action_data.get("end_x"), action_data.get("end_y")) + + # Build function arguments + fn_args = { + "action": action_data.get("type"), + "element_id": element_id, + "start_element_id": start_element_id, + "end_element_id": end_element_id, + "text": action_data.get("text"), + "keys": action_data.get("keys"), + "button": action_data.get("button"), + "scroll_x": action_data.get("scroll_x"), + "scroll_y": action_data.get("scroll_y") + } + + # Remove None values to keep the JSON clean + fn_args = {k: v for k, v in fn_args.items() if v is not None} + + return [{ + "type": "function_call", + "name": "computer", + "arguments": json.dumps(fn_args), + "id": item.get("id"), + "call_id": item.get("call_id"), + "status": "completed", + + # Fall back to string representation + "content": f"Used tool: {action_data.get("type")}({json.dumps(fn_args)})" + }] + + elif item_type == "computer_call_output": + # Simple conversion: computer_call_output -> function_call_output + return [{ + "type": "function_call_output", + "call_id": item.get("call_id"), + "content": [item.get("output")], + "id": item.get("id"), + "status": "completed" + }] + + return [item] + + +@agent_loop(models=r"omniparser\+.*|omni\+.*", priority=10) +async def omniparser_loop( + messages: Messages, + model: str, + tools: Optional[List[Dict[str, Any]]] = None, + max_retries: Optional[int] = None, + stream: bool = False, + computer_handler=None, + use_prompt_caching: Optional[bool] = False, + _on_api_start=None, + _on_api_end=None, + _on_usage=None, + _on_screenshot=None, + **kwargs +) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]: + """ + OpenAI computer-use-preview agent loop using liteLLM responses. + + Supports OpenAI's computer use preview models. + """ + if not OMNIPARSER_AVAILABLE: + raise ValueError("omniparser loop requires som to be installed. Install it with `pip install cua-som`.") + + tools = tools or [] + + llm_model = model.split('+')[-1] + + # Prepare tools for OpenAI API + openai_tools, id2xy = _prepare_tools_for_omniparser(tools) + + # Find last computer_call_output + last_computer_call_output = get_last_computer_call_output(messages) + if last_computer_call_output: + image_url = last_computer_call_output.get("output", {}).get("image_url", "") + image_data = image_url.split(",")[-1] + if image_data: + parser = get_parser() + result = parser.parse(image_data) + if _on_screenshot: + await _on_screenshot(result.annotated_image_base64, "annotated_image") + for element in result.elements: + id2xy[element.id] = ((element.bbox.x1 + element.bbox.x2) / 2, (element.bbox.y1 + element.bbox.y2) / 2) + + # handle computer calls -> function calls + new_messages = [] + for message in messages: + if not isinstance(message, dict): + message = message.__dict__ + new_messages += await replace_computer_call_with_function(message, id2xy) + messages = new_messages + + # Prepare API call kwargs + api_kwargs = { + "model": llm_model, + "input": messages, + "tools": openai_tools if openai_tools else None, + "stream": stream, + "reasoning": {"summary": "concise"}, + "truncation": "auto", + "num_retries": max_retries, + **kwargs + } + + # Call API start hook + if _on_api_start: + await _on_api_start(api_kwargs) + + print(str(api_kwargs)[:1000]) + + # Use liteLLM responses + response = await litellm.aresponses(**api_kwargs) + + # Call API end hook + if _on_api_end: + await _on_api_end(api_kwargs, response) + + # Extract usage information + response.usage = { + **response.usage.model_dump(), + "response_cost": response._hidden_params.get("response_cost", 0.0), + } + if _on_usage: + await _on_usage(response.usage) + + # handle som function calls -> xy computer calls + new_output = [] + for i in range(len(response.output)): + new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy) + response.output = new_output + + return response diff --git a/libs/python/agent2/agent2/loops/openai.py b/libs/python/agent2/agent2/loops/openai.py new file mode 100644 index 00000000..84b79d1f --- /dev/null +++ b/libs/python/agent2/agent2/loops/openai.py @@ -0,0 +1,95 @@ +""" +OpenAI computer-use-preview agent loop implementation using liteLLM +""" + +import asyncio +import json +from typing import Dict, List, Any, AsyncGenerator, Union, Optional +import litellm + +from ..decorators import agent_loop +from ..types import Messages, AgentResponse, Tools + +def _map_computer_tool_to_openai(computer_tool: Any) -> Dict[str, Any]: + """Map a computer tool to OpenAI's computer-use-preview tool schema""" + return { + "type": "computer_use_preview", + "display_width": getattr(computer_tool, 'display_width', 1024), + "display_height": getattr(computer_tool, 'display_height', 768), + "environment": getattr(computer_tool, 'environment', "linux") # mac, windows, linux, browser + } + + +def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools: + """Prepare tools for OpenAI API format""" + openai_tools = [] + + for schema in tool_schemas: + if schema["type"] == "computer": + # Map computer tool to OpenAI format + openai_tools.append(_map_computer_tool_to_openai(schema["computer"])) + elif schema["type"] == "function": + # Function tools use OpenAI-compatible schema directly (liteLLM expects this format) + # Schema should be: {type, name, description, parameters} + openai_tools.append({ "type": "function", **schema["function"] }) + + return openai_tools + + +@agent_loop(models=r".*computer-use-preview.*", priority=10) +async def openai_computer_use_loop( + messages: Messages, + model: str, + tools: Optional[List[Dict[str, Any]]] = None, + max_retries: Optional[int] = None, + stream: bool = False, + computer_handler=None, + use_prompt_caching: Optional[bool] = False, + _on_api_start=None, + _on_api_end=None, + _on_usage=None, + _on_screenshot=None, + **kwargs +) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]: + """ + OpenAI computer-use-preview agent loop using liteLLM responses. + + Supports OpenAI's computer use preview models. + """ + tools = tools or [] + + # Prepare tools for OpenAI API + openai_tools = _prepare_tools_for_openai(tools) + + # Prepare API call kwargs + api_kwargs = { + "model": model, + "input": messages, + "tools": openai_tools if openai_tools else None, + "stream": stream, + "reasoning": {"summary": "concise"}, + "truncation": "auto", + "num_retries": max_retries, + **kwargs + } + + # Call API start hook + if _on_api_start: + await _on_api_start(api_kwargs) + + # Use liteLLM responses + response = await litellm.aresponses(**api_kwargs) + + # Call API end hook + if _on_api_end: + await _on_api_end(api_kwargs, response) + + # Extract usage information + response.usage = { + **response.usage.model_dump(), + "response_cost": response._hidden_params.get("response_cost", 0.0), + } + if _on_usage: + await _on_usage(response.usage) + + return response diff --git a/libs/python/agent2/agent2/loops/uitars.py b/libs/python/agent2/agent2/loops/uitars.py new file mode 100644 index 00000000..e82e005d --- /dev/null +++ b/libs/python/agent2/agent2/loops/uitars.py @@ -0,0 +1,688 @@ +""" +UITARS agent loop implementation using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B +""" + +import asyncio +from ctypes import cast +import json +import base64 +import math +import re +import ast +from typing import Dict, List, Any, AsyncGenerator, Union, Optional +from io import BytesIO +from PIL import Image +import litellm +from litellm.types.utils import ModelResponse +from litellm.responses.litellm_completion_transformation.transformation import LiteLLMCompletionResponsesConfig +from litellm.responses.utils import Usage +from openai.types.responses.response_computer_tool_call_param import ActionType, ResponseComputerToolCallParam +from openai.types.responses.response_input_param import ComputerCallOutput +from openai.types.responses.response_output_message_param import ResponseOutputMessageParam +from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam, Summary + +from ..decorators import agent_loop +from ..types import Messages, AgentResponse, Tools +from ..responses import ( + make_reasoning_item, + make_output_text_item, + make_click_item, + make_double_click_item, + make_drag_item, + make_keypress_item, + make_scroll_item, + make_type_item, + make_wait_item, + make_input_image_item +) + +# Constants from reference code +IMAGE_FACTOR = 28 +MIN_PIXELS = 100 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + +FINISH_WORD = "finished" +WAIT_WORD = "wait" +ENV_FAIL_WORD = "error_env" +CALL_USER = "call_user" + +# Action space prompt for UITARS +UITARS_ACTION_SPACE = """ +click(start_box='<|box_start|>(x1,y1)<|box_end|>') +left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') +right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') +drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') +hotkey(key='') +type(content='') #If you want to submit your input, use "\\n" at the end of `content`. +scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') +wait() #Sleep for 5s and take a screenshot to check for any changes. +finished(content='xxx') # Use escape characters \\', \\", and \\n in content part to ensure we can parse the content in normal python string format. +""" + +UITARS_PROMPT_TEMPLATE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. + +## Output Format +``` +Thought: ... +Action: ... +``` + +## Action Space +{action_space} + +## Note +- Use {language} in `Thought` part. +- Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part. + +## User Instruction +{instruction} +""" + + +def round_by_factor(number: float, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: float, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: float, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + 1. Both dimensions (height and width) are divisible by 'factor'. + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar + + +def escape_single_quotes(text): + """Escape single quotes in text for safe string formatting.""" + pattern = r"(? List[Dict[str, Any]]: + """Parse UITARS model response into structured actions.""" + text = text.strip() + + # Extract thought + thought = None + if text.startswith("Thought:"): + thought_match = re.search(r"Thought: (.+?)(?=\s*Action:|$)", text, re.DOTALL) + if thought_match: + thought = thought_match.group(1).strip() + + # Extract action + if "Action:" not in text: + raise ValueError("No Action found in response") + + action_str = text.split("Action:")[-1].strip() + + # Handle special case for type actions + if "type(content" in action_str: + def escape_quotes(match): + return match.group(1) + + pattern = r"type\(content='(.*?)'\)" + content = re.sub(pattern, escape_quotes, action_str) + action_str = escape_single_quotes(content) + action_str = "type(content='" + action_str + "')" + + + # Parse the action + parsed_action = parse_action(action_str.replace("\n", "\\n").lstrip()) + if parsed_action is None: + raise ValueError(f"Action can't parse: {action_str}") + + action_type = parsed_action["function"] + params = parsed_action["args"] + + # Process parameters + action_inputs = {} + for param_name, param in params.items(): + if param == "": + continue + param = str(param).lstrip() + action_inputs[param_name.strip()] = param + + # Handle coordinate parameters + if "start_box" in param_name or "end_box" in param_name: + # Parse coordinates like '(x,y)' or '(x1,y1,x2,y2)' + numbers = param.replace("(", "").replace(")", "").split(",") + float_numbers = [float(num.strip()) / 1000 for num in numbers] # Normalize to 0-1 range + + if len(float_numbers) == 2: + # Single point, duplicate for box format + float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]] + + action_inputs[param_name.strip()] = str(float_numbers) + + return [{ + "thought": thought, + "action_type": action_type, + "action_inputs": action_inputs, + "text": text + }] + + +def convert_to_computer_actions(parsed_responses: List[Dict[str, Any]], image_width: int, image_height: int) -> List[ResponseComputerToolCallParam | ResponseOutputMessageParam]: + """Convert parsed UITARS responses to computer actions.""" + computer_actions = [] + + for response in parsed_responses: + action_type = response.get("action_type") + action_inputs = response.get("action_inputs", {}) + + if action_type == "finished": + finished_text = action_inputs.get("content", "Task completed successfully.") + computer_actions.append(make_output_text_item(finished_text)) + break + + elif action_type == "wait": + computer_actions.append(make_wait_item()) + + elif action_type == "call_user": + computer_actions.append(make_output_text_item("I need assistance from the user to proceed with this task.")) + + elif action_type in ["click", "left_single"]: + start_box = action_inputs.get("start_box") + if start_box: + coords = eval(start_box) + x = int((coords[0] + coords[2]) / 2 * image_width) + y = int((coords[1] + coords[3]) / 2 * image_height) + + computer_actions.append(make_click_item(x, y, "left")) + + elif action_type == "double_click": + start_box = action_inputs.get("start_box") + if start_box: + coords = eval(start_box) + x = int((coords[0] + coords[2]) / 2 * image_width) + y = int((coords[1] + coords[3]) / 2 * image_height) + + computer_actions.append(make_double_click_item(x, y)) + + elif action_type == "right_click": + start_box = action_inputs.get("start_box") + if start_box: + coords = eval(start_box) + x = int((coords[0] + coords[2]) / 2 * image_width) + y = int((coords[1] + coords[3]) / 2 * image_height) + + computer_actions.append(make_click_item(x, y, "right")) + + elif action_type == "type": + content = action_inputs.get("content", "") + computer_actions.append(make_type_item(content)) + + elif action_type == "hotkey": + key = action_inputs.get("key", "") + keys = key.split() + computer_actions.append(make_keypress_item(keys)) + + elif action_type == "press": + key = action_inputs.get("key", "") + computer_actions.append(make_keypress_item([key])) + + elif action_type == "scroll": + start_box = action_inputs.get("start_box") + direction = action_inputs.get("direction", "down") + + if start_box: + coords = eval(start_box) + x = int((coords[0] + coords[2]) / 2 * image_width) + y = int((coords[1] + coords[3]) / 2 * image_height) + else: + x, y = image_width // 2, image_height // 2 + + scroll_y = 5 if "up" in direction.lower() else -5 + computer_actions.append(make_scroll_item(x, y, 0, scroll_y)) + + elif action_type == "drag": + start_box = action_inputs.get("start_box") + end_box = action_inputs.get("end_box") + + if start_box and end_box: + start_coords = eval(start_box) + end_coords = eval(end_box) + + start_x = int((start_coords[0] + start_coords[2]) / 2 * image_width) + start_y = int((start_coords[1] + start_coords[3]) / 2 * image_height) + end_x = int((end_coords[0] + end_coords[2]) / 2 * image_width) + end_y = int((end_coords[1] + end_coords[3]) / 2 * image_height) + + path = [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}] + computer_actions.append(make_drag_item(path)) + + return computer_actions + + +def pil_to_base64(image: Image.Image) -> str: + """Convert PIL image to base64 string.""" + buffer = BytesIO() + image.save(buffer, format="PNG") + return base64.b64encode(buffer.getvalue()).decode("utf-8") + + +def process_image_for_uitars(image_data: str, max_pixels: int = MAX_PIXELS, min_pixels: int = MIN_PIXELS) -> tuple[Image.Image, int, int]: + """Process image for UITARS model input.""" + # Decode base64 image + if image_data.startswith('data:image'): + image_data = image_data.split(',')[1] + + image_bytes = base64.b64decode(image_data) + image = Image.open(BytesIO(image_bytes)) + + original_width, original_height = image.size + + # Resize image according to UITARS requirements + if image.width * image.height > max_pixels: + resize_factor = math.sqrt(max_pixels / (image.width * image.height)) + width = int(image.width * resize_factor) + height = int(image.height * resize_factor) + image = image.resize((width, height)) + + if image.width * image.height < min_pixels: + resize_factor = math.sqrt(min_pixels / (image.width * image.height)) + width = math.ceil(image.width * resize_factor) + height = math.ceil(image.height * resize_factor) + image = image.resize((width, height)) + + if image.mode != "RGB": + image = image.convert("RGB") + + return image, original_width, original_height + + +def sanitize_message(msg: Any) -> Any: + """Return a copy of the message with image_url ommited within content parts""" + if isinstance(msg, dict): + result = {} + for key, value in msg.items(): + if key == "content" and isinstance(value, list): + result[key] = [ + {k: v for k, v in item.items() if k != "image_url"} if isinstance(item, dict) else item + for item in value + ] + else: + result[key] = value + return result + elif isinstance(msg, list): + return [sanitize_message(item) for item in msg] + else: + return msg + + +def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any]]: + """ + Convert UITARS internal message format back to LiteLLM format. + + This function processes reasoning, computer_call, and computer_call_output messages + and converts them to the appropriate LiteLLM assistant message format. + + Args: + messages: List of UITARS internal messages + + Returns: + List of LiteLLM formatted messages + """ + litellm_messages = [] + current_assistant_content = [] + + for message in messages: + if isinstance(message, dict): + message_type = message.get("type") + + if message_type == "reasoning": + # Extract reasoning text from summary + summary = message.get("summary", []) + if summary and isinstance(summary, list): + for summary_item in summary: + if isinstance(summary_item, dict) and summary_item.get("type") == "summary_text": + reasoning_text = summary_item.get("text", "") + if reasoning_text: + current_assistant_content.append(f"Thought: {reasoning_text}") + + elif message_type == "computer_call": + # Convert computer action to UITARS action format + action = message.get("action", {}) + action_type = action.get("type") + + if action_type == "click": + x, y = action.get("x", 0), action.get("y", 0) + button = action.get("button", "left") + if button == "left": + action_text = f"Action: click(start_box='({x},{y})')" + elif button == "right": + action_text = f"Action: right_single(start_box='({x},{y})')" + else: + action_text = f"Action: click(start_box='({x},{y})')" + + elif action_type == "double_click": + x, y = action.get("x", 0), action.get("y", 0) + action_text = f"Action: left_double(start_box='({x},{y})')" + + elif action_type == "drag": + start_x, start_y = action.get("start_x", 0), action.get("start_y", 0) + end_x, end_y = action.get("end_x", 0), action.get("end_y", 0) + action_text = f"Action: drag(start_box='({start_x},{start_y})', end_box='({end_x},{end_y})')" + + elif action_type == "key": + key = action.get("key", "") + action_text = f"Action: hotkey(key='{key}')" + + elif action_type == "type": + text = action.get("text", "") + # Escape single quotes in the text + escaped_text = escape_single_quotes(text) + action_text = f"Action: type(content='{escaped_text}')" + + elif action_type == "scroll": + x, y = action.get("x", 0), action.get("y", 0) + direction = action.get("direction", "down") + action_text = f"Action: scroll(start_box='({x},{y})', direction='{direction}')" + + elif action_type == "wait": + action_text = "Action: wait()" + + else: + # Fallback for unknown action types + action_text = f"Action: {action_type}({action})" + + current_assistant_content.append(action_text) + + # When we hit a computer_call_output, finalize the current assistant message + if current_assistant_content: + litellm_messages.append({ + "role": "assistant", + "content": [{"type": "text", "text": "\n".join(current_assistant_content)}] + }) + current_assistant_content = [] + + elif message_type == "computer_call_output": + # Add screenshot from computer call output + output = message.get("output", {}) + if isinstance(output, dict) and output.get("type") == "input_image": + image_url = output.get("image_url", "") + if image_url: + litellm_messages.append({ + "role": "user", + "content": [{"type": "image_url", "image_url": {"url": image_url}}] + }) + + elif message.get("role") == "user": + # # Handle user messages + # content = message.get("content", "") + # if isinstance(content, str): + # litellm_messages.append({ + # "role": "user", + # "content": content + # }) + # elif isinstance(content, list): + # litellm_messages.append({ + # "role": "user", + # "content": content + # }) + pass + + # Add any remaining assistant content + if current_assistant_content: + litellm_messages.append({ + "role": "assistant", + "content": current_assistant_content + }) + + return litellm_messages + +@agent_loop(models=r"(?i).*ui-?tars.*", priority=10) +async def uitars_loop( + messages: Messages, + model: str, + tools: Optional[List[Dict[str, Any]]] = None, + max_retries: Optional[int] = None, + stream: bool = False, + computer_handler=None, + use_prompt_caching: Optional[bool] = False, + _on_api_start=None, + _on_api_end=None, + _on_usage=None, + _on_screenshot=None, + **kwargs +) -> Union[AgentResponse, AsyncGenerator[Dict[str, Any], None]]: + """ + UITARS agent loop using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B model. + + Supports UITARS vision-language models for computer control. + """ + tools = tools or [] + + # Create response items + response_items = [] + + # Find computer tool for screen dimensions + computer_tool = None + for tool_schema in tools: + if tool_schema["type"] == "computer": + computer_tool = tool_schema["computer"] + break + + # Get screen dimensions + screen_width, screen_height = 1024, 768 + if computer_tool: + try: + screen_width, screen_height = await computer_tool.get_dimensions() + except: + pass + + # Process messages to extract instruction and image + instruction = "" + image_data = None + + # Convert messages to list if string + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + # Extract instruction and latest screenshot + for message in reversed(messages): + if isinstance(message, dict): + content = message.get("content", "") + + # Handle different content formats + if isinstance(content, str): + if not instruction and message.get("role") == "user": + instruction = content + elif isinstance(content, list): + for item in content: + if isinstance(item, dict): + if item.get("type") == "text" and not instruction: + instruction = item.get("text", "") + elif item.get("type") == "image_url" and not image_data: + image_url = item.get("image_url", {}) + if isinstance(image_url, dict): + image_data = image_url.get("url", "") + else: + image_data = image_url + + # Also check for computer_call_output with screenshots + if message.get("type") == "computer_call_output" and not image_data: + output = message.get("output", {}) + if isinstance(output, dict) and output.get("type") == "input_image": + image_data = output.get("image_url", "") + + if instruction and image_data: + break + + if not instruction: + instruction = "Help me complete this task by analyzing the screen and taking appropriate actions." + + # Create prompt + user_prompt = UITARS_PROMPT_TEMPLATE.format( + instruction=instruction, + action_space=UITARS_ACTION_SPACE, + language="English" + ) + + # Convert conversation history to LiteLLM format + history_messages = convert_uitars_messages_to_litellm(messages) + + # Prepare messages for liteLLM + litellm_messages = [ + { + "role": "system", + "content": "You are a helpful assistant." + } + ] + + # Add current user instruction with screenshot + current_user_message = { + "role": "user", + "content": [ + {"type": "text", "text": user_prompt}, + ] + } + litellm_messages.append(current_user_message) + + # Process image for UITARS + if not image_data: + # Take screenshot if none found in messages + if computer_handler: + image_data = await computer_handler.screenshot() + await _on_screenshot(image_data, "screenshot_before") + + # Add screenshot to output items so it can be retained in history + response_items.append(make_input_image_item(image_data)) + else: + raise ValueError("No screenshot found in messages and no computer_handler provided") + processed_image, original_width, original_height = process_image_for_uitars(image_data) + encoded_image = pil_to_base64(processed_image) + + # Add conversation history + if history_messages: + litellm_messages.extend(history_messages) + else: + litellm_messages.append({ + "role": "user", + "content": [ + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}} + ] + }) + + # Prepare API call kwargs + api_kwargs = { + "model": model, + "messages": litellm_messages, + "max_tokens": kwargs.get("max_tokens", 500), + "temperature": kwargs.get("temperature", 0.0), + "do_sample": kwargs.get("temperature", 0.0) > 0.0, + "num_retries": max_retries, + **{k: v for k, v in kwargs.items() if k not in ["max_tokens", "temperature"]} + } + + # Call API start hook + if _on_api_start: + await _on_api_start(api_kwargs) + + # Call liteLLM with UITARS model + response = await litellm.acompletion(**api_kwargs) + + # Call API end hook + if _on_api_end: + await _on_api_end(api_kwargs, response) + + # Extract response content + response_content = response.choices[0].message.content.strip() # type: ignore + + # Parse UITARS response + parsed_responses = parse_uitars_response(response_content, original_width, original_height) + + # Convert to computer actions + computer_actions = convert_to_computer_actions(parsed_responses, original_width, original_height) + + # Add computer actions to response items + thought = parsed_responses[0].get("thought", "") + if thought: + response_items.append(make_reasoning_item(thought)) + response_items.extend(computer_actions) + + # Extract usage information + response_usage = { + **LiteLLMCompletionResponsesConfig._transform_chat_completion_usage_to_responses_usage(response.usage).model_dump(), + "response_cost": response._hidden_params.get("response_cost", 0.0), + } + if _on_usage: + await _on_usage(response_usage) + + # Create agent response + agent_response = { + "output": response_items, + "usage": response_usage + } + + return agent_response \ No newline at end of file diff --git a/libs/python/agent2/agent2/responses.py b/libs/python/agent2/agent2/responses.py new file mode 100644 index 00000000..2d7e85d0 --- /dev/null +++ b/libs/python/agent2/agent2/responses.py @@ -0,0 +1,207 @@ +""" +Functions for making various Responses API items from different types of responses. +Based on the OpenAI spec for Responses API items. +""" + +import base64 +import json +import uuid +from typing import List, Dict, Any, Literal, Union, Optional + +from openai.types.responses.response_computer_tool_call_param import ( + ResponseComputerToolCallParam, + ActionClick, + ActionDoubleClick, + ActionDrag, + ActionDragPath, + ActionKeypress, + ActionMove, + ActionScreenshot, + ActionScroll, + ActionType as ActionTypeAction, + ActionWait, + PendingSafetyCheck +) + +from openai.types.responses.response_function_tool_call_param import ResponseFunctionToolCallParam +from openai.types.responses.response_output_text_param import ResponseOutputTextParam +from openai.types.responses.response_reasoning_item_param import ResponseReasoningItemParam, Summary +from openai.types.responses.response_output_message_param import ResponseOutputMessageParam +from openai.types.responses.easy_input_message_param import EasyInputMessageParam +from openai.types.responses.response_input_image_param import ResponseInputImageParam + +def random_id(): + return str(uuid.uuid4()) + +# User message items +def make_input_image_item(image_data: Union[str, bytes]) -> EasyInputMessageParam: + return EasyInputMessageParam( + content=[ + ResponseInputImageParam( + type="input_image", + image_url=f"data:image/png;base64,{base64.b64encode(image_data).decode('utf-8') if isinstance(image_data, bytes) else image_data}" + ) + ], + role="user", + type="message" + ) + +# Text items +def make_reasoning_item(reasoning: str) -> ResponseReasoningItemParam: + return ResponseReasoningItemParam( + id=random_id(), + summary=[ + Summary(text=reasoning, type="summary_text") + ], + type="reasoning" + ) + +def make_output_text_item(content: str) -> ResponseOutputMessageParam: + return ResponseOutputMessageParam( + id=random_id(), + content=[ + ResponseOutputTextParam( + text=content, + type="output_text", + annotations=[] + ) + ], + role="assistant", + status="completed", + type="message" + ) + +# Function call items +def make_function_call_item(function_name: str, arguments: Dict[str, Any], call_id: Optional[str] = None) -> ResponseFunctionToolCallParam: + return ResponseFunctionToolCallParam( + id=random_id(), + call_id=call_id if call_id else random_id(), + name=function_name, + arguments=json.dumps(arguments), + status="completed", + type="function_call" + ) + +# Computer tool call items +def make_click_item(x: int, y: int, button: Literal["left", "right", "wheel", "back", "forward"] = "left", call_id: Optional[str] = None) -> ResponseComputerToolCallParam: + return ResponseComputerToolCallParam( + id=random_id(), + call_id=call_id if call_id else random_id(), + action=ActionClick( + button=button, + type="click", + x=x, + y=y + ), + pending_safety_checks=[], + status="completed", + type="computer_call" + ) + +def make_double_click_item(x: int, y: int, call_id: Optional[str] = None) -> ResponseComputerToolCallParam: + return ResponseComputerToolCallParam( + id=random_id(), + call_id=call_id if call_id else random_id(), + action=ActionDoubleClick( + type="double_click", + x=x, + y=y + ), + pending_safety_checks=[], + status="completed", + type="computer_call" + ) + +def make_drag_item(path: List[Dict[str, int]], call_id: Optional[str] = None) -> ResponseComputerToolCallParam: + drag_path = [ActionDragPath(x=point["x"], y=point["y"]) for point in path] + return ResponseComputerToolCallParam( + id=random_id(), + call_id=call_id if call_id else random_id(), + action=ActionDrag( + path=drag_path, + type="drag" + ), + pending_safety_checks=[], + status="completed", + type="computer_call" + ) + +def make_keypress_item(keys: List[str], call_id: Optional[str] = None) -> ResponseComputerToolCallParam: + return ResponseComputerToolCallParam( + id=random_id(), + call_id=call_id if call_id else random_id(), + action=ActionKeypress( + keys=keys, + type="keypress" + ), + pending_safety_checks=[], + status="completed", + type="computer_call" + ) + +def make_move_item(x: int, y: int, call_id: Optional[str] = None) -> ResponseComputerToolCallParam: + return ResponseComputerToolCallParam( + id=random_id(), + call_id=call_id if call_id else random_id(), + action=ActionMove( + type="move", + x=x, + y=y + ), + pending_safety_checks=[], + status="completed", + type="computer_call" + ) + +def make_screenshot_item(call_id: Optional[str] = None) -> ResponseComputerToolCallParam: + return ResponseComputerToolCallParam( + id=random_id(), + call_id=call_id if call_id else random_id(), + action=ActionScreenshot( + type="screenshot" + ), + pending_safety_checks=[], + status="completed", + type="computer_call" + ) + +def make_scroll_item(x: int, y: int, scroll_x: int, scroll_y: int, call_id: Optional[str] = None) -> ResponseComputerToolCallParam: + return ResponseComputerToolCallParam( + id=random_id(), + call_id=call_id if call_id else random_id(), + action=ActionScroll( + scroll_x=scroll_x, + scroll_y=scroll_y, + type="scroll", + x=x, + y=y + ), + pending_safety_checks=[], + status="completed", + type="computer_call" + ) + +def make_type_item(text: str, call_id: Optional[str] = None) -> ResponseComputerToolCallParam: + return ResponseComputerToolCallParam( + id=random_id(), + call_id=call_id if call_id else random_id(), + action=ActionTypeAction( + text=text, + type="type" + ), + pending_safety_checks=[], + status="completed", + type="computer_call" + ) + +def make_wait_item(call_id: Optional[str] = None) -> ResponseComputerToolCallParam: + return ResponseComputerToolCallParam( + id=random_id(), + call_id=call_id if call_id else random_id(), + action=ActionWait( + type="wait" + ), + pending_safety_checks=[], + status="completed", + type="computer_call" + ) diff --git a/libs/python/agent2/agent2/types.py b/libs/python/agent2/agent2/types.py new file mode 100644 index 00000000..2999fad1 --- /dev/null +++ b/libs/python/agent2/agent2/types.py @@ -0,0 +1,79 @@ +""" +Type definitions for agent2 +""" + +from typing import Dict, List, Any, Optional, Callable, Protocol, Literal +from pydantic import BaseModel +import re +from litellm import ResponseInputParam, ResponsesAPIResponse, ToolParam +from collections.abc import Iterable + +# Agent input types +Messages = str | ResponseInputParam +Tools = Optional[Iterable[ToolParam]] + +# Agent output types +AgentResponse = ResponsesAPIResponse + +# Agent loop registration +class AgentLoopInfo(BaseModel): + """Information about a registered agent loop""" + func: Callable + models_regex: str + priority: int = 0 + + def matches_model(self, model: str) -> bool: + """Check if this loop matches the given model""" + return bool(re.match(self.models_regex, model)) + +# Computer tool interface +class Computer(Protocol): + """Protocol defining the interface for computer interactions.""" + + async def get_environment(self) -> Literal["windows", "mac", "linux", "browser"]: + """Get the current environment type.""" + ... + + async def get_dimensions(self) -> tuple[int, int]: + """Get screen dimensions as (width, height).""" + ... + + async def screenshot(self) -> str: + """Take a screenshot and return as base64 string.""" + ... + + async def click(self, x: int, y: int, button: str = "left") -> None: + """Click at coordinates with specified button.""" + ... + + async def double_click(self, x: int, y: int) -> None: + """Double click at coordinates.""" + ... + + async def scroll(self, x: int, y: int, scroll_x: int, scroll_y: int) -> None: + """Scroll at coordinates with specified scroll amounts.""" + ... + + async def type(self, text: str) -> None: + """Type text.""" + ... + + async def wait(self, ms: int = 1000) -> None: + """Wait for specified milliseconds.""" + ... + + async def move(self, x: int, y: int) -> None: + """Move cursor to coordinates.""" + ... + + async def keypress(self, keys: List[str]) -> None: + """Press key combination.""" + ... + + async def drag(self, path: List[Dict[str, int]]) -> None: + """Drag along specified path.""" + ... + + async def get_current_url(self) -> str: + """Get current URL (for browser environments).""" + ... diff --git a/libs/python/agent2/example.py b/libs/python/agent2/example.py new file mode 100644 index 00000000..21a8da46 --- /dev/null +++ b/libs/python/agent2/example.py @@ -0,0 +1,148 @@ +""" +Example usage of the agent2 library with docstring-based tool definitions. +""" + +import asyncio +import logging + +from agent2 import agent_loop, ComputerAgent +from agent2.types import Messages +from computer import Computer +from computer.helpers import sandboxed + +@sandboxed() +def read_file(location: str) -> str: + """Read contents of a file + + Parameters + ---------- + location : str + Path to the file to read + + Returns + ------- + str + Contents of the file or error message + """ + try: + with open(location, 'r') as f: + return f.read() + except Exception as e: + return f"Error reading file: {str(e)}" + +def save_note(content: str, filename: str = "note.txt") -> str: + """Save content to a note file + + Parameters + ---------- + content : str + Content to save to the file + filename : str, optional + Name of the file to save to (default is "note.txt") + + Returns + ------- + str + Success or error message + """ + try: + with open(filename, 'w') as f: + f.write(content) + return f"Saved note to {filename}" + except Exception as e: + return f"Error saving note: {str(e)}" + +def calculate(a: int, b: int) -> int: + """Calculate the sum of two integers + + Parameters + ---------- + a : int + First integer + b : int + Second integer + + Returns + ------- + int + Sum of the two integers + """ + return a + b + +async def main(): + """Example usage of ComputerAgent with different models""" + + # Example 1: Using Claude with computer and custom tools + print("=== Example 1: Claude with Computer ===") + + import os + import dotenv + import json + dotenv.load_dotenv() + + assert os.getenv("CUA_CONTAINER_NAME") is not None, "CUA_CONTAINER_NAME is not set" + assert os.getenv("CUA_API_KEY") is not None, "CUA_API_KEY is not set" + + async with Computer( + os_type="linux", + provider_type="cloud", + name=os.getenv("CUA_CONTAINER_NAME") or "", + api_key=os.getenv("CUA_API_KEY") or "" + ) as computer: + agent = ComputerAgent( + # Supported models: + + # == OpenAI CUA (computer-use-preview) == + # model="openai/computer-use-preview", + + # == Anthropic CUA (Claude > 3.5) == + # model="anthropic/claude-opus-4-20250514", + # model="anthropic/claude-sonnet-4-20250514", + # model="anthropic/claude-3-7-sonnet-20250219", + # model="anthropic/claude-3-5-sonnet-20240620", + + # == UI-TARS == + # model="huggingface-local/ByteDance-Seed/UI-TARS-1.5-7B", + # TODO: add local mlx provider + # model="mlx-community/UI-TARS-1.5-7B-6bit", + # model="ollama_chat/0000/ui-tars-1.5-7b", + + # == Omniparser + Any LLM == + # model="omniparser+..." + model="omniparser+anthropic/claude-opus-4-20250514", + + tools=[computer], + only_n_most_recent_images=3, + verbosity=logging.INFO, + trajectory_dir="trajectories", + use_prompt_caching=True, + max_trajectory_budget={ "max_budget": 1.0, "raise_error": True, "reset_after_each_run": False }, + ) + + history = [] + while True: + user_input = input("> ") + history.append({"role": "user", "content": user_input}) + + # Non-streaming usage + async for result in agent.run(history, stream=False): + history += result["output"] + + # # Print output + # for item in result["output"]: + # if item["type"] == "message": + # print(item["content"][0]["text"]) + # elif item["type"] == "computer_call": + # action = item["action"] + # action_type = action["type"] + # action_args = {k: v for k, v in action.items() if k != "type"} + # print(f"{action_type}({action_args})") + # elif item["type"] == "function_call": + # action = item["name"] + # action_args = item["arguments"] + # print(f"{action}({action_args})") + # elif item["type"] == "function_call_output": + # print("===>", item["output"]) + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/libs/python/agent2/pyproject.toml b/libs/python/agent2/pyproject.toml new file mode 100644 index 00000000..7d209bfa --- /dev/null +++ b/libs/python/agent2/pyproject.toml @@ -0,0 +1,52 @@ +[build-system] +requires = ["pdm-backend"] +build-backend = "pdm.backend" + +[project] +name = "cua-agent2" +version = "0.1.0" +description = "CUA Agent2 - Decorator-based Computer Use Agent with liteLLM integration" +readme = "README.md" +authors = [ + { name = "TryCua", email = "gh@trycua.com" } +] +dependencies = [ + "httpx>=0.27.0", + "aiohttp>=3.9.3", + "asyncio", + "anyio>=4.4.1", + "typing-extensions>=4.12.2", + "pydantic>=2.6.4", + "rich>=13.7.1", + "python-dotenv>=1.0.1", + "cua-computer>=0.3.0,<0.4.0", + "cua-core>=0.1.0,<0.2.0", + "certifi>=2024.2.2", + "litellm>=1.0.0" +] +requires-python = ">=3.11" + +[project.optional-dependencies] +anthropic = [ + "anthropic>=0.49.0", + "boto3>=1.35.81", +] +openai = [ + "openai>=1.14.0", + "httpx>=0.27.0", +] +all = [ + "anthropic>=0.49.0", + "boto3>=1.35.81", + "openai>=1.14.0", + "httpx>=0.27.0", +] + +[tool.uv] +constraint-dependencies = ["fastrtc>0.43.0", "mlx-audio>0.2.3"] + +[tool.pdm] +distribution = true + +[tool.pdm.build] +includes = ["agent2/"]