added grounding+planning composed loop

This commit is contained in:
Dillon DuPont
2025-08-04 16:32:05 -04:00
parent d5564977f0
commit f87b8eaea5
10 changed files with 904 additions and 400 deletions

View File

@@ -8,5 +8,6 @@ from . import openai
from . import uitars
from . import omniparser
from . import gta1
from . import composed_grounded
__all__ = ["anthropic", "openai", "uitars", "omniparser", "gta1"]
__all__ = ["anthropic", "openai", "uitars", "omniparser", "gta1", "composed_grounded"]

View File

@@ -1285,7 +1285,7 @@ def _merge_consecutive_text(content_list: List[Dict[str, Any]]) -> List[Dict[str
return merged
@register_agent(models=r".*claude-.*", priority=5)
@register_agent(models=r".*claude-.*")
class AnthropicHostedToolsConfig(AsyncAgentConfig):
"""Anthropic hosted tools agent configuration implementing AsyncAgentConfig protocol."""

View File

@@ -0,0 +1,318 @@
"""
Composed-grounded agent loop implementation that combines grounding and thinking models.
Uses a two-stage approach: grounding model for element detection, thinking model for reasoning.
"""
import uuid
import asyncio
import json
import base64
from typing import Dict, List, Any, Optional, Tuple
from io import BytesIO
from PIL import Image
import litellm
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
from ..loops.base import AsyncAgentConfig
from ..responses import (
convert_computer_calls_xy2desc,
convert_responses_items_to_completion_messages,
convert_completion_messages_to_responses_items,
convert_computer_calls_desc2xy,
get_all_element_descriptions
)
from ..agent import find_agent_config
GROUNDED_COMPUTER_TOOL_SCHEMA = {
"type": "function",
"function": {
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment"
],
"description": "The action to perform"
},
"element_description": {
"type": "string",
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
},
"start_element_description": {
"type": "string",
"description": "Description of the element to start dragging from (required for drag action)"
},
"end_element_description": {
"type": "string",
"description": "Description of the element to drag to (required for drag action)"
},
"text": {
"type": "string",
"description": "The text to type (required for type action)"
},
"keys": {
"type": "string",
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
},
"button": {
"type": "string",
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
},
},
"required": [
"action"
]
}
}
}
def _prepare_tools_for_grounded(tool_schemas: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Prepare tools for grounded API format"""
grounded_tools = []
for schema in tool_schemas:
if schema["type"] == "computer":
grounded_tools.append(GROUNDED_COMPUTER_TOOL_SCHEMA)
else:
grounded_tools.append(schema)
return grounded_tools
def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str]:
"""Get the last computer call output image from messages."""
for message in reversed(messages):
if (isinstance(message, dict) and
message.get("type") == "computer_call_output" and
isinstance(message.get("output"), dict) and
message["output"].get("type") == "input_image"):
image_url = message["output"].get("image_url", "")
if image_url.startswith("data:image/png;base64,"):
return image_url.split(",", 1)[1]
return None
@register_agent(r".*\+.*", priority=10)
class ComposedGroundedConfig:
"""
Composed-grounded agent configuration that uses both grounding and thinking models.
The model parameter should be in format: "grounding_model+thinking_model"
e.g., "gpt-4o+claude-3-5-sonnet-20241022"
"""
def __init__(self):
self.desc2xy: Dict[str, Tuple[float, float]] = {}
async def predict_step(
self,
messages: List[Dict[str, Any]],
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
) -> Dict[str, Any]:
"""
Composed-grounded predict step implementation.
Process:
0. Store last computer call image, if none then take a screenshot
1. Convert computer calls from xy to descriptions
2. Convert responses items to completion messages
3. Call thinking model with litellm.acompletion
4. Convert completion messages to responses items
5. Get all element descriptions and populate desc2xy mapping
6. Convert computer calls from descriptions back to xy coordinates
7. Return output and usage
"""
# Parse the composed model
if "+" not in model:
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
grounding_model, thinking_model = model.split("+", 1)
pre_output_items = []
# Step 0: Store last computer call image, if none then take a screenshot
last_image_b64 = get_last_computer_call_image(messages)
if last_image_b64 is None:
# Take a screenshot
screenshot_b64 = await computer_handler.screenshot() # type: ignore
if screenshot_b64:
call_id = uuid.uuid4().hex
pre_output_items += [
{
"type": "message",
"role": "assistant",
"content": [
{
"type": "output_text",
"text": "Taking a screenshot to see the current computer screen."
}
]
},
{
"action": {
"type": "screenshot"
},
"call_id": call_id,
"status": "completed",
"type": "computer_call"
},
{
"type": "computer_call_output",
"call_id": call_id,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_b64}"
}
},
]
last_image_b64 = screenshot_b64
# Call screenshot callback if provided
if _on_screenshot:
await _on_screenshot(screenshot_b64)
tool_schemas = _prepare_tools_for_grounded(tools) # type: ignore
# Step 1: Convert computer calls from xy to descriptions
input_messages = messages + pre_output_items
messages_with_descriptions = convert_computer_calls_xy2desc(input_messages, self.desc2xy)
# Step 2: Convert responses items to completion messages
completion_messages = convert_responses_items_to_completion_messages(
messages_with_descriptions,
allow_images_in_tool_results=False
)
# Step 3: Call thinking model with litellm.acompletion
api_kwargs = {
"model": thinking_model,
"messages": completion_messages,
"tools": tool_schemas,
"max_retries": max_retries,
"stream": stream,
**kwargs
}
if use_prompt_caching:
api_kwargs["use_prompt_caching"] = use_prompt_caching
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
# Make the completion call
response = await litellm.acompletion(**api_kwargs)
# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)
# Extract usage information
usage = {
**response.usage.model_dump(), # type: ignore
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
await _on_usage(usage)
# Step 4: Convert completion messages back to responses items format
response_dict = response.model_dump() # type: ignore
choice_messages = [choice["message"] for choice in response_dict["choices"]]
thinking_output_items = []
for choice_message in choice_messages:
thinking_output_items.extend(convert_completion_messages_to_responses_items([choice_message]))
# Step 5: Get all element descriptions and populate desc2xy mapping
element_descriptions = get_all_element_descriptions(thinking_output_items)
if element_descriptions and last_image_b64:
# Use grounding model to predict coordinates for each description
grounding_agent_conf = find_agent_config(grounding_model)
if grounding_agent_conf:
grounding_agent = grounding_agent_conf.agent_class()
for desc in element_descriptions:
coords = await grounding_agent.predict_click(
model=grounding_model,
image_b64=last_image_b64,
instruction=desc
)
if coords:
self.desc2xy[desc] = coords
# Step 6: Convert computer calls from descriptions back to xy coordinates
final_output_items = convert_computer_calls_desc2xy(thinking_output_items, self.desc2xy)
# Step 7: Return output and usage
return {
"output": pre_output_items + final_output_items,
"usage": usage
}
async def predict_click(
self,
model: str,
image_b64: str,
instruction: str,
**kwargs
) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates using the grounding model.
For composed models, uses only the grounding model part for click prediction.
"""
# Parse the composed model to get grounding model
if "+" not in model:
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
grounding_model, thinking_model = model.split("+", 1)
# Find and use the grounding agent
grounding_agent_conf = find_agent_config(grounding_model)
if grounding_agent_conf:
grounding_agent = grounding_agent_conf.agent_class()
return await grounding_agent.predict_click(
model=grounding_model,
image_b64=image_b64,
instruction=instruction,
**kwargs
)
return None
def get_capabilities(self) -> List[AgentCapability]:
"""Return the capabilities supported by this agent."""
return ["click", "step"]

View File

@@ -26,73 +26,6 @@ Output the coordinate pair exactly:
(x,y)
'''.strip()
# Global dictionary to map coordinates to descriptions
xy2desc: Dict[Tuple[float, float], str] = {}
GTA1_TOOL_SCHEMA = {
"type": "function",
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment"
],
"description": "The action to perform"
},
"element_description": {
"type": "string",
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
},
"start_element_description": {
"type": "string",
"description": "Description of the element to start dragging from (required for drag action)"
},
"end_element_description": {
"type": "string",
"description": "Description of the element to drag to (required for drag action)"
},
"text": {
"type": "string",
"description": "The text to type (required for type action)"
},
"keys": {
"type": "string",
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
},
"button": {
"type": "string",
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
},
},
"required": [
"action"
]
}
}
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
"""Extract coordinates from model output."""
try:
@@ -101,173 +34,6 @@ def extract_coordinates(raw_string: str) -> Tuple[float, float]:
except:
return (0.0, 0.0)
def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Get the last computer_call_output message from a messages list.
Args:
messages: List of messages to search through
Returns:
The last computer_call_output message dict, or None if not found
"""
for message in reversed(messages):
if isinstance(message, dict) and message.get("type") == "computer_call_output":
return message
return None
def _prepare_tools_for_gta1(tool_schemas: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Prepare tools for GTA1 API format"""
gta1_tools = []
for schema in tool_schemas:
if schema["type"] == "computer":
gta1_tools.append(GTA1_TOOL_SCHEMA)
else:
gta1_tools.append(schema)
return gta1_tools
async def replace_function_with_computer_call_gta1(item: Dict[str, Any], agent_instance) -> List[Dict[str, Any]]:
"""Convert function_call to computer_call format using GTA1 click prediction."""
global xy2desc
item_type = item.get("type")
async def _get_xy(element_description: Optional[str], last_image_b64: str) -> Union[Tuple[float, float], Tuple[None, None]]:
if element_description is None:
return (None, None)
# Use self.predict_click to get coordinates from description
coords = await agent_instance.predict_click(
model=agent_instance.current_model,
image_b64=last_image_b64,
instruction=element_description
)
if coords:
# Store the mapping from coordinates to description
xy2desc[coords] = element_description
return coords
return (None, None)
if item_type == "function_call":
fn_name = item.get("name")
fn_args = json.loads(item.get("arguments", "{}"))
item_id = item.get("id")
call_id = item.get("call_id")
if fn_name == "computer":
action = fn_args.get("action")
element_description = fn_args.get("element_description")
start_element_description = fn_args.get("start_element_description")
end_element_description = fn_args.get("end_element_description")
text = fn_args.get("text")
keys = fn_args.get("keys")
button = fn_args.get("button")
scroll_x = fn_args.get("scroll_x")
scroll_y = fn_args.get("scroll_y")
# Get the last computer output image for click prediction
last_image_b64 = agent_instance.last_screenshot_b64 or ""
x, y = await _get_xy(element_description, last_image_b64)
start_x, start_y = await _get_xy(start_element_description, last_image_b64)
end_x, end_y = await _get_xy(end_element_description, last_image_b64)
action_args = {
"type": action,
"x": x,
"y": y,
"start_x": start_x,
"start_y": start_y,
"end_x": end_x,
"end_y": end_y,
"text": text,
"keys": keys,
"button": button,
"scroll_x": scroll_x,
"scroll_y": scroll_y
}
# Remove None values to keep the JSON clean
action_args = {k: v for k, v in action_args.items() if v is not None}
return [{
"type": "computer_call",
"action": action_args,
"id": item_id,
"call_id": call_id,
"status": "completed"
}]
return [item]
async def replace_computer_call_with_function_gta1(item: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Convert computer_call back to function_call format using descriptions.
Only READS from the global xy2desc dictionary.
Args:
item: The item to convert
"""
global xy2desc
item_type = item.get("type")
def _get_element_description(x: Optional[float], y: Optional[float]) -> Optional[str]:
"""Get element description from coordinates, return None if coordinates are None"""
if x is None or y is None:
return None
return xy2desc.get((x, y))
if item_type == "computer_call":
action_data = item.get("action", {})
# Extract coordinates and convert back to element descriptions
element_description = _get_element_description(action_data.get("x"), action_data.get("y"))
start_element_description = _get_element_description(action_data.get("start_x"), action_data.get("start_y"))
end_element_description = _get_element_description(action_data.get("end_x"), action_data.get("end_y"))
# Build function arguments
fn_args = {
"action": action_data.get("type"),
"element_description": element_description,
"start_element_description": start_element_description,
"end_element_description": end_element_description,
"text": action_data.get("text"),
"keys": action_data.get("keys"),
"button": action_data.get("button"),
"scroll_x": action_data.get("scroll_x"),
"scroll_y": action_data.get("scroll_y")
}
# Remove None values to keep the JSON clean
fn_args = {k: v for k, v in fn_args.items() if v is not None}
return [{
"type": "function_call",
"name": "computer",
"arguments": json.dumps(fn_args),
"id": item.get("id"),
"call_id": item.get("call_id"),
"status": "completed",
# Fall back to string representation
# "content": f"Used tool: {action_data.get('type')}({json.dumps(fn_args)})"
}]
elif item_type == "computer_call_output":
# Simple conversion: computer_call_output -> function_call_output (text only), user message (image)
return [
{
"type": "function_call_output",
"call_id": item.get("call_id"),
"output": "Tool executed successfully. See the current computer screenshot below, if nothing has changed yet then you may need to wait before trying again.",
"id": item.get("id"),
"status": "completed"
}, {
"role": "user",
"content": [item.get("output")]
}
]
return [item]
def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360) -> Tuple[int, int]:
"""Smart resize function similar to qwen_vl_utils."""
# Calculate the total pixels
@@ -300,7 +66,7 @@ def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 31
return new_height, new_width
@register_agent(models=r".*GTA1.*", priority=10)
@register_agent(models=r".*GTA1.*")
class GTA1Config(AsyncAgentConfig):
"""GTA1 agent configuration implementing AsyncAgentConfig protocol for click prediction."""
@@ -308,153 +74,6 @@ class GTA1Config(AsyncAgentConfig):
self.current_model = None
self.last_screenshot_b64 = None
async def predict_step(
self,
messages: Messages,
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
stream: bool = False,
computer_handler=None,
use_prompt_caching: Optional[bool] = False,
_on_api_start=None,
_on_api_end=None,
_on_usage=None,
_on_screenshot=None,
**kwargs
) -> Dict[str, Any]:
"""
GTA1 agent loop implementation using liteLLM responses with element descriptions.
Follows the 4-step process:
1. Prepare tools
2. Replace computer calls with function calls (using descriptions)
3. API call
4. Replace function calls with computer calls (using predict_click)
"""
models = model.split("+")
if len(models) != 2:
raise ValueError("GTA1 model must be in the format <gta1_model_name>+<planning_model_name> to be used in an agent loop")
gta1_model, llm_model = models
self.current_model = gta1_model
tools = tools or []
# Step 0: Prepare tools
gta1_tools = _prepare_tools_for_gta1(tools)
# Get last computer_call_output for screenshot reference
# Convert messages to list of dicts first
message_list = []
for message in messages:
if not isinstance(message, dict):
message_list.append(message.__dict__)
else:
message_list.append(message)
last_computer_call_output = get_last_computer_call_output(message_list)
if last_computer_call_output:
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
if image_url.startswith("data:image/"):
self.last_screenshot_b64 = image_url.split(",")[-1]
else:
self.last_screenshot_b64 = image_url
# Step 1: If there's no screenshot, simulate the model calling the screenshot function
pre_output = []
if not self.last_screenshot_b64 and computer_handler:
screenshot_base64 = await computer_handler.screenshot()
if _on_screenshot:
await _on_screenshot(screenshot_base64, "screenshot_initial")
call_id = uuid.uuid4().hex
pre_output += [
{
"type": "message",
"role": "assistant",
"content": [
{
"type": "output_text",
"text": "Taking a screenshot to see the current computer screen."
}
]
},
{
"action": {
"type": "screenshot"
},
"call_id": call_id,
"status": "completed",
"type": "computer_call"
},
{
"type": "computer_call_output",
"call_id": call_id,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_base64}"
}
},
]
# Update the last screenshot for future use
self.last_screenshot_b64 = screenshot_base64
message_list += pre_output
# Step 2: Replace computer calls with function calls (using descriptions)
new_messages = []
for message in message_list:
new_messages += await replace_computer_call_with_function_gta1(message)
messages = new_messages
# Step 3: API call
api_kwargs = {
"model": llm_model,
"input": messages,
"tools": gta1_tools if gta1_tools else None,
"stream": stream,
"truncation": "auto",
"num_retries": max_retries,
**kwargs
}
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
# Use liteLLM responses
response = await litellm.aresponses(**api_kwargs)
# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)
# Extract usage information
usage = {
**response.usage.model_dump(), # type: ignore
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
}
if _on_usage:
await _on_usage(usage)
# Step 4: Replace function calls with computer calls (using predict_click)
new_output = []
for i in range(len(response.output)): # type: ignore
output_item = response.output[i] # type: ignore
# Convert to dict if it has model_dump method, otherwise use as-is
if hasattr(output_item, 'model_dump'):
item_dict = output_item.model_dump() # type: ignore
else:
item_dict = output_item # type: ignore
new_output += await replace_function_with_computer_call_gta1(item_dict, self) # type: ignore
return {
"output": pre_output + new_output,
"usage": usage
}
async def predict_click(
self,
model: str,
@@ -528,10 +147,10 @@ class GTA1Config(AsyncAgentConfig):
response = await litellm.acompletion(**api_kwargs)
# Extract response text
output_text = response.choices[0].message.content
output_text = response.choices[0].message.content # type: ignore
# Extract and rescale coordinates
pred_x, pred_y = extract_coordinates(output_text)
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
pred_x *= scale_x
pred_y *= scale_y
@@ -539,4 +158,4 @@ class GTA1Config(AsyncAgentConfig):
def get_capabilities(self) -> List[AgentCapability]:
"""Return the capabilities supported by this agent."""
return ["click", "step"]
return ["click"]

View File

@@ -249,7 +249,7 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
return [item]
@register_agent(models=r"omniparser\+.*|omni\+.*", priority=10)
@register_agent(models=r"omniparser\+.*|omni\+.*")
class OmniparsrConfig(AsyncAgentConfig):
"""Omniparser agent configuration implementing AsyncAgentConfig protocol."""

View File

@@ -3,9 +3,12 @@ OpenAI computer-use-preview agent loop implementation using liteLLM
"""
import asyncio
import base64
import json
from io import BytesIO
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
import litellm
from PIL import Image
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
@@ -36,7 +39,7 @@ def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools:
return openai_tools
@register_agent(models=r".*computer-use-preview.*", priority=10)
@register_agent(models=r".*computer-use-preview.*")
class OpenAIComputerUseConfig:
"""
OpenAI computer-use-preview agent configuration using liteLLM responses.
@@ -128,8 +131,8 @@ class OpenAIComputerUseConfig:
"""
Predict click coordinates based on image and instruction.
Note: OpenAI computer-use-preview doesn't support direct click prediction,
so this returns None.
Uses OpenAI computer-use-preview with manually constructed input items
and a prompt that instructs the agent to only output clicks.
Args:
model: Model name to use
@@ -137,8 +140,94 @@ class OpenAIComputerUseConfig:
instruction: Instruction for where to click
Returns:
None (not supported by OpenAI computer-use-preview)
Tuple of (x, y) coordinates or None if prediction fails
"""
# TODO: implement this correctly
# Scale image to half size
try:
image_data = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_data))
# Scale to half size
new_width = image.width // 2
new_height = image.height // 2
scaled_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
# Convert back to base64
buffer = BytesIO()
scaled_image.save(buffer, format='PNG')
image_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
except Exception:
# If scaling fails, use original image
pass
# Manually construct input items with image and click instruction
input_items = [
{
"role": "user",
"content": f"You are a UI grounding expert. Look at the image and {instruction}. Output ONLY a click action on the target element. No explanations, confirmations, or additional text."
},
{
"role": "user",
"content": [
{
"type": "input_image",
"image_url": f"data:image/png;base64,{image_b64}"
}
]
}
]
# Get image dimensions from base64 data
try:
image_data = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_data))
display_width, display_height = image.size
except Exception:
# Fallback to default dimensions if image parsing fails
display_width, display_height = 1024, 768
# Prepare computer tool for click actions
computer_tool = {
"type": "computer_use_preview",
"display_width": display_width,
"display_height": display_height,
"environment": "linux"
}
# Prepare API call kwargs
api_kwargs = {
"model": model,
"input": input_items,
"tools": [computer_tool],
"stream": False,
"reasoning": {"summary": "concise"},
"truncation": "auto",
"max_tokens": 100 # Keep response short for click prediction
}
# Use liteLLM responses
response = await litellm.aresponses(**api_kwargs)
# Extract click coordinates from response output
output_dict = response.model_dump()
output_items = output_dict.get("output", [])
# print(output_items)
# Look for computer_call with click action
for item in output_items:
if (isinstance(item, dict) and
item.get("type") == "computer_call" and
isinstance(item.get("action"), dict)):
action = item["action"]
if action.get("type") == "click":
x = action.get("x")
y = action.get("y")
if x is not None and y is not None:
return (int(x) * 2, int(y) * 2)
return None
def get_capabilities(self) -> List[AgentCapability]:
@@ -148,4 +237,4 @@ class OpenAIComputerUseConfig:
Returns:
List of capability strings
"""
return ["step"]
return ["click", "step"]

View File

@@ -515,7 +515,7 @@ def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any
return litellm_messages
@register_agent(models=r"(?i).*ui-?tars.*", priority=10)
@register_agent(models=r"(?i).*ui-?tars.*")
class UITARSConfig:
"""
UITARS agent configuration using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B model.

View File

@@ -40,7 +40,7 @@ def make_input_image_item(image_data: Union[str, bytes]) -> EasyInputMessagePara
ResponseInputImageParam(
type="input_image",
image_url=f"data:image/png;base64,{base64.b64encode(image_data).decode('utf-8') if isinstance(image_data, bytes) else image_data}"
)
) # type: ignore
],
role="user",
type="message"
@@ -205,3 +205,479 @@ def make_wait_item(call_id: Optional[str] = None) -> ResponseComputerToolCallPar
status="completed",
type="computer_call"
)
# Conversion functions between element descriptions and coordinates
def convert_computer_calls_desc2xy(responses_items: List[Dict[str, Any]], desc2xy: Dict[str, tuple]) -> List[Dict[str, Any]]:
"""
Convert computer calls from element descriptions to x,y coordinates.
Args:
responses_items: List of response items containing computer calls with element_description
desc2xy: Dictionary mapping element descriptions to (x, y) coordinate tuples
Returns:
List of response items with element_description replaced by x,y coordinates
"""
converted_items = []
for item in responses_items:
if item.get("type") == "computer_call" and "action" in item:
action = item["action"].copy()
# Handle single element_description
if "element_description" in action:
desc = action["element_description"]
if desc in desc2xy:
x, y = desc2xy[desc]
action["x"] = x
action["y"] = y
del action["element_description"]
# Handle start_element_description and end_element_description for drag operations
elif "start_element_description" in action and "end_element_description" in action:
start_desc = action["start_element_description"]
end_desc = action["end_element_description"]
if start_desc in desc2xy and end_desc in desc2xy:
start_x, start_y = desc2xy[start_desc]
end_x, end_y = desc2xy[end_desc]
action["path"] = [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}]
del action["start_element_description"]
del action["end_element_description"]
converted_item = item.copy()
converted_item["action"] = action
converted_items.append(converted_item)
else:
converted_items.append(item)
return converted_items
def convert_computer_calls_xy2desc(responses_items: List[Dict[str, Any]], desc2xy: Dict[str, tuple]) -> List[Dict[str, Any]]:
"""
Convert computer calls from x,y coordinates to element descriptions.
Args:
responses_items: List of response items containing computer calls with x,y coordinates
desc2xy: Dictionary mapping element descriptions to (x, y) coordinate tuples
Returns:
List of response items with x,y coordinates replaced by element_description
"""
# Create reverse mapping from coordinates to descriptions
xy2desc = {coords: desc for desc, coords in desc2xy.items()}
converted_items = []
for item in responses_items:
if item.get("type") == "computer_call" and "action" in item:
action = item["action"].copy()
# Handle single x,y coordinates
if "x" in action and "y" in action:
coords = (action["x"], action["y"])
if coords in xy2desc:
action["element_description"] = xy2desc[coords]
del action["x"]
del action["y"]
# Handle path for drag operations
elif "path" in action and isinstance(action["path"], list) and len(action["path"]) == 2:
start_point = action["path"][0]
end_point = action["path"][1]
if ("x" in start_point and "y" in start_point and
"x" in end_point and "y" in end_point):
start_coords = (start_point["x"], start_point["y"])
end_coords = (end_point["x"], end_point["y"])
if start_coords in xy2desc and end_coords in xy2desc:
action["start_element_description"] = xy2desc[start_coords]
action["end_element_description"] = xy2desc[end_coords]
del action["path"]
converted_item = item.copy()
converted_item["action"] = action
converted_items.append(converted_item)
else:
converted_items.append(item)
return converted_items
def get_all_element_descriptions(responses_items: List[Dict[str, Any]]) -> List[str]:
"""
Extract all element descriptions from computer calls in responses items.
Args:
responses_items: List of response items containing computer calls
Returns:
List of unique element descriptions found in computer calls
"""
descriptions = set()
for item in responses_items:
if item.get("type") == "computer_call" and "action" in item:
action = item["action"]
# Handle single element_description
if "element_description" in action:
descriptions.add(action["element_description"])
# Handle start_element_description and end_element_description for drag operations
if "start_element_description" in action:
descriptions.add(action["start_element_description"])
if "end_element_description" in action:
descriptions.add(action["end_element_description"])
return list(descriptions)
# Conversion functions between responses_items and completion messages formats
def convert_responses_items_to_completion_messages(messages: List[Dict[str, Any]], allow_images_in_tool_results: bool = True) -> List[Dict[str, Any]]:
"""Convert responses_items message format to liteLLM completion format.
Args:
messages: List of responses_items format messages
allow_images_in_tool_results: If True, include images in tool role messages.
If False, send tool message + separate user message with image.
"""
completion_messages = []
for message in messages:
msg_type = message.get("type")
role = message.get("role")
# Handle user messages (both with and without explicit type)
if role == "user" or msg_type == "user":
content = message.get("content", "")
if isinstance(content, list):
# Handle list content (images, text blocks)
completion_content = []
for item in content:
if item.get("type") == "input_image":
completion_content.append({
"type": "image_url",
"image_url": {
"url": item.get("image_url")
}
})
elif item.get("type") == "input_text":
completion_content.append({
"type": "text",
"text": item.get("text")
})
elif item.get("type") == "text":
completion_content.append({
"type": "text",
"text": item.get("text")
})
completion_messages.append({
"role": "user",
"content": completion_content
})
elif isinstance(content, str):
# Handle string content
completion_messages.append({
"role": "user",
"content": content
})
# Handle assistant messages
elif role == "assistant" or msg_type == "message":
content = message.get("content", [])
if isinstance(content, list):
text_parts = []
for item in content:
if item.get("type") == "output_text":
text_parts.append(item.get("text", ""))
elif item.get("type") == "text":
text_parts.append(item.get("text", ""))
if text_parts:
completion_messages.append({
"role": "assistant",
"content": "\n".join(text_parts)
})
# Handle reasoning items (convert to assistant message)
elif msg_type == "reasoning":
summary = message.get("summary", [])
text_parts = []
for item in summary:
if item.get("type") == "summary_text":
text_parts.append(item.get("text", ""))
if text_parts:
completion_messages.append({
"role": "assistant",
"content": "\n".join(text_parts)
})
# Handle function calls
elif msg_type == "function_call":
# Add tool call to last assistant message or create new one
if not completion_messages or completion_messages[-1]["role"] != "assistant":
completion_messages.append({
"role": "assistant",
"content": "",
"tool_calls": []
})
if "tool_calls" not in completion_messages[-1]:
completion_messages[-1]["tool_calls"] = []
completion_messages[-1]["tool_calls"].append({
"id": message.get("call_id"),
"type": "function",
"function": {
"name": message.get("name"),
"arguments": message.get("arguments")
}
})
# Handle computer calls
elif msg_type == "computer_call":
# Add tool call to last assistant message or create new one
if not completion_messages or completion_messages[-1]["role"] != "assistant":
completion_messages.append({
"role": "assistant",
"content": "",
"tool_calls": []
})
if "tool_calls" not in completion_messages[-1]:
completion_messages[-1]["tool_calls"] = []
action = message.get("action", {})
completion_messages[-1]["tool_calls"].append({
"id": message.get("call_id"),
"type": "function",
"function": {
"name": "computer",
"arguments": json.dumps(action)
}
})
# Handle function/computer call outputs
elif msg_type in ["function_call_output", "computer_call_output"]:
output = message.get("output")
call_id = message.get("call_id")
if isinstance(output, dict) and output.get("type") == "input_image":
if allow_images_in_tool_results:
# Handle image output as tool response (may not work with all APIs)
completion_messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": [{
"type": "image_url",
"image_url": {
"url": output.get("image_url")
}
}]
})
else:
# Send tool message + separate user message with image (OpenAI compatible)
completion_messages += [{
"role": "tool",
"tool_call_id": call_id,
"content": "[Execution completed. See screenshot below]"
}, {
"role": "user",
"content": [{
"type": "image_url",
"image_url": {
"url": output.get("image_url")
}
}]
}]
else:
# Handle text output as tool response
completion_messages.append({
"role": "tool",
"tool_call_id": call_id,
"content": str(output)
})
return completion_messages
def convert_completion_messages_to_responses_items(completion_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert completion messages format to responses_items message format."""
responses_items = []
skip_next = False
for i, message in enumerate(completion_messages):
if skip_next:
skip_next = False
continue
role = message.get("role")
content = message.get("content")
tool_calls = message.get("tool_calls", [])
# Handle assistant messages with text content
if role == "assistant" and content and isinstance(content, str):
responses_items.append({
"type": "message",
"role": "assistant",
"content": [{
"type": "output_text",
"text": content
}]
})
# Handle tool calls
if tool_calls:
for tool_call in tool_calls:
if tool_call.get("type") == "function":
function = tool_call.get("function", {})
function_name = function.get("name")
if function_name == "computer":
# Parse computer action
try:
action = json.loads(function.get("arguments", "{}"))
# Change key from "action" -> "type"
if action.get("action"):
action["type"] = action["action"]
del action["action"]
responses_items.append({
"type": "computer_call",
"call_id": tool_call.get("id"),
"action": action,
"status": "completed"
})
except json.JSONDecodeError:
# Fallback to function call format
responses_items.append({
"type": "function_call",
"call_id": tool_call.get("id"),
"name": function_name,
"arguments": function.get("arguments", "{}"),
"status": "completed"
})
else:
# Regular function call
responses_items.append({
"type": "function_call",
"call_id": tool_call.get("id"),
"name": function_name,
"arguments": function.get("arguments", "{}"),
"status": "completed"
})
# Handle tool messages (function/computer call outputs)
elif role == "tool" and content:
tool_call_id = message.get("tool_call_id")
if isinstance(content, str):
# Check if this is the "[Execution completed. See screenshot below]" pattern
if content == "[Execution completed. See screenshot below]":
# Look ahead for the next user message with image
next_idx = i + 1
if (next_idx < len(completion_messages) and
completion_messages[next_idx].get("role") == "user" and
isinstance(completion_messages[next_idx].get("content"), list)):
# Found the pattern - extract image from next message
next_content = completion_messages[next_idx]["content"]
for item in next_content:
if item.get("type") == "image_url":
responses_items.append({
"type": "computer_call_output",
"call_id": tool_call_id,
"output": {
"type": "input_image",
"image_url": item.get("image_url", {}).get("url")
}
})
# Skip the next user message since we processed it
skip_next = True
break
else:
# No matching user message, treat as regular text
responses_items.append({
"type": "computer_call_output",
"call_id": tool_call_id,
"output": content
})
else:
# Determine if this is a computer call or function call output
try:
# Try to parse as structured output
parsed_content = json.loads(content)
if parsed_content.get("type") == "input_image":
responses_items.append({
"type": "computer_call_output",
"call_id": tool_call_id,
"output": parsed_content
})
else:
responses_items.append({
"type": "computer_call_output",
"call_id": tool_call_id,
"output": content
})
except json.JSONDecodeError:
# Plain text output - could be function or computer call
responses_items.append({
"type": "function_call_output",
"call_id": tool_call_id,
"output": content
})
elif isinstance(content, list):
# Handle structured content (e.g., images)
for item in content:
if item.get("type") == "image_url":
responses_items.append({
"type": "computer_call_output",
"call_id": tool_call_id,
"output": {
"type": "input_image",
"image_url": item.get("image_url", {}).get("url")
}
})
elif item.get("type") == "text":
responses_items.append({
"type": "function_call_output",
"call_id": tool_call_id,
"output": item.get("text")
})
# Handle actual user messages
elif role == "user" and content:
if isinstance(content, list):
# Handle structured user content (e.g., text + images)
user_content = []
for item in content:
if item.get("type") == "image_url":
user_content.append({
"type": "input_image",
"image_url": item.get("image_url", {}).get("url")
})
elif item.get("type") == "text":
user_content.append({
"type": "input_text",
"text": item.get("text")
})
if user_content:
responses_items.append({
"role": "user",
"type": "message",
"content": user_content
})
elif isinstance(content, str):
# Handle simple text user message
responses_items.append({
"role": "user",
"content": content
})
return responses_items

View File

@@ -1,4 +1,3 @@
from .base import ModelProtocol
from .gta1 import GTA1Model
__all__ = ["ModelProtocol", "GTA1Model"]
__all__ = ["ModelProtocol"]

View File

@@ -21,7 +21,6 @@ import torch
# Add parent directory to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from agent.agent import ComputerAgent
from models import GTA1Model
from models.base import ModelProtocol
def get_gpu_memory() -> List[int]:
@@ -82,13 +81,16 @@ def get_available_models() -> List[Union[str, ModelProtocol]]:
"""
local_provider = "huggingface-local/" # Options: huggingface-local/ or mlx/
# from models.gta1 import GTA1Model
models = [
# === ComputerAgent model strings ===
f"{local_provider}HelloKKMe/GTA1-7B",
# f"{local_provider}HelloKKMe/GTA1-7B",
# f"{local_provider}HelloKKMe/GTA1-32B",
"openai/computer-use-preview+openai/gpt-4o-mini"
# === Reference model classes ===
GTA1Model("HelloKKMe/GTA1-7B"),
# GTA1Model("HelloKKMe/GTA1-7B"),
# GTA1Model("HelloKKMe/GTA1-32B"),
]