fixed uitars agent, added it to gradio ui

This commit is contained in:
Dillon DuPont
2025-04-23 14:27:42 -04:00
parent b7f3fbe3d3
commit fdf2509a7f
9 changed files with 1261 additions and 591 deletions

View File

@@ -0,0 +1,35 @@
"""Base client implementation for Omni providers."""
import logging
from typing import Dict, List, Optional, Any, Tuple
logger = logging.getLogger(__name__)
class BaseUITarsClient:
"""Base class for provider-specific clients."""
def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None):
"""Initialize base client.
Args:
api_key: Optional API key
model: Optional model name
"""
self.api_key = api_key
self.model = model
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""Run interleaved chat completion.
Args:
messages: List of message dicts
system: System prompt
max_tokens: Optional max tokens override
Returns:
Response dict
"""
raise NotImplementedError

View File

@@ -0,0 +1,204 @@
"""OpenAI-compatible client implementation."""
import os
import logging
from typing import Dict, List, Optional, Any
import aiohttp
import re
from .base import BaseUITarsClient
logger = logging.getLogger(__name__)
# OpenAI-compatible client for the UI_Tars
class OAICompatClient(BaseUITarsClient):
"""OpenAI-compatible API client implementation.
This client can be used with any service that implements the OpenAI API protocol, including:
- Huggingface Text Generation Interface endpoints
- vLLM
- LM Studio
- LocalAI
- Ollama (with OpenAI compatibility)
- Text Generation WebUI
- Any other service with OpenAI API compatibility
"""
def __init__(
self,
api_key: Optional[str] = None,
model: str = "Qwen2.5-VL-7B-Instruct",
provider_base_url: Optional[str] = "http://localhost:8000/v1",
max_tokens: int = 4096,
temperature: float = 0.0,
):
"""Initialize the OpenAI-compatible client.
Args:
api_key: Not used for local endpoints, usually set to "EMPTY"
model: Model name to use
provider_base_url: API base URL. Typically in the format "http://localhost:PORT/v1"
Examples:
- vLLM: "http://localhost:8000/v1"
- LM Studio: "http://localhost:1234/v1"
- LocalAI: "http://localhost:8080/v1"
- Ollama: "http://localhost:11434/v1"
max_tokens: Maximum tokens to generate
temperature: Generation temperature
"""
super().__init__(api_key=api_key or "EMPTY", model=model)
self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
self.model = model
self.provider_base_url = (
provider_base_url or "http://localhost:8000/v1"
) # Use default if None
self.max_tokens = max_tokens
self.temperature = temperature
def _extract_base64_image(self, text: str) -> Optional[str]:
"""Extract base64 image data from an HTML img tag."""
pattern = r'data:image/[^;]+;base64,([^"]+)'
match = re.search(pattern, text)
return match.group(1) if match else None
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Create a loggable version of messages with image data truncated."""
loggable_messages = []
for msg in messages:
if isinstance(msg.get("content"), list):
new_content = []
for content in msg["content"]:
if content.get("type") == "image":
new_content.append(
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
)
else:
new_content.append(content)
loggable_messages.append({"role": msg["role"], "content": new_content})
else:
loggable_messages.append(msg)
return loggable_messages
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""Run interleaved chat completion.
Args:
messages: List of message dicts
system: System prompt
max_tokens: Optional max tokens override
Returns:
Response dict
"""
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
final_messages = [{"role": "system", "content": system}]
# Process messages
for item in messages:
if isinstance(item, dict):
if isinstance(item["content"], list):
# Content is already in the correct format
final_messages.append(item)
else:
# Single string content, check for image
base64_img = self._extract_base64_image(item["content"])
if base64_img:
message = {
"role": item["role"],
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
}
],
}
else:
message = {
"role": item["role"],
"content": [{"type": "text", "text": item["content"]}],
}
final_messages.append(message)
else:
# String content, check for image
base64_img = self._extract_base64_image(item)
if base64_img:
message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
}
],
}
else:
message = {"role": "user", "content": [{"type": "text", "text": item}]}
final_messages.append(message)
payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
payload["max_tokens"] = max_tokens or self.max_tokens
try:
async with aiohttp.ClientSession() as session:
# Use default base URL if none provided
base_url = self.provider_base_url or "http://localhost:8000/v1"
# Check if the base URL already includes the chat/completions endpoint
endpoint_url = base_url
if not endpoint_url.endswith("/chat/completions"):
# If URL is RunPod format, make it OpenAI compatible
if endpoint_url.startswith("https://api.runpod.ai/v2/"):
# Extract RunPod endpoint ID
parts = endpoint_url.split("/")
if len(parts) >= 5:
runpod_id = parts[4]
endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
# If the URL ends with /v1, append /chat/completions
elif endpoint_url.endswith("/v1"):
endpoint_url = f"{endpoint_url}/chat/completions"
# If the URL doesn't end with /v1, make sure it has a proper structure
elif not endpoint_url.endswith("/"):
endpoint_url = f"{endpoint_url}/chat/completions"
else:
endpoint_url = f"{endpoint_url}chat/completions"
# Log the endpoint URL for debugging
logger.debug(f"Using endpoint URL: {endpoint_url}")
async with session.post(endpoint_url, headers=headers, json=payload) as response:
# Log the status and content type
logger.debug(f"Status: {response.status}")
logger.debug(f"Content-Type: {response.headers.get('Content-Type')}")
# Get the raw text of the response
response_text = await response.text()
logger.debug(f"Response content: {response_text}")
# Try to parse as JSON if the content type is appropriate
if "application/json" in response.headers.get('Content-Type', ''):
response_json = await response.json()
else:
raise Exception(f"Response is not JSON format")
# # Optionally try to parse it anyway
# try:
# import json
# response_json = json.loads(response_text)
# except json.JSONDecodeError as e:
# print(f"Failed to parse response as JSON: {e}")
if response.status != 200:
error_msg = response_json.get("error", {}).get(
"message", str(response_json)
)
logger.error(f"Error in API call: {error_msg}")
raise Exception(f"API error: {error_msg}")
return response_json
except Exception as e:
logger.error(f"Error in API call: {str(e)}")
raise

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,7 @@
"""Prompts for UI-TARS agent."""
SYSTEM_PROMPT = "You are a helpful assistant."
COMPUTER_USE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format

View File

@@ -0,0 +1 @@
"""UI-TARS tools package."""

View File

@@ -0,0 +1,279 @@
"""Computer tool for UI-TARS."""
import asyncio
import base64
import logging
import re
from typing import Any, Dict, List, Optional, Literal, Union
from computer import Computer
from ....core.tools.base import ToolResult, ToolFailure
from ....core.tools.computer import BaseComputerTool
logger = logging.getLogger(__name__)
class ComputerTool(BaseComputerTool):
"""
A tool that allows the UI-TARS agent to interact with the screen, keyboard, and mouse.
"""
name: str = "computer"
width: Optional[int] = None
height: Optional[int] = None
computer: Computer
def __init__(self, computer: Computer):
"""Initialize the computer tool.
Args:
computer: Computer instance
"""
super().__init__(computer)
self.computer = computer
self.width = None
self.height = None
self.logger = logging.getLogger(__name__)
def to_params(self) -> Dict[str, Any]:
"""Convert tool to API parameters.
Returns:
Dictionary with tool parameters
"""
if self.width is None or self.height is None:
raise RuntimeError(
"Screen dimensions not initialized. Call initialize_dimensions() first."
)
return {
"type": "computer",
"display_width": self.width,
"display_height": self.height,
}
async def initialize_dimensions(self) -> None:
"""Initialize screen dimensions from the computer interface."""
try:
display_size = await self.computer.interface.get_screen_size()
self.width = display_size["width"]
self.height = display_size["height"]
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
except Exception as e:
# Fall back to defaults if we can't get accurate dimensions
self.width = 1024
self.height = 768
self.logger.warning(
f"Failed to get screen dimensions, using defaults: {self.width}x{self.height}. Error: {e}"
)
async def __call__(
self,
*,
action: str,
**kwargs,
) -> ToolResult:
"""Execute a computer action.
Args:
action: The action to perform (based on UI-TARS action space)
**kwargs: Additional parameters for the action
Returns:
ToolResult containing action output and possibly a base64 image
"""
try:
# Ensure dimensions are initialized
if self.width is None or self.height is None:
await self.initialize_dimensions()
if self.width is None or self.height is None:
return ToolFailure(error="Failed to initialize screen dimensions")
# Handle actions defined in UI-TARS action space (from prompts.py)
# Handle standard click (left click)
if action == "click":
if "x" in kwargs and "y" in kwargs:
x, y = kwargs["x"], kwargs["y"]
await self.computer.interface.left_click(x, y)
# Wait briefly for UI to update
await asyncio.sleep(0.5)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Clicked at ({x}, {y})",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing coordinates for click action")
# Handle double click
elif action == "left_double":
if "x" in kwargs and "y" in kwargs:
x, y = kwargs["x"], kwargs["y"]
await self.computer.interface.double_click(x, y)
# Wait briefly for UI to update
await asyncio.sleep(0.5)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Double-clicked at ({x}, {y})",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing coordinates for left_double action")
# Handle right click
elif action == "right_single":
if "x" in kwargs and "y" in kwargs:
x, y = kwargs["x"], kwargs["y"]
await self.computer.interface.right_click(x, y)
# Wait briefly for UI to update
await asyncio.sleep(0.5)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Right-clicked at ({x}, {y})",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing coordinates for right_single action")
# Handle typing text
elif action == "type_text":
if "text" in kwargs:
text = kwargs["text"]
await self.computer.interface.type_text(text)
# Wait for UI to update
await asyncio.sleep(0.3)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Typed: {text}",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing text for type action")
# Handle hotkey
elif action == "hotkey":
if "keys" in kwargs:
keys = kwargs["keys"]
for key in keys:
await self.computer.interface.press_key(key)
# Wait for UI to update
await asyncio.sleep(0.3)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Pressed hotkey: {', '.join(keys)}",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing keys for hotkey action")
# Handle drag action
elif action == "drag":
if all(k in kwargs for k in ["start_x", "start_y", "end_x", "end_y"]):
start_x, start_y = kwargs["start_x"], kwargs["start_y"]
end_x, end_y = kwargs["end_x"], kwargs["end_y"]
# Perform drag
await self.computer.interface.move_cursor(start_x, start_y)
await self.computer.interface.drag_to(end_x, end_y)
# Wait for UI to update
await asyncio.sleep(0.5)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Dragged from ({start_x}, {start_y}) to ({end_x}, {end_y})",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing coordinates for drag action")
# Handle scroll action
elif action == "scroll":
if all(k in kwargs for k in ["x", "y", "direction"]):
x, y = kwargs["x"], kwargs["y"]
direction = kwargs["direction"]
# Move cursor to position
await self.computer.interface.move_cursor(x, y)
# Scroll based on direction
if direction == "down":
await self.computer.interface.scroll_down(5)
elif direction == "up":
await self.computer.interface.scroll_up(5)
elif direction == "right":
pass # await self.computer.interface.scroll_right(5)
elif direction == "left":
pass # await self.computer.interface.scroll_left(5)
else:
return ToolFailure(error=f"Invalid scroll direction: {direction}")
# Wait for UI to update
await asyncio.sleep(0.5)
# Take screenshot after action
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output=f"Scrolled {direction} at ({x}, {y})",
base64_image=base64_screenshot,
)
else:
return ToolFailure(error="Missing parameters for scroll action")
# Handle wait action
elif action == "wait":
# Sleep for 5 seconds as specified in the action space
await asyncio.sleep(5)
# Take screenshot after waiting
screenshot = await self.computer.interface.screenshot()
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
return ToolResult(
output="Waited for 5 seconds",
base64_image=base64_screenshot,
)
# Handle finished action (task completion)
elif action == "finished":
content = kwargs.get("content", "Task completed")
return ToolResult(
output=f"Task finished: {content}",
)
return await self._handle_scroll(action)
else:
return ToolFailure(error=f"Unsupported action: {action}")
except Exception as e:
self.logger.error(f"Error in ComputerTool.__call__: {str(e)}")
return ToolFailure(error=f"Failed to execute {action}: {str(e)}")

View File

@@ -0,0 +1,60 @@
"""Tool manager for the UI-TARS provider."""
import logging
from typing import Any, Dict, List, Optional
from computer import Computer
from ....core.tools import BaseToolManager
from ....core.tools.collection import ToolCollection
from .computer import ComputerTool
logger = logging.getLogger(__name__)
class ToolManager(BaseToolManager):
"""Manages UI-TARS provider tool initialization and execution."""
def __init__(self, computer: Computer):
"""Initialize the tool manager.
Args:
computer: Computer instance for computer-related tools
"""
super().__init__(computer)
# Initialize UI-TARS-specific tools
self.computer_tool = ComputerTool(self.computer)
self._initialized = False
def _initialize_tools(self) -> ToolCollection:
"""Initialize all available tools."""
return ToolCollection(self.computer_tool)
async def _initialize_tools_specific(self) -> None:
"""Initialize UI-TARS provider-specific tool requirements."""
await self.computer_tool.initialize_dimensions()
def get_tool_params(self) -> List[Dict[str, Any]]:
"""Get tool parameters for API calls.
Returns:
List of tool parameters for the current provider's API
"""
if self.tools is None:
raise RuntimeError("Tools not initialized. Call initialize() first.")
return self.tools.to_params()
async def execute_tool(self, name: str, tool_input: dict[str, Any]) -> Any:
"""Execute a tool with the given input.
Args:
name: Name of the tool to execute
tool_input: Input parameters for the tool
Returns:
Result of the tool execution
"""
if self.tools is None:
raise RuntimeError("Tools not initialized. Call initialize() first.")
return await self.tools.run(name=name, tool_input=tool_input)

View File

@@ -0,0 +1,153 @@
"""Utility functions for the UI-TARS provider."""
import logging
import base64
import re
from typing import Any, Dict, List, Optional, Union, Tuple
logger = logging.getLogger(__name__)
def add_box_token(input_string: str) -> str:
"""Add box tokens to the coordinates in the model response.
Args:
input_string: Raw model response
Returns:
String with box tokens added
"""
if "Action: " not in input_string or "start_box=" not in input_string:
return input_string
suffix = input_string.split("Action: ")[0] + "Action: "
actions = input_string.split("Action: ")[1:]
processed_actions = []
for action in actions:
action = action.strip()
coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action)
updated_action = action
for coord_type, x, y in coordinates:
updated_action = updated_action.replace(
f"{coord_type}='({x},{y})'",
f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'"
)
processed_actions.append(updated_action)
return suffix + "\n\n".join(processed_actions)
def parse_actions(response: str) -> List[str]:
"""Parse actions from UI-TARS model response.
Args:
response: The raw model response text
Returns:
List of parsed actions
"""
actions = []
# Extract the Action part from the response
if "Action:" in response:
action_text = response.split("Action:")[-1].strip()
# Clean up and format action
if action_text:
# Handle multiple actions separated by newlines
action_parts = action_text.split("\n\n")
for part in action_parts:
if part.strip():
actions.append(part.strip())
return actions
def parse_action_parameters(action: str) -> Tuple[str, Dict[str, Any]]:
"""Parse parameters from an action string.
Args:
action: The action string to parse
Returns:
Tuple of (action_name, action_parameters)
"""
# Handle "finished" action
if action.startswith("finished"):
return "finished", {}
# Parse action parameters
action_match = re.match(r'(\w+)\((.*)\)', action)
if not action_match:
logger.warning(f"Could not parse action: {action}")
return "", {}
action_name = action_match.group(1)
action_params_str = action_match.group(2)
tool_args = {"action": action_name}
# Extract coordinate values from the action
if "start_box" in action_params_str:
# Extract all box coordinates
box_pattern = r"(start_box|end_box)='(?:<\|box_start\|>)?\((\d+),\s*(\d+)\)(?:<\|box_end\|>)?'"
box_matches = re.findall(box_pattern, action_params_str)
# Handle click-type actions
if action_name in ["click", "left_double", "right_single"]:
# Get coordinates from start_box
for box_type, x, y in box_matches:
if box_type == "start_box":
tool_args["x"] = int(x)
tool_args["y"] = int(y)
break
# Handle drag action
elif action_name == "drag":
start_x, start_y = None, None
end_x, end_y = None, None
for box_type, x, y in box_matches:
if box_type == "start_box":
start_x, start_y = int(x), int(y)
elif box_type == "end_box":
end_x, end_y = int(x), int(y)
if not None in [start_x, start_y, end_x, end_y]:
tool_args["start_x"] = start_x
tool_args["start_y"] = start_y
tool_args["end_x"] = end_x
tool_args["end_y"] = end_y
# Handle scroll action
elif action_name == "scroll":
# Get coordinates from start_box
for box_type, x, y in box_matches:
if box_type == "start_box":
tool_args["x"] = int(x)
tool_args["y"] = int(y)
break
# Extract direction
direction_match = re.search(r"direction='([^']+)'", action_params_str)
if direction_match:
tool_args["direction"] = direction_match.group(1)
# Handle typing text
elif action_name == "type":
# Extract text content
content_match = re.search(r"content='([^']*)'", action_params_str)
if content_match:
# Unescape escaped characters
text = content_match.group(1).replace("\\'", "'").replace('\\"', '"').replace("\\n", "\n")
tool_args = {"action": "type_text", "text": text}
# Handle hotkey
elif action_name == "hotkey":
# Extract key combination
key_match = re.search(r"key='([^']*)'", action_params_str)
if key_match:
keys = key_match.group(1).split()
tool_args = {"action": "hotkey", "keys": keys}
return action_name, tool_args

View File

@@ -165,7 +165,6 @@ MODEL_MAPPINGS = {
"uitars": {
# UI-TARS models default to custom endpoint
"default": "ByteDance-Seed/UI-TARS-1.5-7B",
"ui-tars": "ByteDance-Seed/UI-TARS-1.5-7B",
},
"ollama": {
# For Ollama models, we keep the original name
@@ -287,7 +286,9 @@ def get_provider_and_model(model_name: str, loop_provider: str) -> tuple:
# Assign the determined model name
model_name_to_use = cleaned_model_name
# agent_loop remains AgentLoop.OMNI
elif agent_loop == AgentLoop.UITARS:
provider = LLMProvider.OAICOMPAT
model_name_to_use = MODEL_MAPPINGS["uitars"]["default"] # Default
else:
# Default to OpenAI if unrecognized loop
provider = LLMProvider.OPENAI
@@ -557,7 +558,7 @@ def create_gradio_ui(
"OPENAI": openai_models,
"ANTHROPIC": anthropic_models,
"OMNI": omni_models + ["Custom model..."], # Add custom model option
"UITARS": ["UI-TARS (ByteDance-Seed/UI-TARS-1.5-7B)", "Custom model..."], # UI-TARS options
"UITARS": ["Custom model..."], # UI-TARS options
}
# --- Apply Saved Settings (override defaults if available) ---
@@ -814,6 +815,8 @@ def create_gradio_ui(
provider, cleaned_model_name_from_func, agent_loop_type = (
get_provider_and_model(model_string_to_analyze, agent_loop_choice)
)
print(f"provider={provider} cleaned_model_name_from_func={cleaned_model_name_from_func} agent_loop_type={agent_loop_type} agent_loop_choice={agent_loop_choice}")
# Determine the final model name to send to the agent
# If custom selected, use the custom text box value, otherwise use the cleaned name