working gta1 loop

This commit is contained in:
Dillon DuPont
2025-08-01 15:49:19 -04:00
parent 5902be2917
commit d5564977f0
9 changed files with 484 additions and 257 deletions

View File

@@ -117,6 +117,13 @@ def sanitize_message(msg: Any) -> Any:
return sanitized
return msg
def get_output_call_ids(messages: List[Dict[str, Any]]) -> List[str]:
call_ids = []
for message in messages:
if message.get("type") == "computer_call_output" or message.get("type") == "function_call_output":
call_ids.append(message.get("call_id"))
return call_ids
class ComputerAgent:
"""
Main agent class that automatically selects the appropriate agent loop
@@ -207,6 +214,7 @@ class ComputerAgent:
litellm.custom_provider_map = [
{"provider": "huggingface-local", "custom_handler": hf_adapter}
]
litellm.suppress_debug_info = True
# == Initialize computer agent ==
@@ -390,8 +398,10 @@ class ComputerAgent:
# AGENT OUTPUT PROCESSING
# ============================================================================
async def _handle_item(self, item: Any, computer: Optional[Computer] = None) -> List[Dict[str, Any]]:
async def _handle_item(self, item: Any, computer: Optional[Computer] = None, ignore_call_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]:
"""Handle each item; may cause a computer action + screenshot."""
if ignore_call_ids and item.get("call_id") and item.get("call_id") in ignore_call_ids:
return []
item_type = item.get("type", None)
@@ -437,7 +447,7 @@ class ComputerAgent:
acknowledged_checks = []
for check in pending_checks:
check_message = check.get("message", str(check))
if acknowledge_safety_check_callback(check_message):
if acknowledge_safety_check_callback(check_message, allow_always=True): # TODO: implement a callback for safety checks
acknowledged_checks.append(check)
else:
raise ValueError(f"Safety check failed: {check_message}")
@@ -512,9 +522,12 @@ class ComputerAgent:
Returns:
AsyncGenerator that yields response chunks
"""
if not self.agent_config_info:
raise ValueError("Agent configuration not found")
capabilities = self.get_capabilities()
if "step" not in capabilities:
raise ValueError(f"Agent loop {self.agent_loop.__name__} does not support step predictions")
raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions")
await self._initialize_computers()
@@ -529,7 +542,7 @@ class ComputerAgent:
"messages": messages,
"stream": stream,
"model": self.model,
"agent_loop": self.agent_loop.__name__,
"agent_loop": self.agent_config_info.agent_class.__name__,
**merged_kwargs
}
await self._on_run_start(run_kwargs, old_items)
@@ -580,9 +593,12 @@ class ComputerAgent:
# Add agent response to new_items
new_items += result.get("output")
# Get output call ids
output_call_ids = get_output_call_ids(result.get("output", []))
# Handle computer actions
for item in result.get("output"):
partial_items = await self._handle_item(item, self.computer_handler)
partial_items = await self._handle_item(item, self.computer_handler, ignore_call_ids=output_call_ids)
new_items += partial_items
# Yield partial response
@@ -612,9 +628,12 @@ class ComputerAgent:
Returns:
None or tuple with (x, y) coordinates
"""
if not self.agent_config_info:
raise ValueError("Agent configuration not found")
capabilities = self.get_capabilities()
if "click" not in capabilities:
raise ValueError(f"Agent loop {self.agent_loop.__name__} does not support click predictions")
raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions")
if hasattr(self.agent_loop, 'predict_click'):
if not image_b64:
if not self.computer_handler:
@@ -634,6 +653,9 @@ class ComputerAgent:
Returns:
List of capability strings (e.g., ["step", "click"])
"""
if not self.agent_config_info:
raise ValueError("Agent configuration not found")
if hasattr(self.agent_loop, 'get_capabilities'):
return self.agent_loop.get_capabilities()
return ["step"] # Default capability

View File

