mirror of
https://github.com/trycua/computer.git
synced 2025-12-31 18:40:04 -06:00
added grounding+planning composed loop
This commit is contained in:
@@ -8,5 +8,6 @@ from . import openai
|
||||
from . import uitars
|
||||
from . import omniparser
|
||||
from . import gta1
|
||||
from . import composed_grounded
|
||||
|
||||
__all__ = ["anthropic", "openai", "uitars", "omniparser", "gta1"]
|
||||
__all__ = ["anthropic", "openai", "uitars", "omniparser", "gta1", "composed_grounded"]
|
||||
|
||||
@@ -1285,7 +1285,7 @@ def _merge_consecutive_text(content_list: List[Dict[str, Any]]) -> List[Dict[str
|
||||
|
||||
return merged
|
||||
|
||||
@register_agent(models=r".*claude-.*", priority=5)
|
||||
@register_agent(models=r".*claude-.*")
|
||||
class AnthropicHostedToolsConfig(AsyncAgentConfig):
|
||||
"""Anthropic hosted tools agent configuration implementing AsyncAgentConfig protocol."""
|
||||
|
||||
|
||||
318
libs/python/agent/agent/loops/composed_grounded.py
Normal file
318
libs/python/agent/agent/loops/composed_grounded.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Composed-grounded agent loop implementation that combines grounding and thinking models.
|
||||
Uses a two-stage approach: grounding model for element detection, thinking model for reasoning.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import asyncio
|
||||
import json
|
||||
import base64
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
import litellm
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
from ..loops.base import AsyncAgentConfig
|
||||
from ..responses import (
|
||||
convert_computer_calls_xy2desc,
|
||||
convert_responses_items_to_completion_messages,
|
||||
convert_completion_messages_to_responses_items,
|
||||
convert_computer_calls_desc2xy,
|
||||
get_all_element_descriptions
|
||||
)
|
||||
from ..agent import find_agent_config
|
||||
|
||||
GROUNDED_COMPUTER_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment"
|
||||
],
|
||||
"description": "The action to perform"
|
||||
},
|
||||
"element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
|
||||
},
|
||||
"start_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to start dragging from (required for drag action)"
|
||||
},
|
||||
"end_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to drag to (required for drag action)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)"
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"action"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def _prepare_tools_for_grounded(tool_schemas: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Prepare tools for grounded API format"""
|
||||
grounded_tools = []
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
grounded_tools.append(GROUNDED_COMPUTER_TOOL_SCHEMA)
|
||||
else:
|
||||
grounded_tools.append(schema)
|
||||
|
||||
return grounded_tools
|
||||
|
||||
def get_last_computer_call_image(messages: List[Dict[str, Any]]) -> Optional[str]:
|
||||
"""Get the last computer call output image from messages."""
|
||||
for message in reversed(messages):
|
||||
if (isinstance(message, dict) and
|
||||
message.get("type") == "computer_call_output" and
|
||||
isinstance(message.get("output"), dict) and
|
||||
message["output"].get("type") == "input_image"):
|
||||
image_url = message["output"].get("image_url", "")
|
||||
if image_url.startswith("data:image/png;base64,"):
|
||||
return image_url.split(",", 1)[1]
|
||||
return None
|
||||
|
||||
|
||||
@register_agent(r".*\+.*", priority=10)
|
||||
class ComposedGroundedConfig:
|
||||
"""
|
||||
Composed-grounded agent configuration that uses both grounding and thinking models.
|
||||
|
||||
The model parameter should be in format: "grounding_model+thinking_model"
|
||||
e.g., "gpt-4o+claude-3-5-sonnet-20241022"
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.desc2xy: Dict[str, Tuple[float, float]] = {}
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Composed-grounded predict step implementation.
|
||||
|
||||
Process:
|
||||
0. Store last computer call image, if none then take a screenshot
|
||||
1. Convert computer calls from xy to descriptions
|
||||
2. Convert responses items to completion messages
|
||||
3. Call thinking model with litellm.acompletion
|
||||
4. Convert completion messages to responses items
|
||||
5. Get all element descriptions and populate desc2xy mapping
|
||||
6. Convert computer calls from descriptions back to xy coordinates
|
||||
7. Return output and usage
|
||||
"""
|
||||
# Parse the composed model
|
||||
if "+" not in model:
|
||||
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
|
||||
grounding_model, thinking_model = model.split("+", 1)
|
||||
|
||||
pre_output_items = []
|
||||
|
||||
# Step 0: Store last computer call image, if none then take a screenshot
|
||||
last_image_b64 = get_last_computer_call_image(messages)
|
||||
if last_image_b64 is None:
|
||||
# Take a screenshot
|
||||
screenshot_b64 = await computer_handler.screenshot() # type: ignore
|
||||
if screenshot_b64:
|
||||
|
||||
call_id = uuid.uuid4().hex
|
||||
pre_output_items += [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "Taking a screenshot to see the current computer screen."
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"action": {
|
||||
"type": "screenshot"
|
||||
},
|
||||
"call_id": call_id,
|
||||
"status": "completed",
|
||||
"type": "computer_call"
|
||||
},
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_b64}"
|
||||
}
|
||||
},
|
||||
]
|
||||
last_image_b64 = screenshot_b64
|
||||
|
||||
# Call screenshot callback if provided
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(screenshot_b64)
|
||||
|
||||
tool_schemas = _prepare_tools_for_grounded(tools) # type: ignore
|
||||
|
||||
# Step 1: Convert computer calls from xy to descriptions
|
||||
input_messages = messages + pre_output_items
|
||||
messages_with_descriptions = convert_computer_calls_xy2desc(input_messages, self.desc2xy)
|
||||
|
||||
# Step 2: Convert responses items to completion messages
|
||||
completion_messages = convert_responses_items_to_completion_messages(
|
||||
messages_with_descriptions,
|
||||
allow_images_in_tool_results=False
|
||||
)
|
||||
|
||||
# Step 3: Call thinking model with litellm.acompletion
|
||||
api_kwargs = {
|
||||
"model": thinking_model,
|
||||
"messages": completion_messages,
|
||||
"tools": tool_schemas,
|
||||
"max_retries": max_retries,
|
||||
"stream": stream,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
if use_prompt_caching:
|
||||
api_kwargs["use_prompt_caching"] = use_prompt_caching
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
# Make the completion call
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract usage information
|
||||
usage = {
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# Step 4: Convert completion messages back to responses items format
|
||||
response_dict = response.model_dump() # type: ignore
|
||||
choice_messages = [choice["message"] for choice in response_dict["choices"]]
|
||||
thinking_output_items = []
|
||||
|
||||
for choice_message in choice_messages:
|
||||
thinking_output_items.extend(convert_completion_messages_to_responses_items([choice_message]))
|
||||
|
||||
# Step 5: Get all element descriptions and populate desc2xy mapping
|
||||
element_descriptions = get_all_element_descriptions(thinking_output_items)
|
||||
|
||||
if element_descriptions and last_image_b64:
|
||||
# Use grounding model to predict coordinates for each description
|
||||
grounding_agent_conf = find_agent_config(grounding_model)
|
||||
if grounding_agent_conf:
|
||||
grounding_agent = grounding_agent_conf.agent_class()
|
||||
|
||||
for desc in element_descriptions:
|
||||
coords = await grounding_agent.predict_click(
|
||||
model=grounding_model,
|
||||
image_b64=last_image_b64,
|
||||
instruction=desc
|
||||
)
|
||||
if coords:
|
||||
self.desc2xy[desc] = coords
|
||||
|
||||
# Step 6: Convert computer calls from descriptions back to xy coordinates
|
||||
final_output_items = convert_computer_calls_desc2xy(thinking_output_items, self.desc2xy)
|
||||
|
||||
# Step 7: Return output and usage
|
||||
return {
|
||||
"output": pre_output_items + final_output_items,
|
||||
"usage": usage
|
||||
}
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
image_b64: str,
|
||||
instruction: str,
|
||||
**kwargs
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates using the grounding model.
|
||||
|
||||
For composed models, uses only the grounding model part for click prediction.
|
||||
"""
|
||||
# Parse the composed model to get grounding model
|
||||
if "+" not in model:
|
||||
raise ValueError(f"Composed model must be in format 'grounding_model+thinking_model', got: {model}")
|
||||
grounding_model, thinking_model = model.split("+", 1)
|
||||
|
||||
# Find and use the grounding agent
|
||||
grounding_agent_conf = find_agent_config(grounding_model)
|
||||
if grounding_agent_conf:
|
||||
grounding_agent = grounding_agent_conf.agent_class()
|
||||
return await grounding_agent.predict_click(
|
||||
model=grounding_model,
|
||||
image_b64=image_b64,
|
||||
instruction=instruction,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["click", "step"]
|
||||
@@ -26,73 +26,6 @@ Output the coordinate pair exactly:
|
||||
(x,y)
|
||||
'''.strip()
|
||||
|
||||
# Global dictionary to map coordinates to descriptions
|
||||
xy2desc: Dict[Tuple[float, float], str] = {}
|
||||
|
||||
GTA1_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment"
|
||||
],
|
||||
"description": "The action to perform"
|
||||
},
|
||||
"element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
|
||||
},
|
||||
"start_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to start dragging from (required for drag action)"
|
||||
},
|
||||
"end_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to drag to (required for drag action)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)"
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"action"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
|
||||
"""Extract coordinates from model output."""
|
||||
try:
|
||||
@@ -101,173 +34,6 @@ def extract_coordinates(raw_string: str) -> Tuple[float, float]:
|
||||
except:
|
||||
return (0.0, 0.0)
|
||||
|
||||
def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Get the last computer_call_output message from a messages list.
|
||||
|
||||
Args:
|
||||
messages: List of messages to search through
|
||||
|
||||
Returns:
|
||||
The last computer_call_output message dict, or None if not found
|
||||
"""
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, dict) and message.get("type") == "computer_call_output":
|
||||
return message
|
||||
return None
|
||||
|
||||
def _prepare_tools_for_gta1(tool_schemas: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Prepare tools for GTA1 API format"""
|
||||
gta1_tools = []
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
gta1_tools.append(GTA1_TOOL_SCHEMA)
|
||||
else:
|
||||
gta1_tools.append(schema)
|
||||
|
||||
return gta1_tools
|
||||
|
||||
async def replace_function_with_computer_call_gta1(item: Dict[str, Any], agent_instance) -> List[Dict[str, Any]]:
|
||||
"""Convert function_call to computer_call format using GTA1 click prediction."""
|
||||
global xy2desc
|
||||
item_type = item.get("type")
|
||||
|
||||
async def _get_xy(element_description: Optional[str], last_image_b64: str) -> Union[Tuple[float, float], Tuple[None, None]]:
|
||||
if element_description is None:
|
||||
return (None, None)
|
||||
# Use self.predict_click to get coordinates from description
|
||||
coords = await agent_instance.predict_click(
|
||||
model=agent_instance.current_model,
|
||||
image_b64=last_image_b64,
|
||||
instruction=element_description
|
||||
)
|
||||
if coords:
|
||||
# Store the mapping from coordinates to description
|
||||
xy2desc[coords] = element_description
|
||||
return coords
|
||||
return (None, None)
|
||||
|
||||
if item_type == "function_call":
|
||||
fn_name = item.get("name")
|
||||
fn_args = json.loads(item.get("arguments", "{}"))
|
||||
|
||||
item_id = item.get("id")
|
||||
call_id = item.get("call_id")
|
||||
|
||||
if fn_name == "computer":
|
||||
action = fn_args.get("action")
|
||||
element_description = fn_args.get("element_description")
|
||||
start_element_description = fn_args.get("start_element_description")
|
||||
end_element_description = fn_args.get("end_element_description")
|
||||
text = fn_args.get("text")
|
||||
keys = fn_args.get("keys")
|
||||
button = fn_args.get("button")
|
||||
scroll_x = fn_args.get("scroll_x")
|
||||
scroll_y = fn_args.get("scroll_y")
|
||||
|
||||
# Get the last computer output image for click prediction
|
||||
last_image_b64 = agent_instance.last_screenshot_b64 or ""
|
||||
|
||||
x, y = await _get_xy(element_description, last_image_b64)
|
||||
start_x, start_y = await _get_xy(start_element_description, last_image_b64)
|
||||
end_x, end_y = await _get_xy(end_element_description, last_image_b64)
|
||||
|
||||
action_args = {
|
||||
"type": action,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"start_x": start_x,
|
||||
"start_y": start_y,
|
||||
"end_x": end_x,
|
||||
"end_y": end_y,
|
||||
"text": text,
|
||||
"keys": keys,
|
||||
"button": button,
|
||||
"scroll_x": scroll_x,
|
||||
"scroll_y": scroll_y
|
||||
}
|
||||
# Remove None values to keep the JSON clean
|
||||
action_args = {k: v for k, v in action_args.items() if v is not None}
|
||||
|
||||
return [{
|
||||
"type": "computer_call",
|
||||
"action": action_args,
|
||||
"id": item_id,
|
||||
"call_id": call_id,
|
||||
"status": "completed"
|
||||
}]
|
||||
|
||||
return [item]
|
||||
|
||||
async def replace_computer_call_with_function_gta1(item: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert computer_call back to function_call format using descriptions.
|
||||
Only READS from the global xy2desc dictionary.
|
||||
|
||||
Args:
|
||||
item: The item to convert
|
||||
"""
|
||||
global xy2desc
|
||||
item_type = item.get("type")
|
||||
|
||||
def _get_element_description(x: Optional[float], y: Optional[float]) -> Optional[str]:
|
||||
"""Get element description from coordinates, return None if coordinates are None"""
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return xy2desc.get((x, y))
|
||||
|
||||
if item_type == "computer_call":
|
||||
action_data = item.get("action", {})
|
||||
|
||||
# Extract coordinates and convert back to element descriptions
|
||||
element_description = _get_element_description(action_data.get("x"), action_data.get("y"))
|
||||
start_element_description = _get_element_description(action_data.get("start_x"), action_data.get("start_y"))
|
||||
end_element_description = _get_element_description(action_data.get("end_x"), action_data.get("end_y"))
|
||||
|
||||
# Build function arguments
|
||||
fn_args = {
|
||||
"action": action_data.get("type"),
|
||||
"element_description": element_description,
|
||||
"start_element_description": start_element_description,
|
||||
"end_element_description": end_element_description,
|
||||
"text": action_data.get("text"),
|
||||
"keys": action_data.get("keys"),
|
||||
"button": action_data.get("button"),
|
||||
"scroll_x": action_data.get("scroll_x"),
|
||||
"scroll_y": action_data.get("scroll_y")
|
||||
}
|
||||
|
||||
# Remove None values to keep the JSON clean
|
||||
fn_args = {k: v for k, v in fn_args.items() if v is not None}
|
||||
|
||||
return [{
|
||||
"type": "function_call",
|
||||
"name": "computer",
|
||||
"arguments": json.dumps(fn_args),
|
||||
"id": item.get("id"),
|
||||
"call_id": item.get("call_id"),
|
||||
"status": "completed",
|
||||
# Fall back to string representation
|
||||
# "content": f"Used tool: {action_data.get('type')}({json.dumps(fn_args)})"
|
||||
}]
|
||||
|
||||
elif item_type == "computer_call_output":
|
||||
# Simple conversion: computer_call_output -> function_call_output (text only), user message (image)
|
||||
return [
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"output": "Tool executed successfully. See the current computer screenshot below, if nothing has changed yet then you may need to wait before trying again.",
|
||||
"id": item.get("id"),
|
||||
"status": "completed"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": [item.get("output")]
|
||||
}
|
||||
]
|
||||
|
||||
return [item]
|
||||
|
||||
def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360) -> Tuple[int, int]:
|
||||
"""Smart resize function similar to qwen_vl_utils."""
|
||||
# Calculate the total pixels
|
||||
@@ -300,7 +66,7 @@ def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 31
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
@register_agent(models=r".*GTA1.*", priority=10)
|
||||
@register_agent(models=r".*GTA1.*")
|
||||
class GTA1Config(AsyncAgentConfig):
|
||||
"""GTA1 agent configuration implementing AsyncAgentConfig protocol for click prediction."""
|
||||
|
||||
@@ -308,153 +74,6 @@ class GTA1Config(AsyncAgentConfig):
|
||||
self.current_model = None
|
||||
self.last_screenshot_b64 = None
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: Messages,
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
computer_handler=None,
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
_on_api_start=None,
|
||||
_on_api_end=None,
|
||||
_on_usage=None,
|
||||
_on_screenshot=None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
GTA1 agent loop implementation using liteLLM responses with element descriptions.
|
||||
|
||||
Follows the 4-step process:
|
||||
1. Prepare tools
|
||||
2. Replace computer calls with function calls (using descriptions)
|
||||
3. API call
|
||||
4. Replace function calls with computer calls (using predict_click)
|
||||
"""
|
||||
models = model.split("+")
|
||||
if len(models) != 2:
|
||||
raise ValueError("GTA1 model must be in the format <gta1_model_name>+<planning_model_name> to be used in an agent loop")
|
||||
|
||||
gta1_model, llm_model = models
|
||||
self.current_model = gta1_model
|
||||
|
||||
tools = tools or []
|
||||
|
||||
# Step 0: Prepare tools
|
||||
gta1_tools = _prepare_tools_for_gta1(tools)
|
||||
|
||||
# Get last computer_call_output for screenshot reference
|
||||
# Convert messages to list of dicts first
|
||||
message_list = []
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
message_list.append(message.__dict__)
|
||||
else:
|
||||
message_list.append(message)
|
||||
|
||||
last_computer_call_output = get_last_computer_call_output(message_list)
|
||||
if last_computer_call_output:
|
||||
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
|
||||
if image_url.startswith("data:image/"):
|
||||
self.last_screenshot_b64 = image_url.split(",")[-1]
|
||||
else:
|
||||
self.last_screenshot_b64 = image_url
|
||||
|
||||
# Step 1: If there's no screenshot, simulate the model calling the screenshot function
|
||||
pre_output = []
|
||||
if not self.last_screenshot_b64 and computer_handler:
|
||||
screenshot_base64 = await computer_handler.screenshot()
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(screenshot_base64, "screenshot_initial")
|
||||
|
||||
call_id = uuid.uuid4().hex
|
||||
pre_output += [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "Taking a screenshot to see the current computer screen."
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"action": {
|
||||
"type": "screenshot"
|
||||
},
|
||||
"call_id": call_id,
|
||||
"status": "completed",
|
||||
"type": "computer_call"
|
||||
},
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}"
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
# Update the last screenshot for future use
|
||||
self.last_screenshot_b64 = screenshot_base64
|
||||
|
||||
message_list += pre_output
|
||||
|
||||
# Step 2: Replace computer calls with function calls (using descriptions)
|
||||
new_messages = []
|
||||
for message in message_list:
|
||||
new_messages += await replace_computer_call_with_function_gta1(message)
|
||||
messages = new_messages
|
||||
|
||||
# Step 3: API call
|
||||
api_kwargs = {
|
||||
"model": llm_model,
|
||||
"input": messages,
|
||||
"tools": gta1_tools if gta1_tools else None,
|
||||
"stream": stream,
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract usage information
|
||||
usage = {
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# Step 4: Replace function calls with computer calls (using predict_click)
|
||||
new_output = []
|
||||
for i in range(len(response.output)): # type: ignore
|
||||
output_item = response.output[i] # type: ignore
|
||||
# Convert to dict if it has model_dump method, otherwise use as-is
|
||||
if hasattr(output_item, 'model_dump'):
|
||||
item_dict = output_item.model_dump() # type: ignore
|
||||
else:
|
||||
item_dict = output_item # type: ignore
|
||||
new_output += await replace_function_with_computer_call_gta1(item_dict, self) # type: ignore
|
||||
|
||||
return {
|
||||
"output": pre_output + new_output,
|
||||
"usage": usage
|
||||
}
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
model: str,
|
||||
@@ -528,10 +147,10 @@ class GTA1Config(AsyncAgentConfig):
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Extract response text
|
||||
output_text = response.choices[0].message.content
|
||||
output_text = response.choices[0].message.content # type: ignore
|
||||
|
||||
# Extract and rescale coordinates
|
||||
pred_x, pred_y = extract_coordinates(output_text)
|
||||
pred_x, pred_y = extract_coordinates(output_text) # type: ignore
|
||||
pred_x *= scale_x
|
||||
pred_y *= scale_y
|
||||
|
||||
@@ -539,4 +158,4 @@ class GTA1Config(AsyncAgentConfig):
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["click", "step"]
|
||||
return ["click"]
|
||||
|
||||
@@ -249,7 +249,7 @@ async def replace_computer_call_with_function(item: Dict[str, Any], xy2id: Dict[
|
||||
return [item]
|
||||
|
||||
|
||||
@register_agent(models=r"omniparser\+.*|omni\+.*", priority=10)
|
||||
@register_agent(models=r"omniparser\+.*|omni\+.*")
|
||||
class OmniparsrConfig(AsyncAgentConfig):
|
||||
"""Omniparser agent configuration implementing AsyncAgentConfig protocol."""
|
||||
|
||||
|
||||
@@ -3,9 +3,12 @@ OpenAI computer-use-preview agent loop implementation using liteLLM
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from io import BytesIO
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
import litellm
|
||||
from PIL import Image
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
@@ -36,7 +39,7 @@ def _prepare_tools_for_openai(tool_schemas: List[Dict[str, Any]]) -> Tools:
|
||||
return openai_tools
|
||||
|
||||
|
||||
@register_agent(models=r".*computer-use-preview.*", priority=10)
|
||||
@register_agent(models=r".*computer-use-preview.*")
|
||||
class OpenAIComputerUseConfig:
|
||||
"""
|
||||
OpenAI computer-use-preview agent configuration using liteLLM responses.
|
||||
@@ -128,8 +131,8 @@ class OpenAIComputerUseConfig:
|
||||
"""
|
||||
Predict click coordinates based on image and instruction.
|
||||
|
||||
Note: OpenAI computer-use-preview doesn't support direct click prediction,
|
||||
so this returns None.
|
||||
Uses OpenAI computer-use-preview with manually constructed input items
|
||||
and a prompt that instructs the agent to only output clicks.
|
||||
|
||||
Args:
|
||||
model: Model name to use
|
||||
@@ -137,8 +140,94 @@ class OpenAIComputerUseConfig:
|
||||
instruction: Instruction for where to click
|
||||
|
||||
Returns:
|
||||
None (not supported by OpenAI computer-use-preview)
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
# TODO: implement this correctly
|
||||
# Scale image to half size
|
||||
try:
|
||||
image_data = base64.b64decode(image_b64)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
|
||||
# Scale to half size
|
||||
new_width = image.width // 2
|
||||
new_height = image.height // 2
|
||||
scaled_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# Convert back to base64
|
||||
buffer = BytesIO()
|
||||
scaled_image.save(buffer, format='PNG')
|
||||
image_b64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
except Exception:
|
||||
# If scaling fails, use original image
|
||||
pass
|
||||
|
||||
# Manually construct input items with image and click instruction
|
||||
input_items = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"You are a UI grounding expert. Look at the image and {instruction}. Output ONLY a click action on the target element. No explanations, confirmations, or additional text."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{image_b64}"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Get image dimensions from base64 data
|
||||
try:
|
||||
image_data = base64.b64decode(image_b64)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
display_width, display_height = image.size
|
||||
except Exception:
|
||||
# Fallback to default dimensions if image parsing fails
|
||||
display_width, display_height = 1024, 768
|
||||
|
||||
# Prepare computer tool for click actions
|
||||
computer_tool = {
|
||||
"type": "computer_use_preview",
|
||||
"display_width": display_width,
|
||||
"display_height": display_height,
|
||||
"environment": "linux"
|
||||
}
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"input": input_items,
|
||||
"tools": [computer_tool],
|
||||
"stream": False,
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"max_tokens": 100 # Keep response short for click prediction
|
||||
}
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
# Extract click coordinates from response output
|
||||
output_dict = response.model_dump()
|
||||
output_items = output_dict.get("output", [])
|
||||
|
||||
# print(output_items)
|
||||
|
||||
# Look for computer_call with click action
|
||||
for item in output_items:
|
||||
if (isinstance(item, dict) and
|
||||
item.get("type") == "computer_call" and
|
||||
isinstance(item.get("action"), dict)):
|
||||
|
||||
action = item["action"]
|
||||
if action.get("type") == "click":
|
||||
x = action.get("x")
|
||||
y = action.get("y")
|
||||
if x is not None and y is not None:
|
||||
return (int(x) * 2, int(y) * 2)
|
||||
|
||||
return None
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
@@ -148,4 +237,4 @@ class OpenAIComputerUseConfig:
|
||||
Returns:
|
||||
List of capability strings
|
||||
"""
|
||||
return ["step"]
|
||||
return ["click", "step"]
|
||||
|
||||
@@ -515,7 +515,7 @@ def convert_uitars_messages_to_litellm(messages: Messages) -> List[Dict[str, Any
|
||||
|
||||
return litellm_messages
|
||||
|
||||
@register_agent(models=r"(?i).*ui-?tars.*", priority=10)
|
||||
@register_agent(models=r"(?i).*ui-?tars.*")
|
||||
class UITARSConfig:
|
||||
"""
|
||||
UITARS agent configuration using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B model.
|
||||
|
||||
@@ -40,7 +40,7 @@ def make_input_image_item(image_data: Union[str, bytes]) -> EasyInputMessagePara
|
||||
ResponseInputImageParam(
|
||||
type="input_image",
|
||||
image_url=f"data:image/png;base64,{base64.b64encode(image_data).decode('utf-8') if isinstance(image_data, bytes) else image_data}"
|
||||
)
|
||||
) # type: ignore
|
||||
],
|
||||
role="user",
|
||||
type="message"
|
||||
@@ -205,3 +205,479 @@ def make_wait_item(call_id: Optional[str] = None) -> ResponseComputerToolCallPar
|
||||
status="completed",
|
||||
type="computer_call"
|
||||
)
|
||||
|
||||
|
||||
# Conversion functions between element descriptions and coordinates
|
||||
def convert_computer_calls_desc2xy(responses_items: List[Dict[str, Any]], desc2xy: Dict[str, tuple]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert computer calls from element descriptions to x,y coordinates.
|
||||
|
||||
Args:
|
||||
responses_items: List of response items containing computer calls with element_description
|
||||
desc2xy: Dictionary mapping element descriptions to (x, y) coordinate tuples
|
||||
|
||||
Returns:
|
||||
List of response items with element_description replaced by x,y coordinates
|
||||
"""
|
||||
converted_items = []
|
||||
|
||||
for item in responses_items:
|
||||
if item.get("type") == "computer_call" and "action" in item:
|
||||
action = item["action"].copy()
|
||||
|
||||
# Handle single element_description
|
||||
if "element_description" in action:
|
||||
desc = action["element_description"]
|
||||
if desc in desc2xy:
|
||||
x, y = desc2xy[desc]
|
||||
action["x"] = x
|
||||
action["y"] = y
|
||||
del action["element_description"]
|
||||
|
||||
# Handle start_element_description and end_element_description for drag operations
|
||||
elif "start_element_description" in action and "end_element_description" in action:
|
||||
start_desc = action["start_element_description"]
|
||||
end_desc = action["end_element_description"]
|
||||
|
||||
if start_desc in desc2xy and end_desc in desc2xy:
|
||||
start_x, start_y = desc2xy[start_desc]
|
||||
end_x, end_y = desc2xy[end_desc]
|
||||
action["path"] = [{"x": start_x, "y": start_y}, {"x": end_x, "y": end_y}]
|
||||
del action["start_element_description"]
|
||||
del action["end_element_description"]
|
||||
|
||||
converted_item = item.copy()
|
||||
converted_item["action"] = action
|
||||
converted_items.append(converted_item)
|
||||
else:
|
||||
converted_items.append(item)
|
||||
|
||||
return converted_items
|
||||
|
||||
|
||||
def convert_computer_calls_xy2desc(responses_items: List[Dict[str, Any]], desc2xy: Dict[str, tuple]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert computer calls from x,y coordinates to element descriptions.
|
||||
|
||||
Args:
|
||||
responses_items: List of response items containing computer calls with x,y coordinates
|
||||
desc2xy: Dictionary mapping element descriptions to (x, y) coordinate tuples
|
||||
|
||||
Returns:
|
||||
List of response items with x,y coordinates replaced by element_description
|
||||
"""
|
||||
# Create reverse mapping from coordinates to descriptions
|
||||
xy2desc = {coords: desc for desc, coords in desc2xy.items()}
|
||||
|
||||
converted_items = []
|
||||
|
||||
for item in responses_items:
|
||||
if item.get("type") == "computer_call" and "action" in item:
|
||||
action = item["action"].copy()
|
||||
|
||||
# Handle single x,y coordinates
|
||||
if "x" in action and "y" in action:
|
||||
coords = (action["x"], action["y"])
|
||||
if coords in xy2desc:
|
||||
action["element_description"] = xy2desc[coords]
|
||||
del action["x"]
|
||||
del action["y"]
|
||||
|
||||
# Handle path for drag operations
|
||||
elif "path" in action and isinstance(action["path"], list) and len(action["path"]) == 2:
|
||||
start_point = action["path"][0]
|
||||
end_point = action["path"][1]
|
||||
|
||||
if ("x" in start_point and "y" in start_point and
|
||||
"x" in end_point and "y" in end_point):
|
||||
|
||||
start_coords = (start_point["x"], start_point["y"])
|
||||
end_coords = (end_point["x"], end_point["y"])
|
||||
|
||||
if start_coords in xy2desc and end_coords in xy2desc:
|
||||
action["start_element_description"] = xy2desc[start_coords]
|
||||
action["end_element_description"] = xy2desc[end_coords]
|
||||
del action["path"]
|
||||
|
||||
converted_item = item.copy()
|
||||
converted_item["action"] = action
|
||||
converted_items.append(converted_item)
|
||||
else:
|
||||
converted_items.append(item)
|
||||
|
||||
return converted_items
|
||||
|
||||
|
||||
def get_all_element_descriptions(responses_items: List[Dict[str, Any]]) -> List[str]:
|
||||
"""
|
||||
Extract all element descriptions from computer calls in responses items.
|
||||
|
||||
Args:
|
||||
responses_items: List of response items containing computer calls
|
||||
|
||||
Returns:
|
||||
List of unique element descriptions found in computer calls
|
||||
"""
|
||||
descriptions = set()
|
||||
|
||||
for item in responses_items:
|
||||
if item.get("type") == "computer_call" and "action" in item:
|
||||
action = item["action"]
|
||||
|
||||
# Handle single element_description
|
||||
if "element_description" in action:
|
||||
descriptions.add(action["element_description"])
|
||||
|
||||
# Handle start_element_description and end_element_description for drag operations
|
||||
if "start_element_description" in action:
|
||||
descriptions.add(action["start_element_description"])
|
||||
|
||||
if "end_element_description" in action:
|
||||
descriptions.add(action["end_element_description"])
|
||||
|
||||
return list(descriptions)
|
||||
|
||||
|
||||
# Conversion functions between responses_items and completion messages formats
|
||||
def convert_responses_items_to_completion_messages(messages: List[Dict[str, Any]], allow_images_in_tool_results: bool = True) -> List[Dict[str, Any]]:
|
||||
"""Convert responses_items message format to liteLLM completion format.
|
||||
|
||||
Args:
|
||||
messages: List of responses_items format messages
|
||||
allow_images_in_tool_results: If True, include images in tool role messages.
|
||||
If False, send tool message + separate user message with image.
|
||||
"""
|
||||
completion_messages = []
|
||||
|
||||
for message in messages:
|
||||
msg_type = message.get("type")
|
||||
role = message.get("role")
|
||||
|
||||
# Handle user messages (both with and without explicit type)
|
||||
if role == "user" or msg_type == "user":
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, list):
|
||||
# Handle list content (images, text blocks)
|
||||
completion_content = []
|
||||
for item in content:
|
||||
if item.get("type") == "input_image":
|
||||
completion_content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": item.get("image_url")
|
||||
}
|
||||
})
|
||||
elif item.get("type") == "input_text":
|
||||
completion_content.append({
|
||||
"type": "text",
|
||||
"text": item.get("text")
|
||||
})
|
||||
elif item.get("type") == "text":
|
||||
completion_content.append({
|
||||
"type": "text",
|
||||
"text": item.get("text")
|
||||
})
|
||||
|
||||
completion_messages.append({
|
||||
"role": "user",
|
||||
"content": completion_content
|
||||
})
|
||||
elif isinstance(content, str):
|
||||
# Handle string content
|
||||
completion_messages.append({
|
||||
"role": "user",
|
||||
"content": content
|
||||
})
|
||||
|
||||
# Handle assistant messages
|
||||
elif role == "assistant" or msg_type == "message":
|
||||
content = message.get("content", [])
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if item.get("type") == "output_text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
elif item.get("type") == "text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
|
||||
if text_parts:
|
||||
completion_messages.append({
|
||||
"role": "assistant",
|
||||
"content": "\n".join(text_parts)
|
||||
})
|
||||
|
||||
# Handle reasoning items (convert to assistant message)
|
||||
elif msg_type == "reasoning":
|
||||
summary = message.get("summary", [])
|
||||
text_parts = []
|
||||
for item in summary:
|
||||
if item.get("type") == "summary_text":
|
||||
text_parts.append(item.get("text", ""))
|
||||
|
||||
if text_parts:
|
||||
completion_messages.append({
|
||||
"role": "assistant",
|
||||
"content": "\n".join(text_parts)
|
||||
})
|
||||
|
||||
# Handle function calls
|
||||
elif msg_type == "function_call":
|
||||
# Add tool call to last assistant message or create new one
|
||||
if not completion_messages or completion_messages[-1]["role"] != "assistant":
|
||||
completion_messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": []
|
||||
})
|
||||
|
||||
if "tool_calls" not in completion_messages[-1]:
|
||||
completion_messages[-1]["tool_calls"] = []
|
||||
|
||||
completion_messages[-1]["tool_calls"].append({
|
||||
"id": message.get("call_id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": message.get("name"),
|
||||
"arguments": message.get("arguments")
|
||||
}
|
||||
})
|
||||
|
||||
# Handle computer calls
|
||||
elif msg_type == "computer_call":
|
||||
# Add tool call to last assistant message or create new one
|
||||
if not completion_messages or completion_messages[-1]["role"] != "assistant":
|
||||
completion_messages.append({
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": []
|
||||
})
|
||||
|
||||
if "tool_calls" not in completion_messages[-1]:
|
||||
completion_messages[-1]["tool_calls"] = []
|
||||
|
||||
action = message.get("action", {})
|
||||
completion_messages[-1]["tool_calls"].append({
|
||||
"id": message.get("call_id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "computer",
|
||||
"arguments": json.dumps(action)
|
||||
}
|
||||
})
|
||||
|
||||
# Handle function/computer call outputs
|
||||
elif msg_type in ["function_call_output", "computer_call_output"]:
|
||||
output = message.get("output")
|
||||
call_id = message.get("call_id")
|
||||
|
||||
if isinstance(output, dict) and output.get("type") == "input_image":
|
||||
if allow_images_in_tool_results:
|
||||
# Handle image output as tool response (may not work with all APIs)
|
||||
completion_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": output.get("image_url")
|
||||
}
|
||||
}]
|
||||
})
|
||||
else:
|
||||
# Send tool message + separate user message with image (OpenAI compatible)
|
||||
completion_messages += [{
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": "[Execution completed. See screenshot below]"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": output.get("image_url")
|
||||
}
|
||||
}]
|
||||
}]
|
||||
else:
|
||||
# Handle text output as tool response
|
||||
completion_messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"content": str(output)
|
||||
})
|
||||
|
||||
return completion_messages
|
||||
|
||||
|
||||
def convert_completion_messages_to_responses_items(completion_messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Convert completion messages format to responses_items message format."""
|
||||
responses_items = []
|
||||
skip_next = False
|
||||
|
||||
for i, message in enumerate(completion_messages):
|
||||
if skip_next:
|
||||
skip_next = False
|
||||
continue
|
||||
|
||||
role = message.get("role")
|
||||
content = message.get("content")
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
|
||||
# Handle assistant messages with text content
|
||||
if role == "assistant" and content and isinstance(content, str):
|
||||
responses_items.append({
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{
|
||||
"type": "output_text",
|
||||
"text": content
|
||||
}]
|
||||
})
|
||||
|
||||
# Handle tool calls
|
||||
if tool_calls:
|
||||
for tool_call in tool_calls:
|
||||
if tool_call.get("type") == "function":
|
||||
function = tool_call.get("function", {})
|
||||
function_name = function.get("name")
|
||||
|
||||
if function_name == "computer":
|
||||
# Parse computer action
|
||||
try:
|
||||
action = json.loads(function.get("arguments", "{}"))
|
||||
# Change key from "action" -> "type"
|
||||
if action.get("action"):
|
||||
action["type"] = action["action"]
|
||||
del action["action"]
|
||||
responses_items.append({
|
||||
"type": "computer_call",
|
||||
"call_id": tool_call.get("id"),
|
||||
"action": action,
|
||||
"status": "completed"
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
# Fallback to function call format
|
||||
responses_items.append({
|
||||
"type": "function_call",
|
||||
"call_id": tool_call.get("id"),
|
||||
"name": function_name,
|
||||
"arguments": function.get("arguments", "{}"),
|
||||
"status": "completed"
|
||||
})
|
||||
else:
|
||||
# Regular function call
|
||||
responses_items.append({
|
||||
"type": "function_call",
|
||||
"call_id": tool_call.get("id"),
|
||||
"name": function_name,
|
||||
"arguments": function.get("arguments", "{}"),
|
||||
"status": "completed"
|
||||
})
|
||||
|
||||
# Handle tool messages (function/computer call outputs)
|
||||
elif role == "tool" and content:
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
if isinstance(content, str):
|
||||
# Check if this is the "[Execution completed. See screenshot below]" pattern
|
||||
if content == "[Execution completed. See screenshot below]":
|
||||
# Look ahead for the next user message with image
|
||||
next_idx = i + 1
|
||||
if (next_idx < len(completion_messages) and
|
||||
completion_messages[next_idx].get("role") == "user" and
|
||||
isinstance(completion_messages[next_idx].get("content"), list)):
|
||||
# Found the pattern - extract image from next message
|
||||
next_content = completion_messages[next_idx]["content"]
|
||||
for item in next_content:
|
||||
if item.get("type") == "image_url":
|
||||
responses_items.append({
|
||||
"type": "computer_call_output",
|
||||
"call_id": tool_call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": item.get("image_url", {}).get("url")
|
||||
}
|
||||
})
|
||||
# Skip the next user message since we processed it
|
||||
skip_next = True
|
||||
break
|
||||
else:
|
||||
# No matching user message, treat as regular text
|
||||
responses_items.append({
|
||||
"type": "computer_call_output",
|
||||
"call_id": tool_call_id,
|
||||
"output": content
|
||||
})
|
||||
else:
|
||||
# Determine if this is a computer call or function call output
|
||||
try:
|
||||
# Try to parse as structured output
|
||||
parsed_content = json.loads(content)
|
||||
if parsed_content.get("type") == "input_image":
|
||||
responses_items.append({
|
||||
"type": "computer_call_output",
|
||||
"call_id": tool_call_id,
|
||||
"output": parsed_content
|
||||
})
|
||||
else:
|
||||
responses_items.append({
|
||||
"type": "computer_call_output",
|
||||
"call_id": tool_call_id,
|
||||
"output": content
|
||||
})
|
||||
except json.JSONDecodeError:
|
||||
# Plain text output - could be function or computer call
|
||||
responses_items.append({
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_call_id,
|
||||
"output": content
|
||||
})
|
||||
elif isinstance(content, list):
|
||||
# Handle structured content (e.g., images)
|
||||
for item in content:
|
||||
if item.get("type") == "image_url":
|
||||
responses_items.append({
|
||||
"type": "computer_call_output",
|
||||
"call_id": tool_call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": item.get("image_url", {}).get("url")
|
||||
}
|
||||
})
|
||||
elif item.get("type") == "text":
|
||||
responses_items.append({
|
||||
"type": "function_call_output",
|
||||
"call_id": tool_call_id,
|
||||
"output": item.get("text")
|
||||
})
|
||||
|
||||
# Handle actual user messages
|
||||
elif role == "user" and content:
|
||||
if isinstance(content, list):
|
||||
# Handle structured user content (e.g., text + images)
|
||||
user_content = []
|
||||
for item in content:
|
||||
if item.get("type") == "image_url":
|
||||
user_content.append({
|
||||
"type": "input_image",
|
||||
"image_url": item.get("image_url", {}).get("url")
|
||||
})
|
||||
elif item.get("type") == "text":
|
||||
user_content.append({
|
||||
"type": "input_text",
|
||||
"text": item.get("text")
|
||||
})
|
||||
|
||||
if user_content:
|
||||
responses_items.append({
|
||||
"role": "user",
|
||||
"type": "message",
|
||||
"content": user_content
|
||||
})
|
||||
elif isinstance(content, str):
|
||||
# Handle simple text user message
|
||||
responses_items.append({
|
||||
"role": "user",
|
||||
"content": content
|
||||
})
|
||||
|
||||
return responses_items
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from .base import ModelProtocol
|
||||
from .gta1 import GTA1Model
|
||||
|
||||
__all__ = ["ModelProtocol", "GTA1Model"]
|
||||
__all__ = ["ModelProtocol"]
|
||||
|
||||
@@ -21,7 +21,6 @@ import torch
|
||||
# Add parent directory to path for imports
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
from agent.agent import ComputerAgent
|
||||
from models import GTA1Model
|
||||
from models.base import ModelProtocol
|
||||
|
||||
def get_gpu_memory() -> List[int]:
|
||||
@@ -82,13 +81,16 @@ def get_available_models() -> List[Union[str, ModelProtocol]]:
|
||||
"""
|
||||
local_provider = "huggingface-local/" # Options: huggingface-local/ or mlx/
|
||||
|
||||
# from models.gta1 import GTA1Model
|
||||
|
||||
models = [
|
||||
# === ComputerAgent model strings ===
|
||||
f"{local_provider}HelloKKMe/GTA1-7B",
|
||||
# f"{local_provider}HelloKKMe/GTA1-7B",
|
||||
# f"{local_provider}HelloKKMe/GTA1-32B",
|
||||
"openai/computer-use-preview+openai/gpt-4o-mini"
|
||||
|
||||
# === Reference model classes ===
|
||||
GTA1Model("HelloKKMe/GTA1-7B"),
|
||||
# GTA1Model("HelloKKMe/GTA1-7B"),
|
||||
# GTA1Model("HelloKKMe/GTA1-32B"),
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user