Merge pull request #169 from trycua/feature/agent/uitars-mlx

Add MLX provider for UI-TARS 1.5
This commit is contained in:
ddupont
2025-05-10 17:57:54 -04:00
committed by GitHub
11 changed files with 348 additions and 23 deletions

View File

@@ -92,7 +92,7 @@ async def main():
agent = ComputerAgent(
computer=computer,
loop="UITARS",
model=LLM(provider="MLX", name="mlx-community/UI-TARS-1.5-7B-6bit")
model=LLM(provider="MLXVLM", name="mlx-community/UI-TARS-1.5-7B-6bit")
)
await agent.run("Find the trycua/cua repository on GitHub and follow the quick start guide")
@@ -193,7 +193,7 @@ For complete examples, see [agent_examples.py](./examples/agent_examples.py) or
from agent import ComputerAgent, LLM, AgentLoop, LLMProvider
# UI-TARS-1.5 agent for local execution with MLX
ComputerAgent(loop=AgentLoop.UITARS, model=LLM(provider=LLMProvider.MLX, name="mlx-community/UI-TARS-1.5-7B-6bit"))
ComputerAgent(loop=AgentLoop.UITARS, model=LLM(provider=LLMProvider.MLXVLM, name="mlx-community/UI-TARS-1.5-7B-6bit"))
# OpenAI Computer-Use agent using OPENAI_API_KEY
ComputerAgent(loop=AgentLoop.OPENAI, model=LLM(provider=LLMProvider.OPENAI, name="computer-use-preview"))
# Anthropic Claude agent using ANTHROPIC_API_KEY

View File