@@ -9,10 +9,7 @@ import io
import logging
try:
from presidio_analyzer import AnalyzerEngine
from presidio_anonymizer import AnonymizerEngine, DeanonymizeEngine
from presidio_anonymizer.entities import RecognizerResult, OperatorConfig
from presidio_image_redactor import ImageRedactorEngine
# TODO: Add Presidio dependencies
from PIL import Image
PRESIDIO_AVAILABLE = True
except ImportError:
@@ -32,11 +29,7 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
def __init__(
self,
anonymize_text: bool = True,
anonymize_images: bool = True,
entities_to_anonymize: Optional[List[str]] = None,
anonymization_operator: str = "replace",
image_redaction_color: Tuple[int, int, int] = (255, 192, 203) # Pink
# TODO: Any extra kwargs if needed
):
"""
Initialize the PII anonymization callback.
@@ -51,23 +44,10 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
if not PRESIDIO_AVAILABLE:
raise ImportError(
"Presidio is not available. Install with: "
"pip install presidio-analyzer presidio-anonymizer presidio-image-redactor"
"pip install cua-agent[pii-anonymization]"
)
self.anonymize_text = anonymize_text
self.anonymize_images = anonymize_images
self.entities_to_anonymize = entities_to_anonymize
self.anonymization_operator = anonymization_operator
self.image_redaction_color = image_redaction_color
# Initialize Presidio engines
self.analyzer = AnalyzerEngine()
self.anonymizer = AnonymizerEngine()
self.deanonymizer = DeanonymizeEngine()
self.image_redactor = ImageRedactorEngine()
# Store anonymization mappings for deanonymization
self.anonymization_mappings: Dict[str, Any] = {}
# TODO: Implement __init__
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
@@ -79,9 +59,6 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
Returns:
List of messages with PII anonymized
"""
if not self.anonymize_text and not self.anonymize_images:
return messages
anonymized_messages = []
for msg in messages:
anonymized_msg = await self._anonymize_message(msg)
@@ -99,9 +76,6 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
Returns:
List of output with PII deanonymized for tool calls
"""
if not self.anonymize_text:
return output
deanonymized_output = []
for item in output:
# Only deanonymize tool calls and computer_call messages
@@ -114,146 +88,9 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
return deanonymized_output
async def _anonymize_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Anonymize PII in a single message."""
msg_copy = message.copy()
# Anonymize text content
if self.anonymize_text:
msg_copy = await self._anonymize_text_content(msg_copy)
# Redact images in computer_call_output
if self.anonymize_images and msg_copy.get("type") == "computer_call_output":
msg_copy = await self._redact_image_content(msg_copy)
return msg_copy
async def _anonymize_text_content(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Anonymize text content in a message."""
msg_copy = message.copy()
# Handle content array
content = msg_copy.get("content", [])
if isinstance(content, str):
anonymized_text, _ = await self._anonymize_text(content)
msg_copy["content"] = anonymized_text
elif isinstance(content, list):
anonymized_content = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
text = item.get("text", "")
anonymized_text, _ = await self._anonymize_text(text)
item_copy = item.copy()
item_copy["text"] = anonymized_text
anonymized_content.append(item_copy)
else:
anonymized_content.append(item)
msg_copy["content"] = anonymized_content
return msg_copy
async def _redact_image_content(self, message: Dict[str, Any]) -> Dict[str, Any]:
"""Redact PII from images in computer_call_output messages."""
msg_copy = message.copy()
output = msg_copy.get("output", {})
if isinstance(output, dict) and "image_url" in output:
try:
# Extract base64 image data
image_url = output["image_url"]
if image_url.startswith("data:image/"):
# Parse data URL
header, data = image_url.split(",", 1)
image_data = base64.b64decode(data)
# Load image with PIL
image = Image.open(io.BytesIO(image_data))
# Redact PII from image
redacted_image = self.image_redactor.redact(image, self.image_redaction_color)
# Convert back to base64
buffer = io.BytesIO()
redacted_image.save(buffer, format="PNG")
redacted_data = base64.b64encode(buffer.getvalue()).decode()
# Update image URL
output_copy = output.copy()
output_copy["image_url"] = f"data:image/png;base64,{redacted_data}"
msg_copy["output"] = output_copy
except Exception as e:
logger.warning(f"Failed to redact image: {e}")
return msg_copy
# TODO: Implement _anonymize_message
return message
async def _deanonymize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
"""Deanonymize PII in tool calls and computer outputs."""
item_copy = item.copy()
# Handle computer_call arguments
if item.get("type") == "computer_call":
args = item_copy.get("args", {})
if isinstance(args, dict):
deanonymized_args = {}
for key, value in args.items():
if isinstance(value, str):
deanonymized_value, _ = await self._deanonymize_text(value)
deanonymized_args[key] = deanonymized_value
else:
deanonymized_args[key] = value
item_copy["args"] = deanonymized_args
return item_copy
async def _anonymize_text(self, text: str) -> Tuple[str, List[RecognizerResult]]:
"""Anonymize PII in text and return the anonymized text and results."""
if not text.strip():
return text, []
try:
# Analyze text for PII
analyzer_results = self.analyzer.analyze(
text=text,
entities=self.entities_to_anonymize,
language="en"
)
if not analyzer_results:
return text, []
# Anonymize the text
anonymized_result = self.anonymizer.anonymize(
text=text,
analyzer_results=analyzer_results,
operators={entity_type: OperatorConfig(self.anonymization_operator)
for entity_type in set(result.entity_type for result in analyzer_results)}
)
# Store mapping for deanonymization
mapping_key = str(hash(text))
self.anonymization_mappings[mapping_key] = {
"original": text,
"anonymized": anonymized_result.text,
"results": analyzer_results
}
return anonymized_result.text, analyzer_results
except Exception as e:
logger.warning(f"Failed to anonymize text: {e}")
return text, []
async def _deanonymize_text(self, text: str) -> Tuple[str, bool]:
"""Attempt to deanonymize text using stored mappings."""
try:
# Look for matching anonymized text in mappings
for mapping_key, mapping in self.anonymization_mappings.items():
if mapping["anonymized"] == text:
return mapping["original"], True
# If no mapping found, return original text
return text, False
except Exception as e:
logger.warning(f"Failed to deanonymize text: {e}")
return text, False
# TODO: Implement _deanonymize_item
return item

View File

@@ -120,7 +120,7 @@ async def ainput(prompt: str = ""):
async def chat_loop(agent, model: str, container_name: str, initial_prompt: str = "", show_usage: bool = True):
"""Main chat loop with the agent."""
print_welcome(model, agent.agent_loop.__name__, container_name)
print_welcome(model, agent.agent_config_info.agent_class.__name__, container_name)
history = []
@@ -130,7 +130,7 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
total_cost = 0
while True:
if history[-1].get("role") != "user":
if len(history) == 0 or history[-1].get("role") != "user":
# Get user input with prompt
print_colored("> ", end="")
user_input = await ainput()

View File

@@ -93,8 +93,10 @@ class OpenAIComputerHandler:
return ""
def acknowledge_safety_check_callback(message: str) -> bool:
def acknowledge_safety_check_callback(message: str, allow_always: bool = False) -> bool:
"""Safety check callback for user acknowledgment."""
if allow_always:
return True
response = input(
f"Safety Check Warning: {message}\nDo you want to acknowledge and proceed? (y/n): "
).lower()

View File

@@ -10,8 +10,10 @@ import re
import base64
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
from io import BytesIO
import uuid
from PIL import Image
import litellm
import math
from ..decorators import register_agent
from ..types import Messages, AgentResponse, Tools, AgentCapability
@@ -24,6 +26,73 @@ Output the coordinate pair exactly:
(x,y)
'''.strip()
# Global dictionary to map coordinates to descriptions
xy2desc: Dict[Tuple[float, float], str] = {}
GTA1_TOOL_SCHEMA = {
"type": "function",
"name": "computer",
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
"parameters": {
"type": "object",
"properties": {
"action": {
"type": "string",
"enum": [
"screenshot",
"click",
"double_click",
"drag",
"type",
"keypress",
"scroll",
"move",
"wait",
"get_current_url",
"get_dimensions",
"get_environment"
],
"description": "The action to perform"
},
"element_description": {
"type": "string",
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
},
"start_element_description": {
"type": "string",
"description": "Description of the element to start dragging from (required for drag action)"
},
"end_element_description": {
"type": "string",
"description": "Description of the element to drag to (required for drag action)"
},
"text": {
"type": "string",
"description": "The text to type (required for type action)"
},
"keys": {
"type": "string",
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
},
"button": {
"type": "string",
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
},
"scroll_x": {
"type": "integer",
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
},
"scroll_y": {
"type": "integer",
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
},
},
"required": [
"action"
]
}
}
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
"""Extract coordinates from model output."""
try:
@@ -32,6 +101,173 @@ def extract_coordinates(raw_string: str) -> Tuple[float, float]:
except:
return (0.0, 0.0)
def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Get the last computer_call_output message from a messages list.
Args:
messages: List of messages to search through
Returns:
The last computer_call_output message dict, or None if not found
"""
for message in reversed(messages):
if isinstance(message, dict) and message.get("type") == "computer_call_output":
return message
return None
def _prepare_tools_for_gta1(tool_schemas: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Prepare tools for GTA1 API format"""
gta1_tools = []
for schema in tool_schemas:
if schema["type"] == "computer":
gta1_tools.append(GTA1_TOOL_SCHEMA)
else:
gta1_tools.append(schema)
return gta1_tools
async def replace_function_with_computer_call_gta1(item: Dict[str, Any], agent_instance) -> List[Dict[str, Any]]:
"""Convert function_call to computer_call format using GTA1 click prediction."""
global xy2desc
item_type = item.get("type")
async def _get_xy(element_description: Optional[str], last_image_b64: str) -> Union[Tuple[float, float], Tuple[None, None]]:
if element_description is None:
return (None, None)
# Use self.predict_click to get coordinates from description
coords = await agent_instance.predict_click(
model=agent_instance.current_model,
image_b64=last_image_b64,
instruction=element_description
)
if coords:
# Store the mapping from coordinates to description
xy2desc[coords] = element_description
return coords
return (None, None)
if item_type == "function_call":
fn_name = item.get("name")
fn_args = json.loads(item.get("arguments", "{}"))
item_id = item.get("id")
call_id = item.get("call_id")
if fn_name == "computer":
action = fn_args.get("action")
element_description = fn_args.get("element_description")
start_element_description = fn_args.get("start_element_description")
end_element_description = fn_args.get("end_element_description")
text = fn_args.get("text")
keys = fn_args.get("keys")
button = fn_args.get("button")
scroll_x = fn_args.get("scroll_x")
scroll_y = fn_args.get("scroll_y")
# Get the last computer output image for click prediction
last_image_b64 = agent_instance.last_screenshot_b64 or ""
x, y = await _get_xy(element_description, last_image_b64)
start_x, start_y = await _get_xy(start_element_description, last_image_b64)
end_x, end_y = await _get_xy(end_element_description, last_image_b64)
action_args = {
"type": action,
"x": x,
"y": y,
"start_x": start_x,
"start_y": start_y,
"end_x": end_x,
"end_y": end_y,
"text": text,
"keys": keys,
"button": button,
"scroll_x": scroll_x,
"scroll_y": scroll_y
}
# Remove None values to keep the JSON clean
action_args = {k: v for k, v in action_args.items() if v is not None}
return [{
"type": "computer_call",
"action": action_args,
"id": item_id,
"call_id": call_id,
"status": "completed"
}]
return [item]
async def replace_computer_call_with_function_gta1(item: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Convert computer_call back to function_call format using descriptions.
Only READS from the global xy2desc dictionary.
Args:
item: The item to convert
"""
global xy2desc
item_type = item.get("type")
def _get_element_description(x: Optional[float], y: Optional[float]) -> Optional[str]:
"""Get element description from coordinates, return None if coordinates are None"""
if x is None or y is None:
return None
return xy2desc.get((x, y))
if item_type == "computer_call":
action_data = item.get("action", {})
# Extract coordinates and convert back to element descriptions
element_description = _get_element_description(action_data.get("x"), action_data.get("y"))
start_element_description = _get_element_description(action_data.get("start_x"), action_data.get("start_y"))
end_element_description = _get_element_description(action_data.get("end_x"), action_data.get("end_y"))
# Build function arguments
fn_args = {
"action": action_data.get("type"),
"element_description": element_description,
"start_element_description": start_element_description,
"end_element_description": end_element_description,
"text": action_data.get("text"),
"keys": action_data.get("keys"),
"button": action_data.get("button"),
"scroll_x": action_data.get("scroll_x"),
"scroll_y": action_data.get("scroll_y")
}
# Remove None values to keep the JSON clean
fn_args = {k: v for k, v in fn_args.items() if v is not None}
return [{
"type": "function_call",
"name": "computer",
"arguments": json.dumps(fn_args),
"id": item.get("id"),
"call_id": item.get("call_id"),
"status": "completed",
# Fall back to string representation
# "content": f"Used tool: {action_data.get('type')}({json.dumps(fn_args)})"
}]
elif item_type == "computer_call_output":
# Simple conversion: computer_call_output -> function_call_output (text only), user message (image)
return [
{
"type": "function_call_output",
"call_id": item.get("call_id"),
"output": "Tool executed successfully. See the current computer screenshot below, if nothing has changed yet then you may need to wait before trying again.",
"id": item.get("id"),
"status": "completed"
}, {
"role": "user",
"content": [item.get("output")]
}
]
return [item]
def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360) -> Tuple[int, int]:
"""Smart resize function similar to qwen_vl_utils."""
# Calculate the total pixels
@@ -64,10 +300,14 @@ def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 31
return new_height, new_width
@register_agent(models=r".*GTA1-.*", priority=10)
@register_agent(models=r".*GTA1.*", priority=10)
class GTA1Config(AsyncAgentConfig):
"""GTA1 agent configuration implementing AsyncAgentConfig protocol for click prediction."""
def __init__(self):
self.current_model = None
self.last_screenshot_b64 = None
async def predict_step(
self,
messages: Messages,
@@ -84,9 +324,136 @@ class GTA1Config(AsyncAgentConfig):
**kwargs
) -> Dict[str, Any]:
"""
GTA1 does not support step prediction - only click prediction.
GTA1 agent loop implementation using liteLLM responses with element descriptions.
Follows the 4-step process:
1. Prepare tools
2. Replace computer calls with function calls (using descriptions)
3. API call
4. Replace function calls with computer calls (using predict_click)
"""
raise NotImplementedError("GTA1 agent only supports click prediction via predict_click method")
models = model.split("+")
if len(models) != 2:
raise ValueError("GTA1 model must be in the format <gta1_model_name>+<planning_model_name> to be used in an agent loop")
gta1_model, llm_model = models
self.current_model = gta1_model
tools = tools or []
# Step 0: Prepare tools
gta1_tools = _prepare_tools_for_gta1(tools)
# Get last computer_call_output for screenshot reference
# Convert messages to list of dicts first
message_list = []
for message in messages:
if not isinstance(message, dict):
message_list.append(message.__dict__)
else:
message_list.append(message)
last_computer_call_output = get_last_computer_call_output(message_list)
if last_computer_call_output:
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
if image_url.startswith("data:image/"):
self.last_screenshot_b64 = image_url.split(",")[-1]
else:
self.last_screenshot_b64 = image_url
# Step 1: If there's no screenshot, simulate the model calling the screenshot function
pre_output = []
if not self.last_screenshot_b64 and computer_handler:
screenshot_base64 = await computer_handler.screenshot()
if _on_screenshot:
await _on_screenshot(screenshot_base64, "screenshot_initial")
call_id = uuid.uuid4().hex
pre_output += [
{
"type": "message",
"role": "assistant",
"content": [
{
"type": "output_text",
"text": "Taking a screenshot to see the current computer screen."
}
]
},
{
"action": {
"type": "screenshot"
},
"call_id": call_id,
"status": "completed",
"type": "computer_call"
},
{
"type": "computer_call_output",
"call_id": call_id,
"output": {
"type": "input_image",
"image_url": f"data:image/png;base64,{screenshot_base64}"
}
},
]
# Update the last screenshot for future use
self.last_screenshot_b64 = screenshot_base64
message_list += pre_output
# Step 2: Replace computer calls with function calls (using descriptions)
new_messages = []
for message in message_list:
new_messages += await replace_computer_call_with_function_gta1(message)
messages = new_messages
# Step 3: API call
api_kwargs = {
"model": llm_model,
"input": messages,
"tools": gta1_tools if gta1_tools else None,
"stream": stream,
"truncation": "auto",
"num_retries": max_retries,
**kwargs
}
# Call API start hook
if _on_api_start:
await _on_api_start(api_kwargs)
# Use liteLLM responses
response = await litellm.aresponses(**api_kwargs)
# Call API end hook
if _on_api_end:
await _on_api_end(api_kwargs, response)
# Extract usage information
usage = {
**response.usage.model_dump(), # type: ignore
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
}
if _on_usage:
await _on_usage(usage)
# Step 4: Replace function calls with computer calls (using predict_click)
new_output = []
for i in range(len(response.output)): # type: ignore
output_item = response.output[i] # type: ignore
# Convert to dict if it has model_dump method, otherwise use as-is
if hasattr(output_item, 'model_dump'):
item_dict = output_item.model_dump() # type: ignore
else:
item_dict = output_item # type: ignore
new_output += await replace_function_with_computer_call_gta1(item_dict, self) # type: ignore
return {
"output": pre_output + new_output,
"usage": usage
}
async def predict_click(
self,
@@ -106,75 +473,70 @@ class GTA1Config(AsyncAgentConfig):
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
try:
# Decode base64 image
image_data = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_data))
width, height = image.width, image.height
# Smart resize the image (similar to qwen_vl_utils)
resized_height, resized_width = smart_resize(
height, width,
factor=28, # Default factor for Qwen models
min_pixels=3136,
max_pixels=4096 * 2160
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height
# Convert resized image back to base64
buffered = BytesIO()
resized_image.save(buffered, format="PNG")
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
# Prepare system and user messages
system_message = {
"role": "system",
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width)
}
user_message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{resized_image_b64}"
}
},
{
"type": "text",
"text": instruction
# Decode base64 image
image_data = base64.b64decode(image_b64)
image = Image.open(BytesIO(image_data))
width, height = image.width, image.height
# Smart resize the image (similar to qwen_vl_utils)
resized_height, resized_width = smart_resize(
height, width,
factor=28, # Default factor for Qwen models
min_pixels=3136,
max_pixels=4096 * 2160
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height
# Convert resized image back to base64
buffered = BytesIO()
resized_image.save(buffered, format="PNG")
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
# Prepare system and user messages
system_message = {
"role": "system",
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width)
}
user_message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{resized_image_b64}"
}
]
}
# Prepare API call kwargs
api_kwargs = {
"model": model,
"messages": [system_message, user_message],
"max_tokens": 32,
"temperature": 0.0,
**kwargs
}
# Use liteLLM acompletion
response = await litellm.acompletion(**api_kwargs)
# Extract response text
output_text = response.choices[0].message.content
# Extract and rescale coordinates
pred_x, pred_y = extract_coordinates(output_text)
pred_x *= scale_x
pred_y *= scale_y
return (pred_x, pred_y)
except Exception as e:
print(f"GTA1 click prediction failed: {e}")
return None
},
{
"type": "text",
"text": instruction
}
]
}
# Prepare API call kwargs
api_kwargs = {
"model": model,
"messages": [system_message, user_message],
"max_tokens": 32,
"temperature": 0.0,
**kwargs
}
# Use liteLLM acompletion
response = await litellm.acompletion(**api_kwargs)
# Extract response text
output_text = response.choices[0].message.content
# Extract and rescale coordinates
pred_x, pred_y = extract_coordinates(output_text)
pred_x *= scale_x
pred_y *= scale_y
return (math.floor(pred_x), math.floor(pred_y))
def get_capabilities(self) -> List[AgentCapability]:
"""Return the capabilities supported by this agent."""
return ["click"]
return ["click", "step"]

