diff --git a/libs/python/agent/agent/__init__.py b/libs/python/agent/agent/__init__.py index 6797dab6..08d782d3 100644 --- a/libs/python/agent/agent/__init__.py +++ b/libs/python/agent/agent/__init__.py @@ -5,7 +5,7 @@ agent - Decorator-based Computer Use Agent with liteLLM integration import logging import sys -from .decorators import agent_loop +from .decorators import register_agent from .agent import ComputerAgent from .types import Messages, AgentResponse @@ -13,7 +13,7 @@ from .types import Messages, AgentResponse from . import loops __all__ = [ - "agent_loop", + "register_agent", "ComputerAgent", "Messages", "AgentResponse" diff --git a/libs/python/agent/agent/agent.py b/libs/python/agent/agent/agent.py index f117fe8b..efacea45 100644 --- a/libs/python/agent/agent/agent.py +++ b/libs/python/agent/agent/agent.py @@ -616,9 +616,9 @@ class ComputerAgent: if "click" not in capabilities: raise ValueError(f"Agent loop {self.agent_loop.__name__} does not support click predictions") if hasattr(self.agent_loop, 'predict_click'): - if not self.computer_handler: - raise ValueError("Computer tool is required for predict_click") if not image_b64: + if not self.computer_handler: + raise ValueError("Computer tool or image_b64 is required for predict_click") image_b64 = await self.computer_handler.screenshot() return await self.agent_loop.predict_click( model=self.model, diff --git a/libs/python/agent/agent/decorators.py b/libs/python/agent/agent/decorators.py index 7305b702..7fba0443 100644 --- a/libs/python/agent/agent/decorators.py +++ b/libs/python/agent/agent/decorators.py @@ -2,13 +2,8 @@ Decorators for agent - agent_loop decorator """ -import asyncio -import inspect -from typing import Dict, List, Any, Callable, Optional -from functools import wraps - +from typing import List, Optional from .types import AgentConfigInfo -from .loops.base import AsyncAgentConfig # Global registry _agent_configs: List[AgentConfigInfo] = [] diff --git a/libs/python/agent/agent/loops/base.py b/libs/python/agent/agent/loops/base.py new file mode 100644 index 00000000..887605b1 --- /dev/null +++ b/libs/python/agent/agent/loops/base.py @@ -0,0 +1,76 @@ +""" +Base protocol for async agent configurations +""" + +from typing import Protocol, List, Dict, Any, Optional, Tuple, Union +from abc import abstractmethod +from ..types import AgentCapability + +class AsyncAgentConfig(Protocol): + """Protocol defining the interface for async agent configurations.""" + + @abstractmethod + async def predict_step( + self, + messages: List[Dict[str, Any]], + model: str, + tools: Optional[List[Dict[str, Any]]] = None, + max_retries: Optional[int] = None, + stream: bool = False, + computer_handler=None, + _on_api_start=None, + _on_api_end=None, + _on_usage=None, + _on_screenshot=None, + **kwargs + ) -> Dict[str, Any]: + """ + Predict the next step based on input items. + + Args: + messages: Input items following Responses format (message, function_call, computer_call) + model: Model name to use + tools: Optional list of tool schemas + max_retries: Maximum number of retries for failed API calls + stream: Whether to stream responses + computer_handler: Computer handler instance + _on_api_start: Callback for API start + _on_api_end: Callback for API end + _on_usage: Callback for usage tracking + _on_screenshot: Callback for screenshot events + **kwargs: Additional arguments + + Returns: + Dictionary with "output" (output items) and "usage" array + """ + ... + + @abstractmethod + async def predict_click( + self, + model: str, + image_b64: str, + instruction: str + ) -> Optional[Tuple[int, int]]: + """ + Predict click coordinates based on image and instruction. + + Args: + model: Model name to use + image_b64: Base64 encoded image + instruction: Instruction for where to click + + Returns: + None or tuple with (x, y) coordinates + """ + ... + + @abstractmethod + def get_capabilities(self) -> List[AgentCapability]: + """ + Get list of capabilities supported by this agent config. + + Returns: + List of capability strings (e.g., ["step", "click"]) + """ + ... diff --git a/libs/python/agent/agent/loops/gta1.py b/libs/python/agent/agent/loops/gta1.py new file mode 100644 index 00000000..4d0d3349 --- /dev/null +++ b/libs/python/agent/agent/loops/gta1.py @@ -0,0 +1,178 @@ +""" +GTA1 agent loop implementation for click prediction using litellm.acompletion +""" + +import asyncio +import json +import re +import base64 +from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple +from io import BytesIO +from PIL import Image +import litellm + +from ..decorators import register_agent +from ..types import Messages, AgentResponse, Tools, AgentCapability +from ..loops.base import AsyncAgentConfig + +SYSTEM_PROMPT = ''' +You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point. + +Output the coordinate pair exactly: +(x,y) +''' + +def extract_coordinates(raw_string: str) -> Tuple[float, float]: + """Extract coordinates from model output.""" + try: + matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string) + return tuple(map(float, matches[0])) # type: ignore + except: + return (0.0, 0.0) + +def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360) -> Tuple[int, int]: + """Smart resize function similar to qwen_vl_utils.""" + # Calculate the total pixels + total_pixels = height * width + + # If already within bounds, return original dimensions + if min_pixels <= total_pixels <= max_pixels: + # Round to nearest factor + new_height = (height // factor) * factor + new_width = (width // factor) * factor + return new_height, new_width + + # Calculate scaling factor + if total_pixels > max_pixels: + scale = (max_pixels / total_pixels) ** 0.5 + else: + scale = (min_pixels / total_pixels) ** 0.5 + + # Apply scaling + new_height = int(height * scale) + new_width = int(width * scale) + + # Round to nearest factor + new_height = (new_height // factor) * factor + new_width = (new_width // factor) * factor + + # Ensure minimum size + new_height = max(new_height, factor) + new_width = max(new_width, factor) + + return new_height, new_width + +@register_agent(models=r".*GTA1-.*", priority=10) +class GTA1Config(AsyncAgentConfig): + """GTA1 agent configuration implementing AsyncAgentConfig protocol for click prediction.""" + + async def predict_step( + self, + messages: Messages, + model: str, + tools: Optional[List[Dict[str, Any]]] = None, + max_retries: Optional[int] = None, + stream: bool = False, + computer_handler=None, + use_prompt_caching: Optional[bool] = False, + _on_api_start=None, + _on_api_end=None, + _on_usage=None, + _on_screenshot=None, + **kwargs + ) -> Dict[str, Any]: + """ + GTA1 does not support step prediction - only click prediction. + """ + raise NotImplementedError("GTA1 agent only supports click prediction via predict_click method") + + async def predict_click( + self, + model: str, + image_b64: str, + instruction: str, + **kwargs + ) -> Optional[Tuple[float, float]]: + """ + Predict click coordinates using GTA1 model via litellm.acompletion. + + Args: + model: The GTA1 model name + image_b64: Base64 encoded image + instruction: Instruction for where to click + + Returns: + Tuple of (x, y) coordinates or None if prediction fails + """ + try: + # Decode base64 image + image_data = base64.b64decode(image_b64) + image = Image.open(BytesIO(image_data)) + width, height = image.width, image.height + + # Smart resize the image (similar to qwen_vl_utils) + resized_height, resized_width = smart_resize( + height, width, + factor=28, # Default factor for Qwen models + min_pixels=3136, + max_pixels=4096 * 2160 + ) + resized_image = image.resize((resized_width, resized_height)) + scale_x, scale_y = width / resized_width, height / resized_height + + # Convert resized image back to base64 + buffered = BytesIO() + resized_image.save(buffered, format="PNG") + resized_image_b64 = base64.b64encode(buffered.getvalue()).decode() + + # Prepare system and user messages + system_message = { + "role": "system", + "content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width) + } + + user_message = { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": f"data:image/png;base64,{resized_image_b64}" + } + }, + { + "type": "text", + "text": instruction + } + ] + } + + # Prepare API call kwargs + api_kwargs = { + "model": model, + "messages": [system_message, user_message], + "max_tokens": 32, + "temperature": 0.0, + **kwargs + } + + # Use liteLLM acompletion + response = await litellm.acompletion(**api_kwargs) + + # Extract response text + output_text = response.choices[0].message.content + + # Extract and rescale coordinates + pred_x, pred_y = extract_coordinates(output_text) + pred_x *= scale_x + pred_y *= scale_y + + return (pred_x, pred_y) + + except Exception as e: + print(f"GTA1 click prediction failed: {e}") + return None + + def get_capabilities(self) -> List[AgentCapability]: + """Return the capabilities supported by this agent.""" + return ["click"]