Files
computer/libs/agent/agent/providers/uitars/loop.py
T
2025-04-23 14:27:42 -04:00

595 lines
23 KiB
Python

"""UI-TARS-specific agent loop implementation."""
import logging
import asyncio
import re
import os
import json
import base64
import copy
from typing import Any, Dict, List, Optional, Tuple, AsyncGenerator, cast
from httpx import ConnectError, ReadTimeout
from ...core.base import BaseLoop
from ...core.messages import StandardMessageManager, ImageRetentionConfig
from ...core.types import AgentResponse, LLMProvider
from ...core.visualization import VisualizationHelper
from computer import Computer
from .utils import add_box_token, parse_actions, parse_action_parameters
from .tools.manager import ToolManager
from .tools.computer import ToolResult
from .prompts import COMPUTER_USE, SYSTEM_PROMPT
from .clients.oaicompat import OAICompatClient
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class UITARSLoop(BaseLoop):
"""UI-TARS-specific implementation of the agent loop.
This class extends BaseLoop to provide support for the UI-TARS model
with computer control capabilities.
"""
###########################################
# INITIALIZATION AND CONFIGURATION
###########################################
def __init__(
self,
computer: Computer,
api_key: str,
model: str,
provider_base_url: Optional[str] = "http://localhost:8000/v1",
only_n_most_recent_images: Optional[int] = 2,
base_dir: Optional[str] = "trajectories",
max_retries: int = 3,
retry_delay: float = 1.0,
save_trajectory: bool = True,
**kwargs,
):
"""Initialize the loop.
Args:
computer: Computer instance
api_key: API key (may not be needed for local endpoints)
model: Model name (e.g., "ui-tars")
provider_base_url: Base URL for the API provider
only_n_most_recent_images: Maximum number of recent screenshots to include in API requests
base_dir: Base directory for saving experiment data
max_retries: Maximum number of retries for API calls
retry_delay: Delay between retries in seconds
save_trajectory: Whether to save trajectory data
"""
# Set provider before initializing base class
self.provider = LLMProvider.OAICOMPAT
self.provider_base_url = provider_base_url
# Initialize message manager with image retention config
self.message_manager = StandardMessageManager(
config=ImageRetentionConfig(num_images_to_keep=only_n_most_recent_images)
)
# Initialize base class (which will set up experiment manager)
super().__init__(
computer=computer,
model=model,
api_key=api_key,
max_retries=max_retries,
retry_delay=retry_delay,
base_dir=base_dir,
save_trajectory=save_trajectory,
only_n_most_recent_images=only_n_most_recent_images,
**kwargs,
)
# Set API client attributes
self.client = None
self.retry_count = 0
# Initialize visualization helper
self.viz_helper = VisualizationHelper(agent=self)
# Initialize tool manager
self.tool_manager = ToolManager(computer=computer)
logger.info("UITARSLoop initialized with StandardMessageManager")
async def initialize(self) -> None:
"""Initialize the loop by setting up tools and clients."""
# Initialize base class
await super().initialize()
# Initialize tool manager with error handling
try:
logger.info("Initializing tool manager...")
await self.tool_manager.initialize()
logger.info("Tool manager initialized successfully.")
except Exception as e:
logger.error(f"Error initializing tool manager: {str(e)}")
logger.warning("Will attempt to initialize tools on first use.")
# Initialize client for the OAICompat provider
try:
await self.initialize_client()
except Exception as e:
logger.error(f"Error initializing client: {str(e)}")
raise RuntimeError(f"Failed to initialize client: {str(e)}")
###########################################
# CLIENT INITIALIZATION - IMPLEMENTING ABSTRACT METHOD
###########################################
async def initialize_client(self) -> None:
"""Initialize the appropriate client.
Implements abstract method from BaseLoop to set up the specific
provider client (OAICompat for UI-TARS).
"""
try:
logger.info(f"Initializing OAICompat client for UI-TARS with model {self.model}...")
self.client = OAICompatClient(
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
model=self.model,
provider_base_url=self.provider_base_url,
)
logger.info(f"Initialized OAICompat client with model {self.model}")
except Exception as e:
logger.error(f"Error initializing client: {str(e)}")
self.client = None
raise RuntimeError(f"Failed to initialize client: {str(e)}")
###########################################
# MESSAGE FORMATTING
###########################################
def to_uitars_format(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert messages to UI-TARS compatible format.
Args:
messages: List of messages in standard format
Returns:
List of messages formatted for UI-TARS
"""
# Create a copy of the messages to avoid modifying the original
uitars_messages = copy.deepcopy(messages)
# Find the first user message to modify
first_user_idx = None
instruction = ""
for idx, msg in enumerate(uitars_messages):
if msg.get("role") == "user":
first_user_idx = idx
content = msg.get("content", "")
if isinstance(content, str):
instruction = content
break
elif isinstance(content, list):
for item in content:
if item.get("type") == "text":
instruction = item.get("text", "")
break
if instruction:
break
# Only modify the first user message if found
if first_user_idx is not None and instruction:
# Create the computer use prompt
user_prompt = COMPUTER_USE.format(
instruction=instruction,
language="English"
)
# Replace the content of the first user message
if isinstance(uitars_messages[first_user_idx].get("content", ""), str):
uitars_messages[first_user_idx]["content"] = [{"type": "text", "text": user_prompt}]
elif isinstance(uitars_messages[first_user_idx].get("content", ""), list):
# Find and replace only the text part, keeping images
for i, item in enumerate(uitars_messages[first_user_idx]["content"]):
if item.get("type") == "text":
uitars_messages[first_user_idx]["content"][i]["text"] = user_prompt
break
# Add box tokens to assistant responses
for idx, msg in enumerate(uitars_messages):
if msg.get("role") == "assistant":
content = msg.get("content", "")
if content and isinstance(content, list):
for i, part in enumerate(content):
if part.get('type') == 'text':
uitars_messages[idx]["content"][i]["text"] = add_box_token(part['text'])
return uitars_messages
###########################################
# API CALL HANDLING
###########################################
async def _make_api_call(self, messages: List[Dict[str, Any]], system_prompt: str) -> Any:
"""Make API call to provider with retry logic."""
# Create new turn directory for this API call
self._create_turn_dir()
request_data = None
last_error = None
for attempt in range(self.max_retries):
try:
# Ensure client is initialized
if self.client is None:
logger.info(
f"Client not initialized in _make_api_call (attempt {attempt+1}), initializing now..."
)
await self.initialize_client()
if self.client is None:
raise RuntimeError("Failed to initialize client")
# Convert messages to UI-TARS format
uitars_messages = self.to_uitars_format(messages)
# Log request
request_data = {
"messages": uitars_messages,
"max_tokens": self.max_tokens,
"system": system_prompt,
}
self._log_api_call("request", request_data)
# Make API call
response = await self.client.run_interleaved(
messages=uitars_messages,
system=system_prompt,
max_tokens=self.max_tokens,
)
# Log success response
self._log_api_call("response", request_data, response)
return response
except (ConnectError, ReadTimeout) as e:
last_error = e
logger.warning(
f"Connection error on attempt {attempt + 1}/{self.max_retries}: {str(e)}"
)
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay * (attempt + 1)) # Exponential backoff
# Reset client on connection errors to force re-initialization
self.client = None
continue
except RuntimeError as e:
# Handle client initialization errors specifically
last_error = e
self._log_api_call("error", request_data, error=e)
logger.error(
f"Client initialization error (attempt {attempt + 1}/{self.max_retries}): {str(e)}"
)
if attempt < self.max_retries - 1:
# Reset client to force re-initialization
self.client = None
await asyncio.sleep(self.retry_delay)
continue
except Exception as e:
# Log unexpected error
last_error = e
self._log_api_call("error", request_data, error=e)
logger.error(f"Unexpected error in API call: {str(e)}")
if attempt < self.max_retries - 1:
await asyncio.sleep(self.retry_delay)
continue
# If we get here, all retries failed
error_message = f"API call failed after {self.max_retries} attempts"
if last_error:
error_message += f": {str(last_error)}"
logger.error(error_message)
raise RuntimeError(error_message)
###########################################
# RESPONSE AND ACTION HANDLING
###########################################
async def _handle_response(
self, response: Any, messages: List[Dict[str, Any]]
) -> Tuple[bool, bool]:
"""Handle API response.
Args:
response: API response
messages: List of messages to update
Returns:
Tuple of (should_continue, action_screenshot_saved)
"""
action_screenshot_saved = False
try:
# Step 1: Extract the raw response text
raw_text = None
try:
# OpenAI-compatible response format
raw_text = response["choices"][0]["message"]["content"]
except (KeyError, TypeError, IndexError) as e:
logger.error(f"Invalid response format: {str(e)}")
return True, action_screenshot_saved
# Step 2: Add the response to message history
self.message_manager.add_assistant_message([{"type": "text", "text": raw_text}])
# Step 3: Parse actions from the response
parsed_actions = parse_actions(raw_text)
if not parsed_actions:
logger.warning("No action found in the response")
return True, action_screenshot_saved
# Step 4: Execute each action
for action in parsed_actions:
action_type = None
# Handle "finished" action
if action.startswith("finished"):
logger.info("Agent completed the task")
return False, action_screenshot_saved
# Process other action types (click, type, etc.)
try:
# Parse action parameters using the utility function
action_name, tool_args = parse_action_parameters(action)
if not action_name:
logger.warning(f"Could not parse action: {action}")
continue
# Mark actions that would create screenshots
if action_name in ["click", "left_double", "right_single", "drag", "scroll"]:
action_screenshot_saved = True
# Execute the tool with prepared arguments
await self._ensure_tools_initialized()
# Let's log what we're about to execute for debugging
logger.info(f"Executing computer tool with arguments: {tool_args}")
result = await self.tool_manager.execute_tool(name="computer", tool_input=tool_args)
# Handle the result
if hasattr(result, "error") and result.error:
logger.error(f"Error executing tool: {result.error}")
else:
# Action was successful
logger.info(f"Successfully executed {action_name}")
# Save screenshot if one was returned and we haven't already saved one
if hasattr(result, "base64_image") and result.base64_image:
self._save_screenshot(result.base64_image, action_type=action_name)
action_screenshot_saved = True
except Exception as e:
logger.error(f"Error executing action {action}: {str(e)}")
# Continue the loop if there are actions to process
return True, action_screenshot_saved
except Exception as e:
logger.error(f"Error handling response: {str(e)}")
# Add error message using the message manager
error_message = [{"type": "text", "text": f"Error: {str(e)}"}]
self.message_manager.add_assistant_message(error_message)
raise
###########################################
# SCREEN HANDLING
###########################################
async def _get_current_screen(self, save_screenshot: bool = True) -> str:
"""Get the current screen as a base64 encoded image.
Args:
save_screenshot: Whether to save the screenshot
Returns:
Base64 encoded screenshot
"""
try:
# Take a screenshot
screenshot = await self.computer.interface.screenshot()
# Convert to base64
img_base64 = base64.b64encode(screenshot).decode("utf-8")
# Process screenshot through hooks and save if needed
await self.handle_screenshot(img_base64, action_type="state")
# Save screenshot if requested
if save_screenshot and self.save_trajectory:
self._save_screenshot(img_base64, action_type="state")
return img_base64
except Exception as e:
logger.error(f"Error getting current screen: {str(e)}")
raise
###########################################
# SYSTEM PROMPT
###########################################
def _get_system_prompt(self) -> str:
"""Get the system prompt for the model."""
return SYSTEM_PROMPT
###########################################
# MAIN LOOP - IMPLEMENTING ABSTRACT METHOD
###########################################
async def run(self, messages: List[Dict[str, Any]]) -> AsyncGenerator[Dict[str, Any], None]:
"""Run the agent loop with provided messages.
Args:
messages: List of messages in standard OpenAI format
Yields:
Agent response format
"""
# Initialize the message manager with the provided messages
self.message_manager.messages = messages.copy()
logger.info(f"Starting UITARSLoop run with {len(self.message_manager.messages)} messages")
# Continue running until explicitly told to stop
running = True
turn_created = False
# Track if an action-specific screenshot has been saved this turn
action_screenshot_saved = False
attempt = 0
max_attempts = 3
while running and attempt < max_attempts:
try:
# Create a new turn directory if it's not already created
if not turn_created:
self._create_turn_dir()
turn_created = True
# Ensure client is initialized
if self.client is None:
logger.info("Initializing client...")
await self.initialize_client()
if self.client is None:
raise RuntimeError("Failed to initialize client")
logger.info("Client initialized successfully")
# Get current screen
base64_screenshot = await self._get_current_screen()
# Add screenshot to message history
self.message_manager.add_user_message(
[
{
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{base64_screenshot}"},
}
]
)
logger.info("Added screenshot to message history")
# Get system prompt
system_prompt = self._get_system_prompt()
# Make API call with retries
response = await self._make_api_call(
self.message_manager.messages, system_prompt
)
# Handle the response (may execute actions)
# Returns: (should_continue, action_screenshot_saved)
should_continue, new_screenshot_saved = await self._handle_response(
response, self.message_manager.messages
)
# Update whether an action screenshot was saved this turn
action_screenshot_saved = action_screenshot_saved or new_screenshot_saved
# Parse actions from the raw response
raw_response = response["choices"][0]["message"]["content"]
parsed_actions = parse_actions(raw_response)
# Extract thought content if available
thought = ""
if "Thought:" in raw_response:
thought_match = re.search(r"Thought: (.*?)(?=\s*Action:|$)", raw_response, re.DOTALL)
if thought_match:
thought = thought_match.group(1).strip()
# Create standardized thought response format
thought_response = {
"role": "assistant",
"content": thought or raw_response,
"metadata": {
"title": "🧠 UI-TARS Thoughts"
}
}
# Create action response format
action_response = {
"role": "assistant",
"content": str(parsed_actions),
"metadata": {
"title": "🖱️ UI-TARS Actions",
}
}
# Yield both responses to the caller (thoughts first, then actions)
yield thought_response
if parsed_actions:
yield action_response
# Check if we should continue this conversation
running = should_continue
# Create a new turn directory if we're continuing
if running:
turn_created = False
# Reset attempt counter on success
attempt = 0
except Exception as e:
attempt += 1
error_msg = f"Error in run method (attempt {attempt}/{max_attempts}): {str(e)}"
logger.error(error_msg)
# If this is our last attempt, provide more info about the error
if attempt >= max_attempts:
logger.error(f"Maximum retry attempts reached. Last error was: {str(e)}")
yield {
"error": str(e),
"metadata": {"title": "❌ Error"},
}
# Create a brief delay before retrying
await asyncio.sleep(1)
###########################################
# UTILITY METHODS
###########################################
async def _ensure_tools_initialized(self) -> None:
"""Ensure the tool manager and tools are initialized before use."""
if not hasattr(self.tool_manager, "tools") or self.tool_manager.tools is None:
logger.info("Tools not initialized. Initializing now...")
await self.tool_manager.initialize()
logger.info("Tools initialized successfully.")
async def process_model_response(self, response_text: str) -> Optional[Dict[str, Any]]:
"""Process model response to extract tool calls.
Args:
response_text: Model response text
Returns:
Extracted tool information, or None if no tool call was found
"""
# UI-TARS doesn't use the standard tool call format, so we parse its actions differently
parsed_actions = parse_actions(response_text)
if parsed_actions:
return {"actions": parsed_actions}
return None