View File

@@ -0,0 +1,6 @@
model,predict_step,predict_point
anthropic,,
openai,,
uitars,,
omniparser,,
gta1,,
1 model predict_step predict_point
2 anthropic
3 openai
4 uitars
5 omniparser
6 gta1

View File

@@ -310,7 +310,6 @@ class OmniparsrConfig(AsyncAgentConfig):
"input": messages,
"tools": openai_tools if openai_tools else None,
"stream": stream,
"reasoning": {"summary": "concise"},
"truncation": "auto",
"num_retries": max_retries,
**kwargs
@@ -331,7 +330,7 @@ class OmniparsrConfig(AsyncAgentConfig):
# Extract usage information
usage = {
**response.usage.model_dump(),
**response.usage.model_dump(), # type: ignore
"response_cost": response._hidden_params.get("response_cost", 0.0),
}
if _on_usage:
@@ -339,7 +338,7 @@ class OmniparsrConfig(AsyncAgentConfig):
# handle som function calls -> xy computer calls
new_output = []
for i in range(len(response.output)):
for i in range(len(response.output)): # type: ignore
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy)
return {

View File

@@ -5,8 +5,7 @@ Example usage of the agent library with docstring-based tool definitions.
import asyncio
import logging
from agent import agent_loop, ComputerAgent
from agent.types import Messages
from agent import ComputerAgent
from computer import Computer
from computer.helpers import sandboxed

View File

@@ -22,7 +22,7 @@ dependencies = [
"cua-computer>=0.3.0,<0.5.0",
"cua-core>=0.1.8,<0.2.0",
"certifi>=2024.2.2",
"litellm>=1.74.8"
"litellm>=1.74.12"
]
requires-python = ">=3.11"