@@ -36,6 +36,7 @@ async def run_agent_example():
# model=LLM(provider=LLMProvider.OPENAI, name="gpt-4o"),
# model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"),
# model=LLM(provider=LLMProvider.OLLAMA, name="gemma3:4b-it-q4_K_M"),
# model=LLM(provider=LLMProvider.MLXVLM, name="mlx-community/UI-TARS-1.5-7B-4bit"),
model=LLM(
provider=LLMProvider.OAICOMPAT,
name="gemma-3-12b-it",

View File

@@ -32,6 +32,7 @@ pip install "cua-agent[all]"
pip install "cua-agent[openai]" # OpenAI Cua Loop
pip install "cua-agent[anthropic]" # Anthropic Cua Loop
pip install "cua-agent[uitars]" # UI-Tars support
pip install "cua-agent[uitars-mlx]" # local UI-Tars support with MLXVLM
pip install "cua-agent[omni]" # Cua Loop based on OmniParser (includes Ollama for local models)
pip install "cua-agent[ui]" # Gradio UI for the agent
```
@@ -136,7 +137,32 @@ The Gradio UI provides:
### Using UI-TARS
You can use UI-TARS by first following the [deployment guide](https://github.com/bytedance/UI-TARS/blob/main/README_deploy.md). This will give you a provider URL like this: `https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1` which you can use in the gradio UI.
The UI-TARS models are available in two forms:
1. **MLX UI-TARS models** (Default): These models run locally using MLXVLM provider
- `mlx-community/UI-TARS-1.5-7B-4bit` (default) - 4-bit quantized version
- `mlx-community/UI-TARS-1.5-7B-6bit` - 6-bit quantized version for higher quality
```python
agent = ComputerAgent(
computer=macos_computer,
loop=AgentLoop.UITARS,
model=LLM(provider=LLMProvider.MLXVLM, name="mlx-community/UI-TARS-1.5-7B-4bit")
)
```
2. **OpenAI-compatible UI-TARS**: For using the original ByteDance model
- If you want to use the original ByteDance UI-TARS model via an OpenAI-compatible API, follow the [deployment guide](https://github.com/bytedance/UI-TARS/blob/main/README_deploy.md)
- This will give you a provider URL like `https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1` which you can use in the code or Gradio UI:
```python
agent = ComputerAgent(
computer=macos_computer,
loop=AgentLoop.UITARS,
model=LLM(provider=LLMProvider.OAICOMPAT, name="tgi",
provider_base_url="https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1")
)
```
## Agent Loops
@@ -146,7 +172,7 @@ The `cua-agent` package provides three agent loops variations, based on differen
|:-----------|:-----------------|:------------|:-------------|
| `AgentLoop.OPENAI` | • `computer_use_preview` | Use OpenAI Operator CUA model | Not Required |
| `AgentLoop.ANTHROPIC` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219` | Use Anthropic Computer-Use | Not Required |
| `AgentLoop.UITARS` | • `ByteDance-Seed/UI-TARS-1.5-7B` | Uses ByteDance's UI-TARS 1.5 model | Not Required |
| `AgentLoop.UITARS` | • `mlx-community/UI-TARS-1.5-7B-4bit` (default)<br>• `mlx-community/UI-TARS-1.5-7B-6bit`<br>• `ByteDance-Seed/UI-TARS-1.5-7B` (via openAI-compatible endpoint) | Uses UI-TARS models with MLXVLM (default) or OAICOMPAT providers | Not Required |
| `AgentLoop.OMNI` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219`<br>• `gpt-4.5-preview`<br>• `gpt-4o`<br>• `gpt-4`<br>• `phi4`<br>• `phi4-mini`<br>• `gemma3`<br>• `...`<br>• `Any Ollama or OpenAI-compatible model` | Use OmniParser for element pixel-detection (SoM) and any VLMs for UI Grounding and Reasoning | OmniParser |
## AgentResponse

View File

@@ -116,6 +116,7 @@ class LoopFactory:
base_dir=trajectory_dir,
only_n_most_recent_images=only_n_most_recent_images,
provider_base_url=provider_base_url,
provider=provider,
)
else:
raise ValueError(f"Unsupported loop type: {loop_type}")

View File

@@ -8,6 +8,7 @@ DEFAULT_MODELS = {
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M",
LLMProvider.OAICOMPAT: "Qwen2.5-VL-7B-Instruct",
LLMProvider.MLXVLM: "mlx-community/UI-TARS-1.5-7B-4bit",
}
# Map providers to their environment variable names
@@ -16,4 +17,5 @@ ENV_VARS = {
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
LLMProvider.OLLAMA: "none",
LLMProvider.OAICOMPAT: "none", # OpenAI-compatible API typically doesn't require an API key
LLMProvider.MLXVLM: "none", # MLX VLM typically doesn't require an API key
}

View File

@@ -23,6 +23,7 @@ class LLMProvider(StrEnum):
OPENAI = "openai"
OLLAMA = "ollama"
OAICOMPAT = "oaicompat"
MLXVLM= "mlxvlm"
@dataclass

View File

@@ -0,0 +1,263 @@
"""MLX LVM client implementation."""
import io
import logging
import base64
import tempfile
import os
import re
import math
from typing import Dict, List, Optional, Any, cast, Tuple
from PIL import Image
from .base import BaseUITarsClient
import mlx.core as mx
from mlx_vlm import load, generate
from mlx_vlm.prompt_utils import apply_chat_template
from mlx_vlm.utils import load_config
from transformers.tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
# Constants for smart_resize
IMAGE_FACTOR = 28
MIN_PIXELS = 100 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
def round_by_factor(number: float, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: float, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: float, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
class MLXVLMUITarsClient(BaseUITarsClient):
"""MLX LVM client implementation class."""
def __init__(
self,
model: str = "mlx-community/UI-TARS-1.5-7B-4bit"
):
"""Initialize MLX LVM client.
Args:
model: Model name or path (defaults to mlx-community/UI-TARS-1.5-7B-4bit)
"""
# Load model and processor
model_obj, processor = load(
model,
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
)
self.config = load_config(model)
self.model = model_obj
self.processor = processor
self.model_name = model
def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
Args:
text: Text containing box tokens
original_size: Original image size (width, height)
model_size: Model processed image size (width, height)
Returns:
Text with processed coordinates
"""
# Find all box tokens
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
def process_coords(match):
model_x, model_y = int(match.group(1)), int(match.group(2))
# Scale coordinates from model space to original image space
# Both original_size and model_size are in (width, height) format
new_x = int(model_x * original_size[0] / model_size[0]) # Width
new_y = int(model_y * original_size[1] / model_size[1]) # Height
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
return re.sub(box_pattern, process_coords, text)
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
) -> Dict[str, Any]:
"""Run interleaved chat completion.
Args:
messages: List of message dicts
system: System prompt
max_tokens: Optional max tokens override
Returns:
Response dict
"""
# Ensure the system message is included
if not any(msg.get("role") == "system" for msg in messages):
messages = [{"role": "system", "content": system}] + messages
# Create a deep copy of messages to avoid modifying the original
processed_messages = messages.copy()
# Extract images and process messages
images = []
original_sizes = {} # Track original sizes of images for coordinate mapping
model_sizes = {} # Track model processed sizes
image_index = 0
for msg_idx, msg in enumerate(messages):
content = msg.get("content", [])
if not isinstance(content, list):
continue
# Create a copy of the content list to modify
processed_content = []
for item_idx, item in enumerate(content):
if item.get("type") == "image_url":
image_url = item.get("image_url", {}).get("url", "")
pil_image = None
if image_url.startswith("data:image/"):
# Extract base64 data
base64_data = image_url.split(',')[1]
# Convert base64 to PIL Image
image_data = base64.b64decode(base64_data)
pil_image = Image.open(io.BytesIO(image_data))
else:
# Handle file path or URL
pil_image = Image.open(image_url)
# Store original image size for coordinate mapping
original_size = pil_image.size
original_sizes[image_index] = original_size
# Use smart_resize to determine model size
# Note: smart_resize expects (height, width) but PIL gives (width, height)
height, width = original_size[1], original_size[0]
new_height, new_width = smart_resize(height, width)
# Store model size in (width, height) format for consistent coordinate processing
model_sizes[image_index] = (new_width, new_height)
# Resize the image using the calculated dimensions from smart_resize
resized_image = pil_image.resize((new_width, new_height))
images.append(resized_image)
image_index += 1
# Copy items to processed content list
processed_content.append(item.copy())
# Update the processed message content
processed_messages[msg_idx] = msg.copy()
processed_messages[msg_idx]["content"] = processed_content
logger.info(f"resized {len(images)} from {original_sizes[0]} to {model_sizes[0]}")
# Process user text input with box coordinates after image processing
# Swap original_size and model_size arguments for inverse transformation
for msg_idx, msg in enumerate(processed_messages):
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
if "<|box_start|>" in msg.get("content") and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
orig_size = original_sizes[0]
model_size = model_sizes[0]
# Swap arguments to perform inverse transformation for user input
processed_messages[msg_idx]["content"] = self._process_coordinates(msg["content"], model_size, orig_size)
try:
# Format prompt according to model requirements using the processor directly
prompt = self.processor.apply_chat_template(
processed_messages,
tokenize=False,
add_generation_prompt=True
)
tokenizer = cast(PreTrainedTokenizer, self.processor)
print("generating response...")
# Generate response
text_content, usage = generate(
self.model,
tokenizer,
str(prompt),
images,
verbose=False,
max_tokens=max_tokens
)
from pprint import pprint
print("DEBUG - AGENT GENERATION --------")
pprint(text_content)
print("DEBUG - AGENT GENERATION --------")
except Exception as e:
logger.error(f"Error generating response: {str(e)}")
return {
"choices": [
{
"message": {
"role": "assistant",
"content": f"Error generating response: {str(e)}"
},
"finish_reason": "error"
}
],
"model": self.model_name,
"error": str(e)
}
# Process coordinates in the response back to original image space
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
# Get original image size and model size (using the first image)
orig_size = original_sizes[0]
model_size = model_sizes[0]
# Check if output contains box tokens that need processing
if "<|box_start|>" in text_content:
# Process coordinates from model space back to original image space
text_content = self._process_coordinates(text_content, orig_size, model_size)
# Format response to match OpenAI format
response = {
"choices": [
{
"message": {
"role": "assistant",
"content": text_content
},
"finish_reason": "stop"
}
],
"model": self.model_name,
"usage": usage
}
return response

View File

@@ -23,6 +23,7 @@ from .tools.computer import ToolResult
from .prompts import COMPUTER_USE, SYSTEM_PROMPT, MAC_SPECIFIC_NOTES
from .clients.oaicompat import OAICompatClient
from .clients.mlxvlm import MLXVLMUITarsClient
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@@ -44,6 +45,7 @@ class UITARSLoop(BaseLoop):
computer: Computer,
api_key: str,
model: str,
provider: Optional[LLMProvider] = None,
provider_base_url: Optional[str] = "http://localhost:8000/v1",
only_n_most_recent_images: Optional[int] = 2,
base_dir: Optional[str] = "trajectories",
@@ -64,9 +66,10 @@ class UITARSLoop(BaseLoop):
max_retries: Maximum number of retries for API calls
retry_delay: Delay between retries in seconds
save_trajectory: Whether to save trajectory data
provider: The LLM provider to use (defaults to OAICOMPAT if not specified)
"""
# Set provider before initializing base class
self.provider = LLMProvider.OAICOMPAT
self.provider = provider or LLMProvider.OAICOMPAT
self.provider_base_url = provider_base_url
# Initialize message manager with image retention config
@@ -113,7 +116,7 @@ class UITARSLoop(BaseLoop):
logger.error(f"Error initializing tool manager: {str(e)}")
logger.warning("Will attempt to initialize tools on first use.")
# Initialize client for the OAICompat provider
# Initialize client for the selected provider
try:
await self.initialize_client()
except Exception as e:
@@ -128,18 +131,28 @@ class UITARSLoop(BaseLoop):
"""Initialize the appropriate client.
Implements abstract method from BaseLoop to set up the specific
provider client (OAICompat for UI-TARS).
provider client based on the configured provider.
"""
try:
logger.info(f"Initializing OAICompat client for UI-TARS with model {self.model}...")
self.client = OAICompatClient(
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
model=self.model,
provider_base_url=self.provider_base_url,
)
logger.info(f"Initialized OAICompat client with model {self.model}")
if self.provider == LLMProvider.MLXVLM:
logger.info(f"Initializing MLX VLM client for UI-TARS with model {self.model}...")
self.client = MLXVLMUITarsClient(
model=self.model,
)
logger.info(f"Initialized MLX VLM client with model {self.model}")
else:
# Default to OAICompat client for other providers
logger.info(f"Initializing OAICompat client for UI-TARS with model {self.model}...")
self.client = OAICompatClient(
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
model=self.model,
provider_base_url=self.provider_base_url,
)
logger.info(f"Initialized OAICompat client with model {self.model}")
except Exception as e:
logger.error(f"Error initializing client: {str(e)}")
self.client = None

View File

@@ -105,7 +105,7 @@ async def to_agent_response_format(
}
],
truncation="auto",
usage=response["usage"],
usage=response.get("usage", {}),
user=None,
metadata={},
response=response

View File

@@ -164,8 +164,10 @@ MODEL_MAPPINGS = {
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet-20250219",
},
"uitars": {
# UI-TARS models default to custom endpoint
"default": "ByteDance-Seed/UI-TARS-1.5-7B",
# UI-TARS models using MLXVLM provider
"default": "mlx-community/UI-TARS-1.5-7B-4bit",
"mlx-community/UI-TARS-1.5-7B-4bit": "mlx-community/UI-TARS-1.5-7B-4bit",
"mlx-community/UI-TARS-1.5-7B-6bit": "mlx-community/UI-TARS-1.5-7B-6bit"
},
"ollama": {
# For Ollama models, we keep the original name
@@ -288,8 +290,16 @@ def get_provider_and_model(model_name: str, loop_provider: str) -> tuple:
model_name_to_use = cleaned_model_name
# agent_loop remains AgentLoop.OMNI
elif agent_loop == AgentLoop.UITARS:
provider = LLMProvider.OAICOMPAT
model_name_to_use = MODEL_MAPPINGS["uitars"]["default"] # Default
# For UITARS, use MLXVLM provider for the MLX models, OAICOMPAT for custom
if model_name == "Custom model...":
provider = LLMProvider.OAICOMPAT
model_name_to_use = "tgi"
else:
provider = LLMProvider.MLXVLM
# Get the model name from the mappings or use as-is if not found
model_name_to_use = MODEL_MAPPINGS["uitars"].get(
model_name, model_name if model_name else MODEL_MAPPINGS["uitars"]["default"]
)
else:
# Default to OpenAI if unrecognized loop
provider = LLMProvider.OPENAI
@@ -440,7 +450,11 @@ def create_gradio_ui(
"OPENAI": openai_models,
"ANTHROPIC": anthropic_models,
"OMNI": omni_models + ["Custom model..."], # Add custom model option
"UITARS": ["Custom model..."], # UI-TARS options
"UITARS": [
"mlx-community/UI-TARS-1.5-7B-4bit",
"mlx-community/UI-TARS-1.5-7B-6bit",
"Custom model..."
], # UI-TARS options with MLX models
}
# --- Apply Saved Settings (override defaults if available) ---

View File

@@ -37,6 +37,9 @@ openai = [
uitars = [
"httpx>=0.27.0,<0.29.0",
]
uitars-mlx = [
"mlx-vlm @ git+https://github.com/ddupont808/mlx-vlm.git@stable/fix/qwen2-position-id"
]
ui = [
"gradio>=5.23.3,<6.0.0",
"python-dotenv>=1.0.1,<2.0.0",
@@ -84,7 +87,8 @@ all = [
"requests>=2.31.0,<3.0.0",
"ollama>=0.4.7,<0.5.0",
"gradio>=5.23.3,<6.0.0",
"python-dotenv>=1.0.1,<2.0.0"
"python-dotenv>=1.0.1,<2.0.0",
"mlx-vlm @ git+https://github.com/ddupont808/mlx-vlm.git@stable/fix/qwen2-position-id"
]
[tool.pdm]