mirror of
https://github.com/trycua/computer.git
synced 2026-01-02 03:20:22 -06:00
Fix circular deps, add GTA1 models
This commit is contained in:
@@ -5,7 +5,7 @@ agent - Decorator-based Computer Use Agent with liteLLM integration
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from .decorators import agent_loop
|
||||
from .decorators import register_agent
|
||||
from .agent import ComputerAgent
|
||||
from .types import Messages, AgentResponse
|
||||
|
||||
@@ -13,7 +13,7 @@ from .types import Messages, AgentResponse
|
||||
from . import loops
|
||||
|
||||
__all__ = [
|
||||
"agent_loop",
|
||||
"register_agent",
|
||||
"ComputerAgent",
|
||||
"Messages",
|
||||
"AgentResponse"
|
||||
|
||||
@@ -616,9 +616,9 @@ class ComputerAgent:
|
||||
if "click" not in capabilities:
|
||||
raise ValueError(f"Agent loop {self.agent_loop.__name__} does not support click predictions")
|
||||
if hasattr(self.agent_loop, 'predict_click'):
|
||||
if not self.computer_handler:
|
||||
raise ValueError("Computer tool is required for predict_click")
|
||||
if not image_b64:
|
||||
if not self.computer_handler:
|
||||
raise ValueError("Computer tool or image_b64 is required for predict_click")
|
||||
image_b64 = await self.computer_handler.screenshot()
|
||||
return await self.agent_loop.predict_click(
|
||||
model=self.model,
|
||||
|
||||
@@ -2,13 +2,8 @@
|
||||
Decorators for agent - agent_loop decorator
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Dict, List, Any, Callable, Optional
|
||||
from functools import wraps
|
||||
|
||||
from typing import List, Optional
|
||||
from .types import AgentConfigInfo
|
||||
from .loops.base import AsyncAgentConfig
|
||||
|
||||
# Global registry
|
||||
_agent_configs: List[AgentConfigInfo] = []
|
||||
|
||||
76
libs/python/agent/agent/loops/base.py
Normal file
76
libs/python/agent/agent/loops/base.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Base protocol for async agent configurations
|
||||
"""
|
||||
|
||||
from typing import Protocol, List, Dict, Any, Optional, Tuple, Union
|
||||
from abc import abstractmethod
|
||||
from ..types import AgentCapability
|
||||
|
||||
class AsyncAgentConfig(Protocol):
|
||||
"""Protocol defining the interface for async agent configurations."""
|
||||
|
||||
@abstractmethod
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Predict the next step based on input items.
|
||||
|
||||
Args:
|
||||
messages: Input items following Responses format (message, function_call, computer_call)
|
||||
model: Model name to use
|
||||
tools: Optional list of tool schemas
|
||||
max_retries: Maximum number of retries for failed API calls
|
||||
stream: Whether to stream responses
|
||||
computer_handler: Computer handler instance
|
||||
_on_api_start: Callback for API start
|
||||
_on_api_end: Callback for API end
|
||||
_on_usage: Callback for usage tracking
|
||||
_on_screenshot: Callback for screenshot events
|
||||
**kwargs: Additional arguments
|
||||
|
||||
Returns:
|
||||
Dictionary with "output" (output items) and "usage" array
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
Returns:
|
||||
None or tuple with (x, y) coordinates
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""
|
||||
Get list of capabilities supported by this agent config.
|
||||
|
||||
Returns:
|
||||
List of capability strings (e.g., ["step", "click"])
|
||||
"""
|
||||
...
|
||||
178
libs/python/agent/agent/loops/gta1.py
Normal file
178
libs/python/agent/agent/loops/gta1.py
Normal file
@@ -0,0 +1,178 @@
|
||||
"""
|
||||
GTA1 agent loop implementation for click prediction using litellm.acompletion
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import base64
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import litellm
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
|
||||
SYSTEM_PROMPT = '''
|
||||
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
|
||||
|
||||
Output the coordinate pair exactly:
|
||||
(x,y)
|
||||
'''
|
||||
|
||||
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
|
||||
"""Extract coordinates from model output."""
|
||||
try:
|
||||
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
|
||||
return tuple(map(float, matches[0])) # type: ignore
|
||||
except:
|
||||
return (0.0, 0.0)
|
||||
|
||||
def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360) -> Tuple[int, int]:
|
||||
"""Smart resize function similar to qwen_vl_utils."""
|
||||
# Calculate the total pixels
|
||||
total_pixels = height * width
|
||||
|
||||
# If already within bounds, return original dimensions
|
||||
if min_pixels <= total_pixels <= max_pixels:
|
||||
# Round to nearest factor
|
||||
new_height = (height // factor) * factor
|
||||
new_width = (width // factor) * factor
|
||||
return new_height, new_width
|
||||
|
||||
# Calculate scaling factor
|
||||
if total_pixels > max_pixels:
|
||||
scale = (max_pixels / total_pixels) ** 0.5
|
||||
else:
|
||||
scale = (min_pixels / total_pixels) ** 0.5
|
||||
|
||||
# Apply scaling
|
||||
new_height = int(height * scale)
|
||||
new_width = int(width * scale)
|
||||
|
||||
# Round to nearest factor
|
||||
new_height = (new_height // factor) * factor
|
||||
new_width = (new_width // factor) * factor
|
||||
|
||||
# Ensure minimum size
|
||||
new_height = max(new_height, factor)
|
||||
new_width = max(new_width, factor)
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
@register_agent(models=r".*GTA1-.*", priority=10)
|
||||
class GTA1Config(AsyncAgentConfig):
|
||||
"""GTA1 agent configuration implementing AsyncAgentConfig protocol for click prediction."""
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
GTA1 does not support step prediction - only click prediction.
|
||||
"""
|
||||
raise NotImplementedError("GTA1 agent only supports click prediction via predict_click method")
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""
|
||||
Predict click coordinates using GTA1 model via litellm.acompletion.
|
||||
|
||||
Args:
|
||||
model: The GTA1 model name
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
try:
|
||||
# Decode base64 image
|
||||
image_data = base64.b64decode(image_b64)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
width, height = image.width, image.height
|
||||
|
||||
# Smart resize the image (similar to qwen_vl_utils)
|
||||
resized_height, resized_width = smart_resize(
|
||||
height, width,
|
||||
factor=28, # Default factor for Qwen models
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160
|
||||
)
|
||||
resized_image = image.resize((resized_width, resized_height))
|
||||
scale_x, scale_y = width / resized_width, height / resized_height
|
||||
|
||||
# Convert resized image back to base64
|
||||
buffered = BytesIO()
|
||||
resized_image.save(buffered, format="PNG")
|
||||
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
# Prepare system and user messages
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width)
|
||||
}
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{resized_image_b64}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": instruction
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": [system_message, user_message],
|
||||
"max_tokens": 32,
|
||||
"temperature": 0.0,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Extract response text
|
||||
output_text = response.choices[0].message.content
|
||||
|
||||
# Extract and rescale coordinates
|
||||
pred_x, pred_y = extract_coordinates(output_text)
|
||||
pred_x *= scale_x
|
||||
pred_y *= scale_y
|
||||
|
||||
return (pred_x, pred_y)
|
||||
|
||||
except Exception as e:
|
||||
print(f"GTA1 click prediction failed: {e}")
|
||||
return None
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["click"]
|
||||
Reference in New Issue
Block a user