mirror of
https://github.com/trycua/computer.git
synced 2026-03-05 05:19:08 -06:00
fixed uitars agent, added it to gradio ui
This commit is contained in:
35
libs/agent/agent/providers/uitars/clients/base.py
Normal file
35
libs/agent/agent/providers/uitars/clients/base.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Base client implementation for Omni providers."""
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseUITarsClient:
|
||||
"""Base class for provider-specific clients."""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, model: Optional[str] = None):
|
||||
"""Initialize base client.
|
||||
|
||||
Args:
|
||||
api_key: Optional API key
|
||||
model: Optional model name
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run interleaved chat completion.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
system: System prompt
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
Response dict
|
||||
"""
|
||||
raise NotImplementedError
|
||||
204
libs/agent/agent/providers/uitars/clients/oaicompat.py
Normal file
204
libs/agent/agent/providers/uitars/clients/oaicompat.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""OpenAI-compatible client implementation."""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Any
|
||||
import aiohttp
|
||||
import re
|
||||
from .base import BaseUITarsClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# OpenAI-compatible client for the UI_Tars
|
||||
class OAICompatClient(BaseUITarsClient):
|
||||
"""OpenAI-compatible API client implementation.
|
||||
|
||||
This client can be used with any service that implements the OpenAI API protocol, including:
|
||||
- Huggingface Text Generation Interface endpoints
|
||||
- vLLM
|
||||
- LM Studio
|
||||
- LocalAI
|
||||
- Ollama (with OpenAI compatibility)
|
||||
- Text Generation WebUI
|
||||
- Any other service with OpenAI API compatibility
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
model: str = "Qwen2.5-VL-7B-Instruct",
|
||||
provider_base_url: Optional[str] = "http://localhost:8000/v1",
|
||||
max_tokens: int = 4096,
|
||||
temperature: float = 0.0,
|
||||
):
|
||||
"""Initialize the OpenAI-compatible client.
|
||||
|
||||
Args:
|
||||
api_key: Not used for local endpoints, usually set to "EMPTY"
|
||||
model: Model name to use
|
||||
provider_base_url: API base URL. Typically in the format "http://localhost:PORT/v1"
|
||||
Examples:
|
||||
- vLLM: "http://localhost:8000/v1"
|
||||
- LM Studio: "http://localhost:1234/v1"
|
||||
- LocalAI: "http://localhost:8080/v1"
|
||||
- Ollama: "http://localhost:11434/v1"
|
||||
max_tokens: Maximum tokens to generate
|
||||
temperature: Generation temperature
|
||||
"""
|
||||
super().__init__(api_key=api_key or "EMPTY", model=model)
|
||||
self.api_key = api_key or "EMPTY" # Local endpoints typically don't require an API key
|
||||
self.model = model
|
||||
self.provider_base_url = (
|
||||
provider_base_url or "http://localhost:8000/v1"
|
||||
) # Use default if None
|
||||
self.max_tokens = max_tokens
|
||||
self.temperature = temperature
|
||||
|
||||
def _extract_base64_image(self, text: str) -> Optional[str]:
|
||||
"""Extract base64 image data from an HTML img tag."""
|
||||
pattern = r'data:image/[^;]+;base64,([^"]+)'
|
||||
match = re.search(pattern, text)
|
||||
return match.group(1) if match else None
|
||||
|
||||
def _get_loggable_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Create a loggable version of messages with image data truncated."""
|
||||
loggable_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg.get("content"), list):
|
||||
new_content = []
|
||||
for content in msg["content"]:
|
||||
if content.get("type") == "image":
|
||||
new_content.append(
|
||||
{"type": "image", "image_url": {"url": "[BASE64_IMAGE_DATA]"}}
|
||||
)
|
||||
else:
|
||||
new_content.append(content)
|
||||
loggable_messages.append({"role": msg["role"], "content": new_content})
|
||||
else:
|
||||
loggable_messages.append(msg)
|
||||
return loggable_messages
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run interleaved chat completion.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
system: System prompt
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
Response dict
|
||||
"""
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
final_messages = [{"role": "system", "content": system}]
|
||||
|
||||
# Process messages
|
||||
for item in messages:
|
||||
if isinstance(item, dict):
|
||||
if isinstance(item["content"], list):
|
||||
# Content is already in the correct format
|
||||
final_messages.append(item)
|
||||
else:
|
||||
# Single string content, check for image
|
||||
base64_img = self._extract_base64_image(item["content"])
|
||||
if base64_img:
|
||||
message = {
|
||||
"role": item["role"],
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
message = {
|
||||
"role": item["role"],
|
||||
"content": [{"type": "text", "text": item["content"]}],
|
||||
}
|
||||
final_messages.append(message)
|
||||
else:
|
||||
# String content, check for image
|
||||
base64_img = self._extract_base64_image(item)
|
||||
if base64_img:
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
|
||||
}
|
||||
],
|
||||
}
|
||||
else:
|
||||
message = {"role": "user", "content": [{"type": "text", "text": item}]}
|
||||
final_messages.append(message)
|
||||
|
||||
payload = {"model": self.model, "messages": final_messages, "temperature": self.temperature}
|
||||
payload["max_tokens"] = max_tokens or self.max_tokens
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Use default base URL if none provided
|
||||
base_url = self.provider_base_url or "http://localhost:8000/v1"
|
||||
|
||||
# Check if the base URL already includes the chat/completions endpoint
|
||||
|
||||
endpoint_url = base_url
|
||||
if not endpoint_url.endswith("/chat/completions"):
|
||||
# If URL is RunPod format, make it OpenAI compatible
|
||||
if endpoint_url.startswith("https://api.runpod.ai/v2/"):
|
||||
# Extract RunPod endpoint ID
|
||||
parts = endpoint_url.split("/")
|
||||
if len(parts) >= 5:
|
||||
runpod_id = parts[4]
|
||||
endpoint_url = f"https://api.runpod.ai/v2/{runpod_id}/openai/v1/chat/completions"
|
||||
# If the URL ends with /v1, append /chat/completions
|
||||
elif endpoint_url.endswith("/v1"):
|
||||
endpoint_url = f"{endpoint_url}/chat/completions"
|
||||
# If the URL doesn't end with /v1, make sure it has a proper structure
|
||||
elif not endpoint_url.endswith("/"):
|
||||
endpoint_url = f"{endpoint_url}/chat/completions"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}chat/completions"
|
||||
|
||||
# Log the endpoint URL for debugging
|
||||
logger.debug(f"Using endpoint URL: {endpoint_url}")
|
||||
|
||||
async with session.post(endpoint_url, headers=headers, json=payload) as response:
|
||||
# Log the status and content type
|
||||
logger.debug(f"Status: {response.status}")
|
||||
logger.debug(f"Content-Type: {response.headers.get('Content-Type')}")
|
||||
|
||||
# Get the raw text of the response
|
||||
response_text = await response.text()
|
||||
logger.debug(f"Response content: {response_text}")
|
||||
|
||||
# Try to parse as JSON if the content type is appropriate
|
||||
if "application/json" in response.headers.get('Content-Type', ''):
|
||||
response_json = await response.json()
|
||||
else:
|
||||
raise Exception(f"Response is not JSON format")
|
||||
# # Optionally try to parse it anyway
|
||||
# try:
|
||||
# import json
|
||||
# response_json = json.loads(response_text)
|
||||
# except json.JSONDecodeError as e:
|
||||
# print(f"Failed to parse response as JSON: {e}")
|
||||
|
||||
if response.status != 200:
|
||||
error_msg = response_json.get("error", {}).get(
|
||||
"message", str(response_json)
|
||||
)
|
||||
logger.error(f"Error in API call: {error_msg}")
|
||||
raise Exception(f"API error: {error_msg}")
|
||||
|
||||
return response_json
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in API call: {str(e)}")
|
||||
raise
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,7 @@
|
||||
"""Prompts for UI-TARS agent."""
|
||||
|
||||
SYSTEM_PROMPT = "You are a helpful assistant."
|
||||
|
||||
COMPUTER_USE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
|
||||
1
libs/agent/agent/providers/uitars/tools/__init__.py
Normal file
1
libs/agent/agent/providers/uitars/tools/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""UI-TARS tools package."""
|
||||
279
libs/agent/agent/providers/uitars/tools/computer.py
Normal file
279
libs/agent/agent/providers/uitars/tools/computer.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Computer tool for UI-TARS."""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Literal, Union
|
||||
|
||||
from computer import Computer
|
||||
from ....core.tools.base import ToolResult, ToolFailure
|
||||
from ....core.tools.computer import BaseComputerTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ComputerTool(BaseComputerTool):
|
||||
"""
|
||||
A tool that allows the UI-TARS agent to interact with the screen, keyboard, and mouse.
|
||||
"""
|
||||
|
||||
name: str = "computer"
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
computer: Computer
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the computer tool.
|
||||
|
||||
Args:
|
||||
computer: Computer instance
|
||||
"""
|
||||
super().__init__(computer)
|
||||
self.computer = computer
|
||||
self.width = None
|
||||
self.height = None
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def to_params(self) -> Dict[str, Any]:
|
||||
"""Convert tool to API parameters.
|
||||
|
||||
Returns:
|
||||
Dictionary with tool parameters
|
||||
"""
|
||||
if self.width is None or self.height is None:
|
||||
raise RuntimeError(
|
||||
"Screen dimensions not initialized. Call initialize_dimensions() first."
|
||||
)
|
||||
return {
|
||||
"type": "computer",
|
||||
"display_width": self.width,
|
||||
"display_height": self.height,
|
||||
}
|
||||
|
||||
async def initialize_dimensions(self) -> None:
|
||||
"""Initialize screen dimensions from the computer interface."""
|
||||
try:
|
||||
display_size = await self.computer.interface.get_screen_size()
|
||||
self.width = display_size["width"]
|
||||
self.height = display_size["height"]
|
||||
self.logger.info(f"Initialized screen dimensions to {self.width}x{self.height}")
|
||||
except Exception as e:
|
||||
# Fall back to defaults if we can't get accurate dimensions
|
||||
self.width = 1024
|
||||
self.height = 768
|
||||
self.logger.warning(
|
||||
f"Failed to get screen dimensions, using defaults: {self.width}x{self.height}. Error: {e}"
|
||||
)
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
*,
|
||||
action: str,
|
||||
**kwargs,
|
||||
) -> ToolResult:
|
||||
"""Execute a computer action.
|
||||
|
||||
Args:
|
||||
action: The action to perform (based on UI-TARS action space)
|
||||
**kwargs: Additional parameters for the action
|
||||
|
||||
Returns:
|
||||
ToolResult containing action output and possibly a base64 image
|
||||
"""
|
||||
try:
|
||||
# Ensure dimensions are initialized
|
||||
if self.width is None or self.height is None:
|
||||
await self.initialize_dimensions()
|
||||
if self.width is None or self.height is None:
|
||||
return ToolFailure(error="Failed to initialize screen dimensions")
|
||||
|
||||
# Handle actions defined in UI-TARS action space (from prompts.py)
|
||||
# Handle standard click (left click)
|
||||
if action == "click":
|
||||
if "x" in kwargs and "y" in kwargs:
|
||||
x, y = kwargs["x"], kwargs["y"]
|
||||
await self.computer.interface.left_click(x, y)
|
||||
|
||||
# Wait briefly for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Clicked at ({x}, {y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing coordinates for click action")
|
||||
|
||||
# Handle double click
|
||||
elif action == "left_double":
|
||||
if "x" in kwargs and "y" in kwargs:
|
||||
x, y = kwargs["x"], kwargs["y"]
|
||||
await self.computer.interface.double_click(x, y)
|
||||
|
||||
# Wait briefly for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Double-clicked at ({x}, {y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing coordinates for left_double action")
|
||||
|
||||
# Handle right click
|
||||
elif action == "right_single":
|
||||
if "x" in kwargs and "y" in kwargs:
|
||||
x, y = kwargs["x"], kwargs["y"]
|
||||
await self.computer.interface.right_click(x, y)
|
||||
|
||||
# Wait briefly for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Right-clicked at ({x}, {y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing coordinates for right_single action")
|
||||
|
||||
# Handle typing text
|
||||
elif action == "type_text":
|
||||
if "text" in kwargs:
|
||||
text = kwargs["text"]
|
||||
await self.computer.interface.type_text(text)
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Typed: {text}",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing text for type action")
|
||||
|
||||
# Handle hotkey
|
||||
elif action == "hotkey":
|
||||
if "keys" in kwargs:
|
||||
keys = kwargs["keys"]
|
||||
for key in keys:
|
||||
await self.computer.interface.press_key(key)
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Pressed hotkey: {', '.join(keys)}",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing keys for hotkey action")
|
||||
|
||||
# Handle drag action
|
||||
elif action == "drag":
|
||||
if all(k in kwargs for k in ["start_x", "start_y", "end_x", "end_y"]):
|
||||
start_x, start_y = kwargs["start_x"], kwargs["start_y"]
|
||||
end_x, end_y = kwargs["end_x"], kwargs["end_y"]
|
||||
|
||||
# Perform drag
|
||||
await self.computer.interface.move_cursor(start_x, start_y)
|
||||
await self.computer.interface.drag_to(end_x, end_y)
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Dragged from ({start_x}, {start_y}) to ({end_x}, {end_y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing coordinates for drag action")
|
||||
|
||||
# Handle scroll action
|
||||
elif action == "scroll":
|
||||
if all(k in kwargs for k in ["x", "y", "direction"]):
|
||||
x, y = kwargs["x"], kwargs["y"]
|
||||
direction = kwargs["direction"]
|
||||
|
||||
# Move cursor to position
|
||||
await self.computer.interface.move_cursor(x, y)
|
||||
|
||||
# Scroll based on direction
|
||||
if direction == "down":
|
||||
await self.computer.interface.scroll_down(5)
|
||||
elif direction == "up":
|
||||
await self.computer.interface.scroll_up(5)
|
||||
elif direction == "right":
|
||||
pass # await self.computer.interface.scroll_right(5)
|
||||
elif direction == "left":
|
||||
pass # await self.computer.interface.scroll_left(5)
|
||||
else:
|
||||
return ToolFailure(error=f"Invalid scroll direction: {direction}")
|
||||
|
||||
# Wait for UI to update
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Take screenshot after action
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output=f"Scrolled {direction} at ({x}, {y})",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
else:
|
||||
return ToolFailure(error="Missing parameters for scroll action")
|
||||
|
||||
# Handle wait action
|
||||
elif action == "wait":
|
||||
# Sleep for 5 seconds as specified in the action space
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Take screenshot after waiting
|
||||
screenshot = await self.computer.interface.screenshot()
|
||||
base64_screenshot = base64.b64encode(screenshot).decode("utf-8")
|
||||
|
||||
return ToolResult(
|
||||
output="Waited for 5 seconds",
|
||||
base64_image=base64_screenshot,
|
||||
)
|
||||
|
||||
# Handle finished action (task completion)
|
||||
elif action == "finished":
|
||||
content = kwargs.get("content", "Task completed")
|
||||
return ToolResult(
|
||||
output=f"Task finished: {content}",
|
||||
)
|
||||
|
||||
return await self._handle_scroll(action)
|
||||
else:
|
||||
return ToolFailure(error=f"Unsupported action: {action}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in ComputerTool.__call__: {str(e)}")
|
||||
return ToolFailure(error=f"Failed to execute {action}: {str(e)}")
|
||||
60
libs/agent/agent/providers/uitars/tools/manager.py
Normal file
60
libs/agent/agent/providers/uitars/tools/manager.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Tool manager for the UI-TARS provider."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from computer import Computer
|
||||
from ....core.tools import BaseToolManager
|
||||
from ....core.tools.collection import ToolCollection
|
||||
from .computer import ComputerTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolManager(BaseToolManager):
|
||||
"""Manages UI-TARS provider tool initialization and execution."""
|
||||
|
||||
def __init__(self, computer: Computer):
|
||||
"""Initialize the tool manager.
|
||||
|
||||
Args:
|
||||
computer: Computer instance for computer-related tools
|
||||
"""
|
||||
super().__init__(computer)
|
||||
# Initialize UI-TARS-specific tools
|
||||
self.computer_tool = ComputerTool(self.computer)
|
||||
self._initialized = False
|
||||
|
||||
def _initialize_tools(self) -> ToolCollection:
|
||||
"""Initialize all available tools."""
|
||||
return ToolCollection(self.computer_tool)
|
||||
|
||||
async def _initialize_tools_specific(self) -> None:
|
||||
"""Initialize UI-TARS provider-specific tool requirements."""
|
||||
await self.computer_tool.initialize_dimensions()
|
||||
|
||||
def get_tool_params(self) -> List[Dict[str, Any]]:
|
||||
"""Get tool parameters for API calls.
|
||||
|
||||
Returns:
|
||||
List of tool parameters for the current provider's API
|
||||
"""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
|
||||
return self.tools.to_params()
|
||||
|
||||
async def execute_tool(self, name: str, tool_input: dict[str, Any]) -> Any:
|
||||
"""Execute a tool with the given input.
|
||||
|
||||
Args:
|
||||
name: Name of the tool to execute
|
||||
tool_input: Input parameters for the tool
|
||||
|
||||
Returns:
|
||||
Result of the tool execution
|
||||
"""
|
||||
if self.tools is None:
|
||||
raise RuntimeError("Tools not initialized. Call initialize() first.")
|
||||
|
||||
return await self.tools.run(name=name, tool_input=tool_input)
|
||||
153
libs/agent/agent/providers/uitars/utils.py
Normal file
153
libs/agent/agent/providers/uitars/utils.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Utility functions for the UI-TARS provider."""
|
||||
|
||||
import logging
|
||||
import base64
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def add_box_token(input_string: str) -> str:
|
||||
"""Add box tokens to the coordinates in the model response.
|
||||
|
||||
Args:
|
||||
input_string: Raw model response
|
||||
|
||||
Returns:
|
||||
String with box tokens added
|
||||
"""
|
||||
if "Action: " not in input_string or "start_box=" not in input_string:
|
||||
return input_string
|
||||
|
||||
suffix = input_string.split("Action: ")[0] + "Action: "
|
||||
actions = input_string.split("Action: ")[1:]
|
||||
processed_actions = []
|
||||
|
||||
for action in actions:
|
||||
action = action.strip()
|
||||
coordinates = re.findall(r"(start_box|end_box)='\((\d+),\s*(\d+)\)'", action)
|
||||
|
||||
updated_action = action
|
||||
for coord_type, x, y in coordinates:
|
||||
updated_action = updated_action.replace(
|
||||
f"{coord_type}='({x},{y})'",
|
||||
f"{coord_type}='<|box_start|>({x},{y})<|box_end|>'"
|
||||
)
|
||||
processed_actions.append(updated_action)
|
||||
|
||||
return suffix + "\n\n".join(processed_actions)
|
||||
|
||||
|
||||
def parse_actions(response: str) -> List[str]:
|
||||
"""Parse actions from UI-TARS model response.
|
||||
|
||||
Args:
|
||||
response: The raw model response text
|
||||
|
||||
Returns:
|
||||
List of parsed actions
|
||||
"""
|
||||
actions = []
|
||||
# Extract the Action part from the response
|
||||
if "Action:" in response:
|
||||
action_text = response.split("Action:")[-1].strip()
|
||||
# Clean up and format action
|
||||
if action_text:
|
||||
# Handle multiple actions separated by newlines
|
||||
action_parts = action_text.split("\n\n")
|
||||
for part in action_parts:
|
||||
if part.strip():
|
||||
actions.append(part.strip())
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
def parse_action_parameters(action: str) -> Tuple[str, Dict[str, Any]]:
|
||||
"""Parse parameters from an action string.
|
||||
|
||||
Args:
|
||||
action: The action string to parse
|
||||
|
||||
Returns:
|
||||
Tuple of (action_name, action_parameters)
|
||||
"""
|
||||
# Handle "finished" action
|
||||
if action.startswith("finished"):
|
||||
return "finished", {}
|
||||
|
||||
# Parse action parameters
|
||||
action_match = re.match(r'(\w+)\((.*)\)', action)
|
||||
if not action_match:
|
||||
logger.warning(f"Could not parse action: {action}")
|
||||
return "", {}
|
||||
|
||||
action_name = action_match.group(1)
|
||||
action_params_str = action_match.group(2)
|
||||
|
||||
tool_args = {"action": action_name}
|
||||
|
||||
# Extract coordinate values from the action
|
||||
if "start_box" in action_params_str:
|
||||
# Extract all box coordinates
|
||||
box_pattern = r"(start_box|end_box)='(?:<\|box_start\|>)?\((\d+),\s*(\d+)\)(?:<\|box_end\|>)?'"
|
||||
box_matches = re.findall(box_pattern, action_params_str)
|
||||
|
||||
# Handle click-type actions
|
||||
if action_name in ["click", "left_double", "right_single"]:
|
||||
# Get coordinates from start_box
|
||||
for box_type, x, y in box_matches:
|
||||
if box_type == "start_box":
|
||||
tool_args["x"] = int(x)
|
||||
tool_args["y"] = int(y)
|
||||
break
|
||||
|
||||
# Handle drag action
|
||||
elif action_name == "drag":
|
||||
start_x, start_y = None, None
|
||||
end_x, end_y = None, None
|
||||
|
||||
for box_type, x, y in box_matches:
|
||||
if box_type == "start_box":
|
||||
start_x, start_y = int(x), int(y)
|
||||
elif box_type == "end_box":
|
||||
end_x, end_y = int(x), int(y)
|
||||
|
||||
if not None in [start_x, start_y, end_x, end_y]:
|
||||
tool_args["start_x"] = start_x
|
||||
tool_args["start_y"] = start_y
|
||||
tool_args["end_x"] = end_x
|
||||
tool_args["end_y"] = end_y
|
||||
|
||||
# Handle scroll action
|
||||
elif action_name == "scroll":
|
||||
# Get coordinates from start_box
|
||||
for box_type, x, y in box_matches:
|
||||
if box_type == "start_box":
|
||||
tool_args["x"] = int(x)
|
||||
tool_args["y"] = int(y)
|
||||
break
|
||||
|
||||
# Extract direction
|
||||
direction_match = re.search(r"direction='([^']+)'", action_params_str)
|
||||
if direction_match:
|
||||
tool_args["direction"] = direction_match.group(1)
|
||||
|
||||
# Handle typing text
|
||||
elif action_name == "type":
|
||||
# Extract text content
|
||||
content_match = re.search(r"content='([^']*)'", action_params_str)
|
||||
if content_match:
|
||||
# Unescape escaped characters
|
||||
text = content_match.group(1).replace("\\'", "'").replace('\\"', '"').replace("\\n", "\n")
|
||||
tool_args = {"action": "type_text", "text": text}
|
||||
|
||||
# Handle hotkey
|
||||
elif action_name == "hotkey":
|
||||
# Extract key combination
|
||||
key_match = re.search(r"key='([^']*)'", action_params_str)
|
||||
if key_match:
|
||||
keys = key_match.group(1).split()
|
||||
tool_args = {"action": "hotkey", "keys": keys}
|
||||
|
||||
return action_name, tool_args
|
||||
@@ -165,7 +165,6 @@ MODEL_MAPPINGS = {
|
||||
"uitars": {
|
||||
# UI-TARS models default to custom endpoint
|
||||
"default": "ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
"ui-tars": "ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
},
|
||||
"ollama": {
|
||||
# For Ollama models, we keep the original name
|
||||
@@ -287,7 +286,9 @@ def get_provider_and_model(model_name: str, loop_provider: str) -> tuple:
|
||||
# Assign the determined model name
|
||||
model_name_to_use = cleaned_model_name
|
||||
# agent_loop remains AgentLoop.OMNI
|
||||
|
||||
elif agent_loop == AgentLoop.UITARS:
|
||||
provider = LLMProvider.OAICOMPAT
|
||||
model_name_to_use = MODEL_MAPPINGS["uitars"]["default"] # Default
|
||||
else:
|
||||
# Default to OpenAI if unrecognized loop
|
||||
provider = LLMProvider.OPENAI
|
||||
@@ -557,7 +558,7 @@ def create_gradio_ui(
|
||||
"OPENAI": openai_models,
|
||||
"ANTHROPIC": anthropic_models,
|
||||
"OMNI": omni_models + ["Custom model..."], # Add custom model option
|
||||
"UITARS": ["UI-TARS (ByteDance-Seed/UI-TARS-1.5-7B)", "Custom model..."], # UI-TARS options
|
||||
"UITARS": ["Custom model..."], # UI-TARS options
|
||||
}
|
||||
|
||||
# --- Apply Saved Settings (override defaults if available) ---
|
||||
@@ -814,6 +815,8 @@ def create_gradio_ui(
|
||||
provider, cleaned_model_name_from_func, agent_loop_type = (
|
||||
get_provider_and_model(model_string_to_analyze, agent_loop_choice)
|
||||
)
|
||||
|
||||
print(f"provider={provider} cleaned_model_name_from_func={cleaned_model_name_from_func} agent_loop_type={agent_loop_type} agent_loop_choice={agent_loop_choice}")
|
||||
|
||||
# Determine the final model name to send to the agent
|
||||
# If custom selected, use the custom text box value, otherwise use the cleaned name
|
||||
|
||||
Reference in New Issue
Block a user