mirror of
https://github.com/trycua/computer.git
synced 2026-01-01 11:00:31 -06:00
working gta1 loop
This commit is contained in:
@@ -117,6 +117,13 @@ def sanitize_message(msg: Any) -> Any:
|
||||
return sanitized
|
||||
return msg
|
||||
|
||||
def get_output_call_ids(messages: List[Dict[str, Any]]) -> List[str]:
|
||||
call_ids = []
|
||||
for message in messages:
|
||||
if message.get("type") == "computer_call_output" or message.get("type") == "function_call_output":
|
||||
call_ids.append(message.get("call_id"))
|
||||
return call_ids
|
||||
|
||||
class ComputerAgent:
|
||||
"""
|
||||
Main agent class that automatically selects the appropriate agent loop
|
||||
@@ -207,6 +214,7 @@ class ComputerAgent:
|
||||
litellm.custom_provider_map = [
|
||||
{"provider": "huggingface-local", "custom_handler": hf_adapter}
|
||||
]
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
# == Initialize computer agent ==
|
||||
|
||||
@@ -390,8 +398,10 @@ class ComputerAgent:
|
||||
# AGENT OUTPUT PROCESSING
|
||||
# ============================================================================
|
||||
|
||||
async def _handle_item(self, item: Any, computer: Optional[Computer] = None) -> List[Dict[str, Any]]:
|
||||
async def _handle_item(self, item: Any, computer: Optional[Computer] = None, ignore_call_ids: Optional[List[str]] = None) -> List[Dict[str, Any]]:
|
||||
"""Handle each item; may cause a computer action + screenshot."""
|
||||
if ignore_call_ids and item.get("call_id") and item.get("call_id") in ignore_call_ids:
|
||||
return []
|
||||
|
||||
item_type = item.get("type", None)
|
||||
|
||||
@@ -437,7 +447,7 @@ class ComputerAgent:
|
||||
acknowledged_checks = []
|
||||
for check in pending_checks:
|
||||
check_message = check.get("message", str(check))
|
||||
if acknowledge_safety_check_callback(check_message):
|
||||
if acknowledge_safety_check_callback(check_message, allow_always=True): # TODO: implement a callback for safety checks
|
||||
acknowledged_checks.append(check)
|
||||
else:
|
||||
raise ValueError(f"Safety check failed: {check_message}")
|
||||
@@ -512,9 +522,12 @@ class ComputerAgent:
|
||||
Returns:
|
||||
AsyncGenerator that yields response chunks
|
||||
"""
|
||||
if not self.agent_config_info:
|
||||
raise ValueError("Agent configuration not found")
|
||||
|
||||
capabilities = self.get_capabilities()
|
||||
if "step" not in capabilities:
|
||||
raise ValueError(f"Agent loop {self.agent_loop.__name__} does not support step predictions")
|
||||
raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support step predictions")
|
||||
|
||||
await self._initialize_computers()
|
||||
|
||||
@@ -529,7 +542,7 @@ class ComputerAgent:
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"model": self.model,
|
||||
"agent_loop": self.agent_loop.__name__,
|
||||
"agent_loop": self.agent_config_info.agent_class.__name__,
|
||||
**merged_kwargs
|
||||
}
|
||||
await self._on_run_start(run_kwargs, old_items)
|
||||
@@ -580,9 +593,12 @@ class ComputerAgent:
|
||||
# Add agent response to new_items
|
||||
new_items += result.get("output")
|
||||
|
||||
# Get output call ids
|
||||
output_call_ids = get_output_call_ids(result.get("output", []))
|
||||
|
||||
# Handle computer actions
|
||||
for item in result.get("output"):
|
||||
partial_items = await self._handle_item(item, self.computer_handler)
|
||||
partial_items = await self._handle_item(item, self.computer_handler, ignore_call_ids=output_call_ids)
|
||||
new_items += partial_items
|
||||
|
||||
# Yield partial response
|
||||
@@ -612,9 +628,12 @@ class ComputerAgent:
|
||||
Returns:
|
||||
None or tuple with (x, y) coordinates
|
||||
"""
|
||||
if not self.agent_config_info:
|
||||
raise ValueError("Agent configuration not found")
|
||||
|
||||
capabilities = self.get_capabilities()
|
||||
if "click" not in capabilities:
|
||||
raise ValueError(f"Agent loop {self.agent_loop.__name__} does not support click predictions")
|
||||
raise ValueError(f"Agent loop {self.agent_config_info.agent_class.__name__} does not support click predictions")
|
||||
if hasattr(self.agent_loop, 'predict_click'):
|
||||
if not image_b64:
|
||||
if not self.computer_handler:
|
||||
@@ -634,6 +653,9 @@ class ComputerAgent:
|
||||
Returns:
|
||||
List of capability strings (e.g., ["step", "click"])
|
||||
"""
|
||||
if not self.agent_config_info:
|
||||
raise ValueError("Agent configuration not found")
|
||||
|
||||
if hasattr(self.agent_loop, 'get_capabilities'):
|
||||
return self.agent_loop.get_capabilities()
|
||||
return ["step"] # Default capability
|
||||
@@ -9,10 +9,7 @@ import io
|
||||
import logging
|
||||
|
||||
try:
|
||||
from presidio_analyzer import AnalyzerEngine
|
||||
from presidio_anonymizer import AnonymizerEngine, DeanonymizeEngine
|
||||
from presidio_anonymizer.entities import RecognizerResult, OperatorConfig
|
||||
from presidio_image_redactor import ImageRedactorEngine
|
||||
# TODO: Add Presidio dependencies
|
||||
from PIL import Image
|
||||
PRESIDIO_AVAILABLE = True
|
||||
except ImportError:
|
||||
@@ -32,11 +29,7 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
anonymize_text: bool = True,
|
||||
anonymize_images: bool = True,
|
||||
entities_to_anonymize: Optional[List[str]] = None,
|
||||
anonymization_operator: str = "replace",
|
||||
image_redaction_color: Tuple[int, int, int] = (255, 192, 203) # Pink
|
||||
# TODO: Any extra kwargs if needed
|
||||
):
|
||||
"""
|
||||
Initialize the PII anonymization callback.
|
||||
@@ -51,23 +44,10 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
if not PRESIDIO_AVAILABLE:
|
||||
raise ImportError(
|
||||
"Presidio is not available. Install with: "
|
||||
"pip install presidio-analyzer presidio-anonymizer presidio-image-redactor"
|
||||
"pip install cua-agent[pii-anonymization]"
|
||||
)
|
||||
|
||||
self.anonymize_text = anonymize_text
|
||||
self.anonymize_images = anonymize_images
|
||||
self.entities_to_anonymize = entities_to_anonymize
|
||||
self.anonymization_operator = anonymization_operator
|
||||
self.image_redaction_color = image_redaction_color
|
||||
|
||||
# Initialize Presidio engines
|
||||
self.analyzer = AnalyzerEngine()
|
||||
self.anonymizer = AnonymizerEngine()
|
||||
self.deanonymizer = DeanonymizeEngine()
|
||||
self.image_redactor = ImageRedactorEngine()
|
||||
|
||||
# Store anonymization mappings for deanonymization
|
||||
self.anonymization_mappings: Dict[str, Any] = {}
|
||||
# TODO: Implement __init__
|
||||
|
||||
async def on_llm_start(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
@@ -79,9 +59,6 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
Returns:
|
||||
List of messages with PII anonymized
|
||||
"""
|
||||
if not self.anonymize_text and not self.anonymize_images:
|
||||
return messages
|
||||
|
||||
anonymized_messages = []
|
||||
for msg in messages:
|
||||
anonymized_msg = await self._anonymize_message(msg)
|
||||
@@ -99,9 +76,6 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
Returns:
|
||||
List of output with PII deanonymized for tool calls
|
||||
"""
|
||||
if not self.anonymize_text:
|
||||
return output
|
||||
|
||||
deanonymized_output = []
|
||||
for item in output:
|
||||
# Only deanonymize tool calls and computer_call messages
|
||||
@@ -114,146 +88,9 @@ class PIIAnonymizationCallback(AsyncCallbackHandler):
|
||||
return deanonymized_output
|
||||
|
||||
async def _anonymize_message(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Anonymize PII in a single message."""
|
||||
msg_copy = message.copy()
|
||||
|
||||
# Anonymize text content
|
||||
if self.anonymize_text:
|
||||
msg_copy = await self._anonymize_text_content(msg_copy)
|
||||
|
||||
# Redact images in computer_call_output
|
||||
if self.anonymize_images and msg_copy.get("type") == "computer_call_output":
|
||||
msg_copy = await self._redact_image_content(msg_copy)
|
||||
|
||||
return msg_copy
|
||||
|
||||
async def _anonymize_text_content(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Anonymize text content in a message."""
|
||||
msg_copy = message.copy()
|
||||
|
||||
# Handle content array
|
||||
content = msg_copy.get("content", [])
|
||||
if isinstance(content, str):
|
||||
anonymized_text, _ = await self._anonymize_text(content)
|
||||
msg_copy["content"] = anonymized_text
|
||||
elif isinstance(content, list):
|
||||
anonymized_content = []
|
||||
for item in content:
|
||||
if isinstance(item, dict) and item.get("type") == "text":
|
||||
text = item.get("text", "")
|
||||
anonymized_text, _ = await self._anonymize_text(text)
|
||||
item_copy = item.copy()
|
||||
item_copy["text"] = anonymized_text
|
||||
anonymized_content.append(item_copy)
|
||||
else:
|
||||
anonymized_content.append(item)
|
||||
msg_copy["content"] = anonymized_content
|
||||
|
||||
return msg_copy
|
||||
|
||||
async def _redact_image_content(self, message: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Redact PII from images in computer_call_output messages."""
|
||||
msg_copy = message.copy()
|
||||
output = msg_copy.get("output", {})
|
||||
|
||||
if isinstance(output, dict) and "image_url" in output:
|
||||
try:
|
||||
# Extract base64 image data
|
||||
image_url = output["image_url"]
|
||||
if image_url.startswith("data:image/"):
|
||||
# Parse data URL
|
||||
header, data = image_url.split(",", 1)
|
||||
image_data = base64.b64decode(data)
|
||||
|
||||
# Load image with PIL
|
||||
image = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# Redact PII from image
|
||||
redacted_image = self.image_redactor.redact(image, self.image_redaction_color)
|
||||
|
||||
# Convert back to base64
|
||||
buffer = io.BytesIO()
|
||||
redacted_image.save(buffer, format="PNG")
|
||||
redacted_data = base64.b64encode(buffer.getvalue()).decode()
|
||||
|
||||
# Update image URL
|
||||
output_copy = output.copy()
|
||||
output_copy["image_url"] = f"data:image/png;base64,{redacted_data}"
|
||||
msg_copy["output"] = output_copy
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to redact image: {e}")
|
||||
|
||||
return msg_copy
|
||||
# TODO: Implement _anonymize_message
|
||||
return message
|
||||
|
||||
async def _deanonymize_item(self, item: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Deanonymize PII in tool calls and computer outputs."""
|
||||
item_copy = item.copy()
|
||||
|
||||
# Handle computer_call arguments
|
||||
if item.get("type") == "computer_call":
|
||||
args = item_copy.get("args", {})
|
||||
if isinstance(args, dict):
|
||||
deanonymized_args = {}
|
||||
for key, value in args.items():
|
||||
if isinstance(value, str):
|
||||
deanonymized_value, _ = await self._deanonymize_text(value)
|
||||
deanonymized_args[key] = deanonymized_value
|
||||
else:
|
||||
deanonymized_args[key] = value
|
||||
item_copy["args"] = deanonymized_args
|
||||
|
||||
return item_copy
|
||||
|
||||
async def _anonymize_text(self, text: str) -> Tuple[str, List[RecognizerResult]]:
|
||||
"""Anonymize PII in text and return the anonymized text and results."""
|
||||
if not text.strip():
|
||||
return text, []
|
||||
|
||||
try:
|
||||
# Analyze text for PII
|
||||
analyzer_results = self.analyzer.analyze(
|
||||
text=text,
|
||||
entities=self.entities_to_anonymize,
|
||||
language="en"
|
||||
)
|
||||
|
||||
if not analyzer_results:
|
||||
return text, []
|
||||
|
||||
# Anonymize the text
|
||||
anonymized_result = self.anonymizer.anonymize(
|
||||
text=text,
|
||||
analyzer_results=analyzer_results,
|
||||
operators={entity_type: OperatorConfig(self.anonymization_operator)
|
||||
for entity_type in set(result.entity_type for result in analyzer_results)}
|
||||
)
|
||||
|
||||
# Store mapping for deanonymization
|
||||
mapping_key = str(hash(text))
|
||||
self.anonymization_mappings[mapping_key] = {
|
||||
"original": text,
|
||||
"anonymized": anonymized_result.text,
|
||||
"results": analyzer_results
|
||||
}
|
||||
|
||||
return anonymized_result.text, analyzer_results
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to anonymize text: {e}")
|
||||
return text, []
|
||||
|
||||
async def _deanonymize_text(self, text: str) -> Tuple[str, bool]:
|
||||
"""Attempt to deanonymize text using stored mappings."""
|
||||
try:
|
||||
# Look for matching anonymized text in mappings
|
||||
for mapping_key, mapping in self.anonymization_mappings.items():
|
||||
if mapping["anonymized"] == text:
|
||||
return mapping["original"], True
|
||||
|
||||
# If no mapping found, return original text
|
||||
return text, False
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to deanonymize text: {e}")
|
||||
return text, False
|
||||
# TODO: Implement _deanonymize_item
|
||||
return item
|
||||
@@ -120,7 +120,7 @@ async def ainput(prompt: str = ""):
|
||||
|
||||
async def chat_loop(agent, model: str, container_name: str, initial_prompt: str = "", show_usage: bool = True):
|
||||
"""Main chat loop with the agent."""
|
||||
print_welcome(model, agent.agent_loop.__name__, container_name)
|
||||
print_welcome(model, agent.agent_config_info.agent_class.__name__, container_name)
|
||||
|
||||
history = []
|
||||
|
||||
@@ -130,7 +130,7 @@ async def chat_loop(agent, model: str, container_name: str, initial_prompt: str
|
||||
total_cost = 0
|
||||
|
||||
while True:
|
||||
if history[-1].get("role") != "user":
|
||||
if len(history) == 0 or history[-1].get("role") != "user":
|
||||
# Get user input with prompt
|
||||
print_colored("> ", end="")
|
||||
user_input = await ainput()
|
||||
|
||||
@@ -93,8 +93,10 @@ class OpenAIComputerHandler:
|
||||
return ""
|
||||
|
||||
|
||||
def acknowledge_safety_check_callback(message: str) -> bool:
|
||||
def acknowledge_safety_check_callback(message: str, allow_always: bool = False) -> bool:
|
||||
"""Safety check callback for user acknowledgment."""
|
||||
if allow_always:
|
||||
return True
|
||||
response = input(
|
||||
f"Safety Check Warning: {message}\nDo you want to acknowledge and proceed? (y/n): "
|
||||
).lower()
|
||||
|
||||
@@ -10,8 +10,10 @@ import re
|
||||
import base64
|
||||
from typing import Dict, List, Any, AsyncGenerator, Union, Optional, Tuple
|
||||
from io import BytesIO
|
||||
import uuid
|
||||
from PIL import Image
|
||||
import litellm
|
||||
import math
|
||||
|
||||
from ..decorators import register_agent
|
||||
from ..types import Messages, AgentResponse, Tools, AgentCapability
|
||||
@@ -24,6 +26,73 @@ Output the coordinate pair exactly:
|
||||
(x,y)
|
||||
'''.strip()
|
||||
|
||||
# Global dictionary to map coordinates to descriptions
|
||||
xy2desc: Dict[Tuple[float, float], str] = {}
|
||||
|
||||
GTA1_TOOL_SCHEMA = {
|
||||
"type": "function",
|
||||
"name": "computer",
|
||||
"description": "Control a computer by taking screenshots and interacting with UI elements. This tool uses element descriptions to locate and interact with UI elements on the screen (e.g., 'red submit button', 'search text field', 'hamburger menu icon', 'close button in top right corner').",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"action": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"screenshot",
|
||||
"click",
|
||||
"double_click",
|
||||
"drag",
|
||||
"type",
|
||||
"keypress",
|
||||
"scroll",
|
||||
"move",
|
||||
"wait",
|
||||
"get_current_url",
|
||||
"get_dimensions",
|
||||
"get_environment"
|
||||
],
|
||||
"description": "The action to perform"
|
||||
},
|
||||
"element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to interact with (required for click, double_click, move, scroll actions, and as start/end for drag)"
|
||||
},
|
||||
"start_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to start dragging from (required for drag action)"
|
||||
},
|
||||
"end_element_description": {
|
||||
"type": "string",
|
||||
"description": "Description of the element to drag to (required for drag action)"
|
||||
},
|
||||
"text": {
|
||||
"type": "string",
|
||||
"description": "The text to type (required for type action)"
|
||||
},
|
||||
"keys": {
|
||||
"type": "string",
|
||||
"description": "Key combination to press (required for keypress action). Single key for individual key press, multiple keys for combinations (e.g., 'ctrl+c')"
|
||||
},
|
||||
"button": {
|
||||
"type": "string",
|
||||
"description": "The mouse button to use for click action (left, right, wheel, back, forward) Default: left",
|
||||
},
|
||||
"scroll_x": {
|
||||
"type": "integer",
|
||||
"description": "Horizontal scroll amount for scroll action (positive for right, negative for left)",
|
||||
},
|
||||
"scroll_y": {
|
||||
"type": "integer",
|
||||
"description": "Vertical scroll amount for scroll action (positive for down, negative for up)",
|
||||
},
|
||||
},
|
||||
"required": [
|
||||
"action"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
|
||||
"""Extract coordinates from model output."""
|
||||
try:
|
||||
@@ -32,6 +101,173 @@ def extract_coordinates(raw_string: str) -> Tuple[float, float]:
|
||||
except:
|
||||
return (0.0, 0.0)
|
||||
|
||||
def get_last_computer_call_output(messages: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
"""Get the last computer_call_output message from a messages list.
|
||||
|
||||
Args:
|
||||
messages: List of messages to search through
|
||||
|
||||
Returns:
|
||||
The last computer_call_output message dict, or None if not found
|
||||
"""
|
||||
for message in reversed(messages):
|
||||
if isinstance(message, dict) and message.get("type") == "computer_call_output":
|
||||
return message
|
||||
return None
|
||||
|
||||
def _prepare_tools_for_gta1(tool_schemas: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Prepare tools for GTA1 API format"""
|
||||
gta1_tools = []
|
||||
|
||||
for schema in tool_schemas:
|
||||
if schema["type"] == "computer":
|
||||
gta1_tools.append(GTA1_TOOL_SCHEMA)
|
||||
else:
|
||||
gta1_tools.append(schema)
|
||||
|
||||
return gta1_tools
|
||||
|
||||
async def replace_function_with_computer_call_gta1(item: Dict[str, Any], agent_instance) -> List[Dict[str, Any]]:
|
||||
"""Convert function_call to computer_call format using GTA1 click prediction."""
|
||||
global xy2desc
|
||||
item_type = item.get("type")
|
||||
|
||||
async def _get_xy(element_description: Optional[str], last_image_b64: str) -> Union[Tuple[float, float], Tuple[None, None]]:
|
||||
if element_description is None:
|
||||
return (None, None)
|
||||
# Use self.predict_click to get coordinates from description
|
||||
coords = await agent_instance.predict_click(
|
||||
model=agent_instance.current_model,
|
||||
image_b64=last_image_b64,
|
||||
instruction=element_description
|
||||
)
|
||||
if coords:
|
||||
# Store the mapping from coordinates to description
|
||||
xy2desc[coords] = element_description
|
||||
return coords
|
||||
return (None, None)
|
||||
|
||||
if item_type == "function_call":
|
||||
fn_name = item.get("name")
|
||||
fn_args = json.loads(item.get("arguments", "{}"))
|
||||
|
||||
item_id = item.get("id")
|
||||
call_id = item.get("call_id")
|
||||
|
||||
if fn_name == "computer":
|
||||
action = fn_args.get("action")
|
||||
element_description = fn_args.get("element_description")
|
||||
start_element_description = fn_args.get("start_element_description")
|
||||
end_element_description = fn_args.get("end_element_description")
|
||||
text = fn_args.get("text")
|
||||
keys = fn_args.get("keys")
|
||||
button = fn_args.get("button")
|
||||
scroll_x = fn_args.get("scroll_x")
|
||||
scroll_y = fn_args.get("scroll_y")
|
||||
|
||||
# Get the last computer output image for click prediction
|
||||
last_image_b64 = agent_instance.last_screenshot_b64 or ""
|
||||
|
||||
x, y = await _get_xy(element_description, last_image_b64)
|
||||
start_x, start_y = await _get_xy(start_element_description, last_image_b64)
|
||||
end_x, end_y = await _get_xy(end_element_description, last_image_b64)
|
||||
|
||||
action_args = {
|
||||
"type": action,
|
||||
"x": x,
|
||||
"y": y,
|
||||
"start_x": start_x,
|
||||
"start_y": start_y,
|
||||
"end_x": end_x,
|
||||
"end_y": end_y,
|
||||
"text": text,
|
||||
"keys": keys,
|
||||
"button": button,
|
||||
"scroll_x": scroll_x,
|
||||
"scroll_y": scroll_y
|
||||
}
|
||||
# Remove None values to keep the JSON clean
|
||||
action_args = {k: v for k, v in action_args.items() if v is not None}
|
||||
|
||||
return [{
|
||||
"type": "computer_call",
|
||||
"action": action_args,
|
||||
"id": item_id,
|
||||
"call_id": call_id,
|
||||
"status": "completed"
|
||||
}]
|
||||
|
||||
return [item]
|
||||
|
||||
async def replace_computer_call_with_function_gta1(item: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert computer_call back to function_call format using descriptions.
|
||||
Only READS from the global xy2desc dictionary.
|
||||
|
||||
Args:
|
||||
item: The item to convert
|
||||
"""
|
||||
global xy2desc
|
||||
item_type = item.get("type")
|
||||
|
||||
def _get_element_description(x: Optional[float], y: Optional[float]) -> Optional[str]:
|
||||
"""Get element description from coordinates, return None if coordinates are None"""
|
||||
if x is None or y is None:
|
||||
return None
|
||||
return xy2desc.get((x, y))
|
||||
|
||||
if item_type == "computer_call":
|
||||
action_data = item.get("action", {})
|
||||
|
||||
# Extract coordinates and convert back to element descriptions
|
||||
element_description = _get_element_description(action_data.get("x"), action_data.get("y"))
|
||||
start_element_description = _get_element_description(action_data.get("start_x"), action_data.get("start_y"))
|
||||
end_element_description = _get_element_description(action_data.get("end_x"), action_data.get("end_y"))
|
||||
|
||||
# Build function arguments
|
||||
fn_args = {
|
||||
"action": action_data.get("type"),
|
||||
"element_description": element_description,
|
||||
"start_element_description": start_element_description,
|
||||
"end_element_description": end_element_description,
|
||||
"text": action_data.get("text"),
|
||||
"keys": action_data.get("keys"),
|
||||
"button": action_data.get("button"),
|
||||
"scroll_x": action_data.get("scroll_x"),
|
||||
"scroll_y": action_data.get("scroll_y")
|
||||
}
|
||||
|
||||
# Remove None values to keep the JSON clean
|
||||
fn_args = {k: v for k, v in fn_args.items() if v is not None}
|
||||
|
||||
return [{
|
||||
"type": "function_call",
|
||||
"name": "computer",
|
||||
"arguments": json.dumps(fn_args),
|
||||
"id": item.get("id"),
|
||||
"call_id": item.get("call_id"),
|
||||
"status": "completed",
|
||||
# Fall back to string representation
|
||||
# "content": f"Used tool: {action_data.get('type')}({json.dumps(fn_args)})"
|
||||
}]
|
||||
|
||||
elif item_type == "computer_call_output":
|
||||
# Simple conversion: computer_call_output -> function_call_output (text only), user message (image)
|
||||
return [
|
||||
{
|
||||
"type": "function_call_output",
|
||||
"call_id": item.get("call_id"),
|
||||
"output": "Tool executed successfully. See the current computer screenshot below, if nothing has changed yet then you may need to wait before trying again.",
|
||||
"id": item.get("id"),
|
||||
"status": "completed"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": [item.get("output")]
|
||||
}
|
||||
]
|
||||
|
||||
return [item]
|
||||
|
||||
def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 3136, max_pixels: int = 8847360) -> Tuple[int, int]:
|
||||
"""Smart resize function similar to qwen_vl_utils."""
|
||||
# Calculate the total pixels
|
||||
@@ -64,10 +300,14 @@ def smart_resize(height: int, width: int, factor: int = 28, min_pixels: int = 31
|
||||
|
||||
return new_height, new_width
|
||||
|
||||
@register_agent(models=r".*GTA1-.*", priority=10)
|
||||
@register_agent(models=r".*GTA1.*", priority=10)
|
||||
class GTA1Config(AsyncAgentConfig):
|
||||
"""GTA1 agent configuration implementing AsyncAgentConfig protocol for click prediction."""
|
||||
|
||||
def __init__(self):
|
||||
self.current_model = None
|
||||
self.last_screenshot_b64 = None
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: Messages,
|
||||
@@ -84,9 +324,136 @@ class GTA1Config(AsyncAgentConfig):
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
GTA1 does not support step prediction - only click prediction.
|
||||
GTA1 agent loop implementation using liteLLM responses with element descriptions.
|
||||
|
||||
Follows the 4-step process:
|
||||
1. Prepare tools
|
||||
2. Replace computer calls with function calls (using descriptions)
|
||||
3. API call
|
||||
4. Replace function calls with computer calls (using predict_click)
|
||||
"""
|
||||
raise NotImplementedError("GTA1 agent only supports click prediction via predict_click method")
|
||||
models = model.split("+")
|
||||
if len(models) != 2:
|
||||
raise ValueError("GTA1 model must be in the format <gta1_model_name>+<planning_model_name> to be used in an agent loop")
|
||||
|
||||
gta1_model, llm_model = models
|
||||
self.current_model = gta1_model
|
||||
|
||||
tools = tools or []
|
||||
|
||||
# Step 0: Prepare tools
|
||||
gta1_tools = _prepare_tools_for_gta1(tools)
|
||||
|
||||
# Get last computer_call_output for screenshot reference
|
||||
# Convert messages to list of dicts first
|
||||
message_list = []
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
message_list.append(message.__dict__)
|
||||
else:
|
||||
message_list.append(message)
|
||||
|
||||
last_computer_call_output = get_last_computer_call_output(message_list)
|
||||
if last_computer_call_output:
|
||||
image_url = last_computer_call_output.get("output", {}).get("image_url", "")
|
||||
if image_url.startswith("data:image/"):
|
||||
self.last_screenshot_b64 = image_url.split(",")[-1]
|
||||
else:
|
||||
self.last_screenshot_b64 = image_url
|
||||
|
||||
# Step 1: If there's no screenshot, simulate the model calling the screenshot function
|
||||
pre_output = []
|
||||
if not self.last_screenshot_b64 and computer_handler:
|
||||
screenshot_base64 = await computer_handler.screenshot()
|
||||
if _on_screenshot:
|
||||
await _on_screenshot(screenshot_base64, "screenshot_initial")
|
||||
|
||||
call_id = uuid.uuid4().hex
|
||||
pre_output += [
|
||||
{
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"text": "Taking a screenshot to see the current computer screen."
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"action": {
|
||||
"type": "screenshot"
|
||||
},
|
||||
"call_id": call_id,
|
||||
"status": "completed",
|
||||
"type": "computer_call"
|
||||
},
|
||||
{
|
||||
"type": "computer_call_output",
|
||||
"call_id": call_id,
|
||||
"output": {
|
||||
"type": "input_image",
|
||||
"image_url": f"data:image/png;base64,{screenshot_base64}"
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
# Update the last screenshot for future use
|
||||
self.last_screenshot_b64 = screenshot_base64
|
||||
|
||||
message_list += pre_output
|
||||
|
||||
# Step 2: Replace computer calls with function calls (using descriptions)
|
||||
new_messages = []
|
||||
for message in message_list:
|
||||
new_messages += await replace_computer_call_with_function_gta1(message)
|
||||
messages = new_messages
|
||||
|
||||
# Step 3: API call
|
||||
api_kwargs = {
|
||||
"model": llm_model,
|
||||
"input": messages,
|
||||
"tools": gta1_tools if gta1_tools else None,
|
||||
"stream": stream,
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Call API start hook
|
||||
if _on_api_start:
|
||||
await _on_api_start(api_kwargs)
|
||||
|
||||
# Use liteLLM responses
|
||||
response = await litellm.aresponses(**api_kwargs)
|
||||
|
||||
# Call API end hook
|
||||
if _on_api_end:
|
||||
await _on_api_end(api_kwargs, response)
|
||||
|
||||
# Extract usage information
|
||||
usage = {
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0), # type: ignore
|
||||
}
|
||||
if _on_usage:
|
||||
await _on_usage(usage)
|
||||
|
||||
# Step 4: Replace function calls with computer calls (using predict_click)
|
||||
new_output = []
|
||||
for i in range(len(response.output)): # type: ignore
|
||||
output_item = response.output[i] # type: ignore
|
||||
# Convert to dict if it has model_dump method, otherwise use as-is
|
||||
if hasattr(output_item, 'model_dump'):
|
||||
item_dict = output_item.model_dump() # type: ignore
|
||||
else:
|
||||
item_dict = output_item # type: ignore
|
||||
new_output += await replace_function_with_computer_call_gta1(item_dict, self) # type: ignore
|
||||
|
||||
return {
|
||||
"output": pre_output + new_output,
|
||||
"usage": usage
|
||||
}
|
||||
|
||||
async def predict_click(
|
||||
self,
|
||||
@@ -106,75 +473,70 @@ class GTA1Config(AsyncAgentConfig):
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
try:
|
||||
# Decode base64 image
|
||||
image_data = base64.b64decode(image_b64)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
width, height = image.width, image.height
|
||||
|
||||
# Smart resize the image (similar to qwen_vl_utils)
|
||||
resized_height, resized_width = smart_resize(
|
||||
height, width,
|
||||
factor=28, # Default factor for Qwen models
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160
|
||||
)
|
||||
resized_image = image.resize((resized_width, resized_height))
|
||||
scale_x, scale_y = width / resized_width, height / resized_height
|
||||
|
||||
# Convert resized image back to base64
|
||||
buffered = BytesIO()
|
||||
resized_image.save(buffered, format="PNG")
|
||||
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
# Prepare system and user messages
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width)
|
||||
}
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{resized_image_b64}"
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": instruction
|
||||
# Decode base64 image
|
||||
image_data = base64.b64decode(image_b64)
|
||||
image = Image.open(BytesIO(image_data))
|
||||
width, height = image.width, image.height
|
||||
|
||||
# Smart resize the image (similar to qwen_vl_utils)
|
||||
resized_height, resized_width = smart_resize(
|
||||
height, width,
|
||||
factor=28, # Default factor for Qwen models
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160
|
||||
)
|
||||
resized_image = image.resize((resized_width, resized_height))
|
||||
scale_x, scale_y = width / resized_width, height / resized_height
|
||||
|
||||
# Convert resized image back to base64
|
||||
buffered = BytesIO()
|
||||
resized_image.save(buffered, format="PNG")
|
||||
resized_image_b64 = base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
# Prepare system and user messages
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": SYSTEM_PROMPT.format(height=resized_height, width=resized_width)
|
||||
}
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{resized_image_b64}"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": [system_message, user_message],
|
||||
"max_tokens": 32,
|
||||
"temperature": 0.0,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Extract response text
|
||||
output_text = response.choices[0].message.content
|
||||
|
||||
# Extract and rescale coordinates
|
||||
pred_x, pred_y = extract_coordinates(output_text)
|
||||
pred_x *= scale_x
|
||||
pred_y *= scale_y
|
||||
|
||||
return (pred_x, pred_y)
|
||||
|
||||
except Exception as e:
|
||||
print(f"GTA1 click prediction failed: {e}")
|
||||
return None
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": instruction
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": [system_message, user_message],
|
||||
"max_tokens": 32,
|
||||
"temperature": 0.0,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Use liteLLM acompletion
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Extract response text
|
||||
output_text = response.choices[0].message.content
|
||||
|
||||
# Extract and rescale coordinates
|
||||
pred_x, pred_y = extract_coordinates(output_text)
|
||||
pred_x *= scale_x
|
||||
pred_y *= scale_y
|
||||
|
||||
return (math.floor(pred_x), math.floor(pred_y))
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
"""Return the capabilities supported by this agent."""
|
||||
return ["click"]
|
||||
return ["click", "step"]
|
||||
|
||||
6
libs/python/agent/agent/loops/model_types.csv
Normal file
6
libs/python/agent/agent/loops/model_types.csv
Normal file
@@ -0,0 +1,6 @@
|
||||
model,predict_step,predict_point
|
||||
anthropic,✅,✅
|
||||
openai,✅,✅
|
||||
uitars,✅,✅
|
||||
omniparser,❌,✅
|
||||
gta1,❌,✅
|
||||
|
@@ -310,7 +310,6 @@ class OmniparsrConfig(AsyncAgentConfig):
|
||||
"input": messages,
|
||||
"tools": openai_tools if openai_tools else None,
|
||||
"stream": stream,
|
||||
"reasoning": {"summary": "concise"},
|
||||
"truncation": "auto",
|
||||
"num_retries": max_retries,
|
||||
**kwargs
|
||||
@@ -331,7 +330,7 @@ class OmniparsrConfig(AsyncAgentConfig):
|
||||
|
||||
# Extract usage information
|
||||
usage = {
|
||||
**response.usage.model_dump(),
|
||||
**response.usage.model_dump(), # type: ignore
|
||||
"response_cost": response._hidden_params.get("response_cost", 0.0),
|
||||
}
|
||||
if _on_usage:
|
||||
@@ -339,7 +338,7 @@ class OmniparsrConfig(AsyncAgentConfig):
|
||||
|
||||
# handle som function calls -> xy computer calls
|
||||
new_output = []
|
||||
for i in range(len(response.output)):
|
||||
for i in range(len(response.output)): # type: ignore
|
||||
new_output += await replace_function_with_computer_call(response.output[i].model_dump(), id2xy)
|
||||
|
||||
return {
|
||||
|
||||
@@ -5,8 +5,7 @@ Example usage of the agent library with docstring-based tool definitions.
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from agent import agent_loop, ComputerAgent
|
||||
from agent.types import Messages
|
||||
from agent import ComputerAgent
|
||||
from computer import Computer
|
||||
from computer.helpers import sandboxed
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ dependencies = [
|
||||
"cua-computer>=0.3.0,<0.5.0",
|
||||
"cua-core>=0.1.8,<0.2.0",
|
||||
"certifi>=2024.2.2",
|
||||
"litellm>=1.74.8"
|
||||
"litellm>=1.74.12"
|
||||
]
|
||||
requires-python = ">=3.11"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user