mirror of
https://github.com/trycua/computer.git
synced 2026-01-02 03:20:22 -06:00
Merge pull request #558 from tamoghnokandar/tamoghnokandar-patch-1
[AGENT] - New Model Gelato-30B-A3B added
This commit is contained in:
@@ -1,36 +1,38 @@
|
||||
"""
|
||||
Agent loops for agent
|
||||
"""
|
||||
|
||||
# Import the loops to register them
|
||||
from . import (
|
||||
anthropic,
|
||||
composed_grounded,
|
||||
gemini,
|
||||
glm45v,
|
||||
gta1,
|
||||
holo,
|
||||
internvl,
|
||||
moondream3,
|
||||
omniparser,
|
||||
openai,
|
||||
opencua,
|
||||
qwen,
|
||||
uitars,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"uitars",
|
||||
"omniparser",
|
||||
"gta1",
|
||||
"composed_grounded",
|
||||
"glm45v",
|
||||
"opencua",
|
||||
"internvl",
|
||||
"holo",
|
||||
"moondream3",
|
||||
"gemini",
|
||||
"qwen",
|
||||
]
|
||||
"""
|
||||
Agent loops for agent
|
||||
"""
|
||||
|
||||
# Import the loops to register them
|
||||
from . import (
|
||||
anthropic,
|
||||
composed_grounded,
|
||||
gelato,
|
||||
gemini,
|
||||
glm45v,
|
||||
gta1,
|
||||
holo,
|
||||
internvl,
|
||||
moondream3,
|
||||
omniparser,
|
||||
openai,
|
||||
opencua,
|
||||
qwen,
|
||||
uitars,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"anthropic",
|
||||
"openai",
|
||||
"uitars",
|
||||
"omniparser",
|
||||
"gta1",
|
||||
"composed_grounded",
|
||||
"glm45v",
|
||||
"opencua",
|
||||
"internvl",
|
||||
"holo",
|
||||
"moondream3",
|
||||
"gemini",
|
||||
"qwen",
|
||||
"gelato",
|
||||
]
|
||||
|
||||
183
libs/python/agent/agent/loops/gelato.py
Normal file
183
libs/python/agent/agent/loops/gelato.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Gelato agent loop implementation for click prediction using litellm.acompletion
|
||||
Model: https://huggingface.co/mlfoundations/Gelato-30B-A3B
|
||||
Code: https://github.com/mlfoundations/Gelato/tree/main
|
||||
"""
|
||||
|
||||
import base64
|
||||
import math
|
||||
import re
|
||||
from io import BytesIO
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..types import AgentCapability
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. For elements with area, return the center point.
|
||||
|
||||
Output the coordinate pair exactly:
|
||||
(x,y)
|
||||
"""
|
||||
|
||||
|
||||
def extract_coordinates(raw_string):
|
||||
"""
|
||||
Extract the coordinates from the raw string.
|
||||
Args:
|
||||
raw_string: str (e.g. "(100, 200)")
|
||||
Returns:
|
||||
x: float (e.g. 100.0)
|
||||
y: float (e.g. 200.0)
|
||||
"""
|
||||
try:
|
||||
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
|
||||
return [tuple(map(int, match)) for match in matches][0]
|
||||
except:
|
||||
return 0, 0
|
||||
|
||||
|
||||
def smart_resize(
|
||||
height: int,
|
||||
width: int,
|
||||
factor: int = 28,
|
||||
min_pixels: int = 3136,
|
||||
max_pixels: int = 8847360,
|
||||
) -> Tuple[int, int]:
|
||||
"""Smart resize function similar to qwen_vl_utils."""
|
||||
# Calculate the total pixels
|
||||
total_pixels = height * width
|
||||
|
||||
# If already within bounds, return original dimensions
|
||||
if min_pixels <= total_pixels <= max_pixels:
|
||||
# Round to nearest factor
|
||||
new_height = (height // factor) * factor
|
||||
new_width = (width // factor) * factor
|
||||
return new_height, new_width
|
||||
|
||||
# Calculate scaling factor
|
||||
if total_pixels > max_pixels:
|
||||
scale = (max_pixels / total_pixels) ** 0.5
|
||||
else:
|
||||
scale = (min_pixels / total_pixels) ** 0.5
|
||||
|
||||
# Apply scaling
|
||||
new_height = int(height * scale)
|
||||
new_width = int(width * scale)
|
||||
|
||||
# Round to nearest factor
|
||||
new_height = (new_height // factor) * factor
|
||||
new_width = (new_width // factor) * factor
|
||||
|
||||
# Ensure minimum size
|
||||
new_height = max(new_height, factor)
|
||||
new_width = max(new_width, factor)
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
|
||||
@register_agent(models=r".*Gelato.*")
|
||||
class GelatoConfig(AsyncAgentConfig):
|
||||
"""Gelato agent configuration implementing AsyncAgentConfig protocol for click prediction."""
|
||||
|
||||
def __init__(self):
|
||||
self.current_model = None
|
||||
self.last_screenshot_b64 = None
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
async def predict_click(
|
||||
self, model: str, image_b64: str, instruction: str, **kwargs
|
||||
) -> Optional[Tuple[float, float]]:
|
||||
"""
|
||||
Predict click coordinates using UI-Ins model via litellm.acompletion.
|
||||
|
||||
Args:
|
||||
model: The UI-Ins model name
|
||||
image_b64: Base64 encoded image
|
||||
instruction: Instruction for where to click
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
# Decode base64 image
|
||||
image_data = base64.b64decode(image_b64)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
width, height = image.width, image.height
|
||||
|
||||
# Smart resize the image (similar to qwen_vl_utils)
|
||||
resized_height, resized_width = smart_resize(
|
||||
height,
|
||||
width,
|
||||
factor=28, # Default factor for Qwen models
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160,
|
||||
)
|
||||
resized_image = image.resize((resized_width, resized_height))
|
||||
scale_x, scale_y = width / resized_width, height / resized_height
|
||||
|
||||
# Convert resized image back to base64
|
||||
buffered = BytesIO()
|
||||
resized_image.save(buffered, format="PNG")
|
||||
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
# Prepare system and user messages
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": SYSTEM_PROMPT.strip()}],
|
||||
}
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/png;base64,{resized_image_b64}"},
|
||||
},
|
||||
{"type": "text", "text": instruction},
|
||||
],
|
||||
}
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": [system_message, user_message],
|
||||
"max_tokens": 2056,
|
||||
"temperature": 0.0,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Extract response text
|
||||
output_text = response.choices[0].message.content # type: ignore
|
||||
|
||||
# Extract and rescale coordinates
|
||||
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
|
||||
pred_x *= scale_x
|
||||
pred_y *= scale_y
|
||||
|
||||
return (math.floor(pred_x), math.floor(pred_y))
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["click"]
|
||||
Reference in New Issue
Block a user