Fix circular deps, add GTA1 models

This commit is contained in:
Dillon DuPont
2025-07-29 18:53:46 -04:00
parent 876d42af0a
commit 3a67485e42
5 changed files with 259 additions and 10 deletions

View File

@@ -5,7 +5,7 @@ agent - Decorator-based Computer Use Agent with liteLLM integration
import logging
import sys
from .decorators import agent_loop
from .decorators import register_agent
from .agent import ComputerAgent
from .types import Messages, AgentResponse
@@ -13,7 +13,7 @@ from .types import Messages, AgentResponse
from . import loops
__all__ = [
"agent_loop",
"register_agent",
"ComputerAgent",
"Messages",
"AgentResponse"

View File

@@ -616,9 +616,9 @@ class ComputerAgent:
if "click" not in capabilities:
raise ValueError(f"Agent loop {self.agent_loop.__name__} does not support click predictions")
if hasattr(self.agent_loop, 'predict_click'):
if not self.computer_handler:
raise ValueError("Computer tool is required for predict_click")
if not image_b64:
if not self.computer_handler:
raise ValueError("Computer tool or image_b64 is required for predict_click")
image_b64 = await self.computer_handler.screenshot()
return await self.agent_loop.predict_click(
model=self.model,

View File

@@ -2,13 +2,8 @@
Decorators for agent - agent_loop decorator
"""
import asyncio
import inspect
from typing import Dict, List, Any, Callable, Optional
from functools import wraps
from typing import List, Optional
from .types import AgentConfigInfo
from .loops.base import AsyncAgentConfig
# Global registry
_agent_configs: List[AgentConfigInfo] = []

View File

@@ -0,0 +1,76 @@
"""
Base protocol for async agent configurations
"""
from typing import Protocol, List, Dict, Any, Optional, Tuple, Union
from abc import abstractmethod
from ..types import AgentCapability
class AsyncAgentConfig(Protocol):
"""Protocol defining the interface for async agent configurations."""
@abstractmethod
async def predict_step(
self,
messages: List[Dict[str, Any]],
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
) -> Dict[str, Any]:
"""
Predict the next step based on input items.
Args:
messages: Input items following Responses format (message, function_call, computer_call)
model: Model name to use
tools: Optional list of tool schemas
max_retries: Maximum number of retries for failed API calls
stream: Whether to stream responses
computer_handler: Computer handler instance
_on_api_start: Callback for API start
_on_api_end: Callback for API end
_on_usage: Callback for usage tracking
_on_screenshot: Callback for screenshot events
**kwargs: Additional arguments
Returns:
Dictionary with "output" (output items) and "usage" array
"""
...
@abstractmethod
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates based on image and instruction.
Args:
model: Model name to use
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
None or tuple with (x, y) coordinates
"""
...
@abstractmethod
def get_capabilities(self) -> List[AgentCapability]:
"""
Get list of capabilities supported by this agent config.
Returns:
List of capability strings (e.g., ["step", "click"])
"""
...

View File

@@ -0,0 +1,178 @@
"""
GTA1 agent loop implementation for click prediction using litellm.acompletion
"""
import asyncio
import json
import re
import base64
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
from io import BytesIO
from PIL import Image
import litellm
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..loops.base import AsyncAgentConfig
SYSTEM_PROMPT = '''
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
Output the coordinate pair exactly:
(x,y)
'''
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
"""Extract coordinates from model output."""
try:
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
return tuple(map(float, matches[0])) # type: ignore
except:
return (0.0, 0.0)
def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360) -> Tuple[int, int]:
"""Smart resize function similar to qwen_vl_utils."""
# Calculate the total pixels
total_pixels = height * width
# If already within bounds, return original dimensions
if min_pixels <= total_pixels <= max_pixels:
# Round to nearest factor
new_height = (height // factor) * factor
new_width = (width // factor) * factor
return new_height, new_width
# Calculate scaling factor
if total_pixels > max_pixels:
scale = (max_pixels / total_pixels) ** 0.5
else:
scale = (min_pixels / total_pixels) ** 0.5
# Apply scaling
new_height = int(height * scale)
new_width = int(width * scale)
# Round to nearest factor
new_height = (new_height // factor) * factor
new_width = (new_width // factor) * factor
# Ensure minimum size
new_height = max(new_height, factor)
new_width = max(new_width, factor)
return new_height, new_width
@register_agent(models=r".*GTA1-.*", priority=10)
class GTA1Config(AsyncAgentConfig):
"""GTA1 agent configuration implementing AsyncAgentConfig protocol for click prediction."""
async def predict_step(
self,
messages: Messages,
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
) -> Dict[str, Any]:
"""
GTA1 does not support step prediction - only click prediction.
"""
raise NotImplementedError("GTA1 agent only supports click prediction via predict_click method")
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
) -> Optional[Tuple[float, float]]:
"""
Predict click coordinates using GTA1 model via litellm.acompletion.
Args:
model: The GTA1 model name
image_b64: Base64 encoded image
instruction: Instruction for where to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
try:
# Decode base64 image
image_data = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_data))
width, height = image.width, image.height
# Smart resize the image (similar to qwen_vl_utils)
resized_height, resized_width = smart_resize(
height, width,
factor=28, # Default factor for Qwen models
min_pixels=3136,
max_pixels=4096 * 2160
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height
# Convert resized image back to base64
buffered = BytesIO()
resized_image.save(buffered, format="PNG")
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
# Prepare system and user messages
system_message = {
"role": "system",
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width)
}
user_message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{resized_image_b64}"
}
},
{
"type": "text",
"text": instruction
}
]
}
# Prepare API call kwargs
api_kwargs = {
"model": model,
"messages": [system_message, user_message],
"max_tokens": 32,
"temperature": 0.0,
**kwargs
}
# Use liteLLM acompletion
response = await litellm.acompletion(**api_kwargs)
# Extract response text
output_text = response.choices[0].message.content
# Extract and rescale coordinates
pred_x, pred_y = extract_coordinates(output_text)
pred_x *= scale_x
pred_y *= scale_y
return (pred_x, pred_y)
except Exception as e:
print(f"GTA1 click prediction failed: {e}")
return None
def get_capabilities(self) -> List[AgentCapability]:
"""Return the capabilities supported by this agent."""
return ["click"]