mirror of
https://github.com/trycua/computer.git
synced 2026-01-16 10:20:20 -06:00
Merge pull request #169 from trycua/feature/agent/uitars-mlx
Add MLX provider for UI-TARS 1.5
This commit is contained in:
@@ -92,7 +92,7 @@ async def main():
|
||||
agent = ComputerAgent(
|
||||
computer=computer,
|
||||
loop="UITARS",
|
||||
model=LLM(provider="MLX", name="mlx-community/UI-TARS-1.5-7B-6bit")
|
||||
model=LLM(provider="MLXVLM", name="mlx-community/UI-TARS-1.5-7B-6bit")
|
||||
)
|
||||
await agent.run("Find the trycua/cua repository on GitHub and follow the quick start guide")
|
||||
|
||||
@@ -193,7 +193,7 @@ For complete examples, see [agent_examples.py](./examples/agent_examples.py) or
|
||||
from agent import ComputerAgent, LLM, AgentLoop, LLMProvider
|
||||
|
||||
# UI-TARS-1.5 agent for local execution with MLX
|
||||
ComputerAgent(loop=AgentLoop.UITARS, model=LLM(provider=LLMProvider.MLX, name="mlx-community/UI-TARS-1.5-7B-6bit"))
|
||||
ComputerAgent(loop=AgentLoop.UITARS, model=LLM(provider=LLMProvider.MLXVLM, name="mlx-community/UI-TARS-1.5-7B-6bit"))
|
||||
# OpenAI Computer-Use agent using OPENAI_API_KEY
|
||||
ComputerAgent(loop=AgentLoop.OPENAI, model=LLM(provider=LLMProvider.OPENAI, name="computer-use-preview"))
|
||||
# Anthropic Claude agent using ANTHROPIC_API_KEY
|
||||
|
||||
@@ -36,6 +36,7 @@ async def run_agent_example():
|
||||
# model=LLM(provider=LLMProvider.OPENAI, name="gpt-4o"),
|
||||
# model=LLM(provider=LLMProvider.ANTHROPIC, name="claude-3-7-sonnet-20250219"),
|
||||
# model=LLM(provider=LLMProvider.OLLAMA, name="gemma3:4b-it-q4_K_M"),
|
||||
# model=LLM(provider=LLMProvider.MLXVLM, name="mlx-community/UI-TARS-1.5-7B-4bit"),
|
||||
model=LLM(
|
||||
provider=LLMProvider.OAICOMPAT,
|
||||
name="gemma-3-12b-it",
|
||||
|
||||
@@ -32,6 +32,7 @@ pip install "cua-agent[all]"
|
||||
pip install "cua-agent[openai]" # OpenAI Cua Loop
|
||||
pip install "cua-agent[anthropic]" # Anthropic Cua Loop
|
||||
pip install "cua-agent[uitars]" # UI-Tars support
|
||||
pip install "cua-agent[uitars-mlx]" # local UI-Tars support with MLXVLM
|
||||
pip install "cua-agent[omni]" # Cua Loop based on OmniParser (includes Ollama for local models)
|
||||
pip install "cua-agent[ui]" # Gradio UI for the agent
|
||||
```
|
||||
@@ -136,7 +137,32 @@ The Gradio UI provides:
|
||||
|
||||
### Using UI-TARS
|
||||
|
||||
You can use UI-TARS by first following the [deployment guide](https://github.com/bytedance/UI-TARS/blob/main/README_deploy.md). This will give you a provider URL like this: `https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1` which you can use in the gradio UI.
|
||||
The UI-TARS models are available in two forms:
|
||||
|
||||
1. **MLX UI-TARS models** (Default): These models run locally using MLXVLM provider
|
||||
- `mlx-community/UI-TARS-1.5-7B-4bit` (default) - 4-bit quantized version
|
||||
- `mlx-community/UI-TARS-1.5-7B-6bit` - 6-bit quantized version for higher quality
|
||||
|
||||
```python
|
||||
agent = ComputerAgent(
|
||||
computer=macos_computer,
|
||||
loop=AgentLoop.UITARS,
|
||||
model=LLM(provider=LLMProvider.MLXVLM, name="mlx-community/UI-TARS-1.5-7B-4bit")
|
||||
)
|
||||
```
|
||||
|
||||
2. **OpenAI-compatible UI-TARS**: For using the original ByteDance model
|
||||
- If you want to use the original ByteDance UI-TARS model via an OpenAI-compatible API, follow the [deployment guide](https://github.com/bytedance/UI-TARS/blob/main/README_deploy.md)
|
||||
- This will give you a provider URL like `https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1` which you can use in the code or Gradio UI:
|
||||
|
||||
```python
|
||||
agent = ComputerAgent(
|
||||
computer=macos_computer,
|
||||
loop=AgentLoop.UITARS,
|
||||
model=LLM(provider=LLMProvider.OAICOMPAT, name="tgi",
|
||||
provider_base_url="https://**************.us-east-1.aws.endpoints.huggingface.cloud/v1")
|
||||
)
|
||||
```
|
||||
|
||||
## Agent Loops
|
||||
|
||||
@@ -146,7 +172,7 @@ The `cua-agent` package provides three agent loops variations, based on differen
|
||||
|:-----------|:-----------------|:------------|:-------------|
|
||||
| `AgentLoop.OPENAI` | • `computer_use_preview` | Use OpenAI Operator CUA model | Not Required |
|
||||
| `AgentLoop.ANTHROPIC` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219` | Use Anthropic Computer-Use | Not Required |
|
||||
| `AgentLoop.UITARS` | • `ByteDance-Seed/UI-TARS-1.5-7B` | Uses ByteDance's UI-TARS 1.5 model | Not Required |
|
||||
| `AgentLoop.UITARS` | • `mlx-community/UI-TARS-1.5-7B-4bit` (default)<br>• `mlx-community/UI-TARS-1.5-7B-6bit`<br>• `ByteDance-Seed/UI-TARS-1.5-7B` (via openAI-compatible endpoint) | Uses UI-TARS models with MLXVLM (default) or OAICOMPAT providers | Not Required |
|
||||
| `AgentLoop.OMNI` | • `claude-3-5-sonnet-20240620`<br>• `claude-3-7-sonnet-20250219`<br>• `gpt-4.5-preview`<br>• `gpt-4o`<br>• `gpt-4`<br>• `phi4`<br>• `phi4-mini`<br>• `gemma3`<br>• `...`<br>• `Any Ollama or OpenAI-compatible model` | Use OmniParser for element pixel-detection (SoM) and any VLMs for UI Grounding and Reasoning | OmniParser |
|
||||
|
||||
## AgentResponse
|
||||
|
||||
@@ -116,6 +116,7 @@ class LoopFactory:
|
||||
base_dir=trajectory_dir,
|
||||
only_n_most_recent_images=only_n_most_recent_images,
|
||||
provider_base_url=provider_base_url,
|
||||
provider=provider,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported loop type: {loop_type}")
|
||||
|
||||
@@ -8,6 +8,7 @@ DEFAULT_MODELS = {
|
||||
LLMProvider.ANTHROPIC: "claude-3-7-sonnet-20250219",
|
||||
LLMProvider.OLLAMA: "gemma3:4b-it-q4_K_M",
|
||||
LLMProvider.OAICOMPAT: "Qwen2.5-VL-7B-Instruct",
|
||||
LLMProvider.MLXVLM: "mlx-community/UI-TARS-1.5-7B-4bit",
|
||||
}
|
||||
|
||||
# Map providers to their environment variable names
|
||||
@@ -16,4 +17,5 @@ ENV_VARS = {
|
||||
LLMProvider.ANTHROPIC: "ANTHROPIC_API_KEY",
|
||||
LLMProvider.OLLAMA: "none",
|
||||
LLMProvider.OAICOMPAT: "none", # OpenAI-compatible API typically doesn't require an API key
|
||||
LLMProvider.MLXVLM: "none", # MLX VLM typically doesn't require an API key
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ class LLMProvider(StrEnum):
|
||||
OPENAI = "openai"
|
||||
OLLAMA = "ollama"
|
||||
OAICOMPAT = "oaicompat"
|
||||
MLXVLM= "mlxvlm"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
263
libs/agent/agent/providers/uitars/clients/mlxvlm.py
Normal file
263
libs/agent/agent/providers/uitars/clients/mlxvlm.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""MLX LVM client implementation."""
|
||||
|
||||
import io
|
||||
import logging
|
||||
import base64
|
||||
import tempfile
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
from typing import Dict, List, Optional, Any, cast, Tuple
|
||||
from PIL import Image
|
||||
|
||||
from .base import BaseUITarsClient
|
||||
import mlx.core as mx
|
||||
from mlx_vlm import load, generate
|
||||
from mlx_vlm.prompt_utils import apply_chat_template
|
||||
from mlx_vlm.utils import load_config
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants for smart_resize
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
def round_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
def ceil_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
def floor_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
class MLXVLMUITarsClient(BaseUITarsClient):
|
||||
"""MLX LVM client implementation class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "mlx-community/UI-TARS-1.5-7B-4bit"
|
||||
):
|
||||
"""Initialize MLX LVM client.
|
||||
|
||||
Args:
|
||||
model: Model name or path (defaults to mlx-community/UI-TARS-1.5-7B-4bit)
|
||||
"""
|
||||
# Load model and processor
|
||||
model_obj, processor = load(
|
||||
model,
|
||||
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
|
||||
)
|
||||
self.config = load_config(model)
|
||||
self.model = model_obj
|
||||
self.processor = processor
|
||||
self.model_name = model
|
||||
|
||||
def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
|
||||
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
|
||||
|
||||
Args:
|
||||
text: Text containing box tokens
|
||||
original_size: Original image size (width, height)
|
||||
model_size: Model processed image size (width, height)
|
||||
|
||||
Returns:
|
||||
Text with processed coordinates
|
||||
"""
|
||||
# Find all box tokens
|
||||
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
|
||||
|
||||
def process_coords(match):
|
||||
model_x, model_y = int(match.group(1)), int(match.group(2))
|
||||
# Scale coordinates from model space to original image space
|
||||
# Both original_size and model_size are in (width, height) format
|
||||
new_x = int(model_x * original_size[0] / model_size[0]) # Width
|
||||
new_y = int(model_y * original_size[1] / model_size[1]) # Height
|
||||
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
|
||||
|
||||
return re.sub(box_pattern, process_coords, text)
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Run interleaved chat completion.
|
||||
|
||||
Args:
|
||||
messages: List of message dicts
|
||||
system: System prompt
|
||||
max_tokens: Optional max tokens override
|
||||
|
||||
Returns:
|
||||
Response dict
|
||||
"""
|
||||
# Ensure the system message is included
|
||||
if not any(msg.get("role") == "system" for msg in messages):
|
||||
messages = [{"role": "system", "content": system}] + messages
|
||||
|
||||
# Create a deep copy of messages to avoid modifying the original
|
||||
processed_messages = messages.copy()
|
||||
|
||||
# Extract images and process messages
|
||||
images = []
|
||||
original_sizes = {} # Track original sizes of images for coordinate mapping
|
||||
model_sizes = {} # Track model processed sizes
|
||||
image_index = 0
|
||||
|
||||
for msg_idx, msg in enumerate(messages):
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
# Create a copy of the content list to modify
|
||||
processed_content = []
|
||||
|
||||
for item_idx, item in enumerate(content):
|
||||
if item.get("type") == "image_url":
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
pil_image = None
|
||||
|
||||
if image_url.startswith("data:image/"):
|
||||
# Extract base64 data
|
||||
base64_data = image_url.split(',')[1]
|
||||
# Convert base64 to PIL Image
|
||||
image_data = base64.b64decode(base64_data)
|
||||
pil_image = Image.open(io.BytesIO(image_data))
|
||||
else:
|
||||
# Handle file path or URL
|
||||
pil_image = Image.open(image_url)
|
||||
|
||||
# Store original image size for coordinate mapping
|
||||
original_size = pil_image.size
|
||||
original_sizes[image_index] = original_size
|
||||
|
||||
# Use smart_resize to determine model size
|
||||
# Note: smart_resize expects (height, width) but PIL gives (width, height)
|
||||
height, width = original_size[1], original_size[0]
|
||||
new_height, new_width = smart_resize(height, width)
|
||||
# Store model size in (width, height) format for consistent coordinate processing
|
||||
model_sizes[image_index] = (new_width, new_height)
|
||||
|
||||
# Resize the image using the calculated dimensions from smart_resize
|
||||
resized_image = pil_image.resize((new_width, new_height))
|
||||
images.append(resized_image)
|
||||
image_index += 1
|
||||
|
||||
# Copy items to processed content list
|
||||
processed_content.append(item.copy())
|
||||
|
||||
# Update the processed message content
|
||||
processed_messages[msg_idx] = msg.copy()
|
||||
processed_messages[msg_idx]["content"] = processed_content
|
||||
|
||||
logger.info(f"resized {len(images)} from {original_sizes[0]} to {model_sizes[0]}")
|
||||
|
||||
# Process user text input with box coordinates after image processing
|
||||
# Swap original_size and model_size arguments for inverse transformation
|
||||
for msg_idx, msg in enumerate(processed_messages):
|
||||
if msg.get("role") == "user" and isinstance(msg.get("content"), str):
|
||||
if "<|box_start|>" in msg.get("content") and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
||||
orig_size = original_sizes[0]
|
||||
model_size = model_sizes[0]
|
||||
# Swap arguments to perform inverse transformation for user input
|
||||
processed_messages[msg_idx]["content"] = self._process_coordinates(msg["content"], model_size, orig_size)
|
||||
|
||||
try:
|
||||
# Format prompt according to model requirements using the processor directly
|
||||
prompt = self.processor.apply_chat_template(
|
||||
processed_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
tokenizer = cast(PreTrainedTokenizer, self.processor)
|
||||
|
||||
print("generating response...")
|
||||
|
||||
# Generate response
|
||||
text_content, usage = generate(
|
||||
self.model,
|
||||
tokenizer,
|
||||
str(prompt),
|
||||
images,
|
||||
verbose=False,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
from pprint import pprint
|
||||
print("DEBUG - AGENT GENERATION --------")
|
||||
pprint(text_content)
|
||||
print("DEBUG - AGENT GENERATION --------")
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating response: {str(e)}")
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": f"Error generating response: {str(e)}"
|
||||
},
|
||||
"finish_reason": "error"
|
||||
}
|
||||
],
|
||||
"model": self.model_name,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Process coordinates in the response back to original image space
|
||||
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
||||
# Get original image size and model size (using the first image)
|
||||
orig_size = original_sizes[0]
|
||||
model_size = model_sizes[0]
|
||||
|
||||
# Check if output contains box tokens that need processing
|
||||
if "<|box_start|>" in text_content:
|
||||
# Process coordinates from model space back to original image space
|
||||
text_content = self._process_coordinates(text_content, orig_size, model_size)
|
||||
|
||||
# Format response to match OpenAI format
|
||||
response = {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text_content
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"model": self.model_name,
|
||||
"usage": usage
|
||||
}
|
||||
|
||||
return response
|
||||
@@ -23,6 +23,7 @@ from .tools.computer import ToolResult
|
||||
from .prompts import COMPUTER_USE, SYSTEM_PROMPT, MAC_SPECIFIC_NOTES
|
||||
|
||||
from .clients.oaicompat import OAICompatClient
|
||||
from .clients.mlxvlm import MLXVLMUITarsClient
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,6 +45,7 @@ class UITARSLoop(BaseLoop):
|
||||
computer: Computer,
|
||||
api_key: str,
|
||||
model: str,
|
||||
provider: Optional[LLMProvider] = None,
|
||||
provider_base_url: Optional[str] = "http://localhost:8000/v1",
|
||||
only_n_most_recent_images: Optional[int] = 2,
|
||||
base_dir: Optional[str] = "trajectories",
|
||||
@@ -64,9 +66,10 @@ class UITARSLoop(BaseLoop):
|
||||
max_retries: Maximum number of retries for API calls
|
||||
retry_delay: Delay between retries in seconds
|
||||
save_trajectory: Whether to save trajectory data
|
||||
provider: The LLM provider to use (defaults to OAICOMPAT if not specified)
|
||||
"""
|
||||
# Set provider before initializing base class
|
||||
self.provider = LLMProvider.OAICOMPAT
|
||||
self.provider = provider or LLMProvider.OAICOMPAT
|
||||
self.provider_base_url = provider_base_url
|
||||
|
||||
# Initialize message manager with image retention config
|
||||
@@ -113,7 +116,7 @@ class UITARSLoop(BaseLoop):
|
||||
logger.error(f"Error initializing tool manager: {str(e)}")
|
||||
logger.warning("Will attempt to initialize tools on first use.")
|
||||
|
||||
# Initialize client for the OAICompat provider
|
||||
# Initialize client for the selected provider
|
||||
try:
|
||||
await self.initialize_client()
|
||||
except Exception as e:
|
||||
@@ -128,18 +131,28 @@ class UITARSLoop(BaseLoop):
|
||||
"""Initialize the appropriate client.
|
||||
|
||||
Implements abstract method from BaseLoop to set up the specific
|
||||
provider client (OAICompat for UI-TARS).
|
||||
provider client based on the configured provider.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Initializing OAICompat client for UI-TARS with model {self.model}...")
|
||||
|
||||
self.client = OAICompatClient(
|
||||
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
|
||||
model=self.model,
|
||||
provider_base_url=self.provider_base_url,
|
||||
)
|
||||
|
||||
logger.info(f"Initialized OAICompat client with model {self.model}")
|
||||
if self.provider == LLMProvider.MLXVLM:
|
||||
logger.info(f"Initializing MLX VLM client for UI-TARS with model {self.model}...")
|
||||
|
||||
self.client = MLXVLMUITarsClient(
|
||||
model=self.model,
|
||||
)
|
||||
|
||||
logger.info(f"Initialized MLX VLM client with model {self.model}")
|
||||
else:
|
||||
# Default to OAICompat client for other providers
|
||||
logger.info(f"Initializing OAICompat client for UI-TARS with model {self.model}...")
|
||||
|
||||
self.client = OAICompatClient(
|
||||
api_key=self.api_key or "EMPTY", # Local endpoints typically don't require an API key
|
||||
model=self.model,
|
||||
provider_base_url=self.provider_base_url,
|
||||
)
|
||||
|
||||
logger.info(f"Initialized OAICompat client with model {self.model}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing client: {str(e)}")
|
||||
self.client = None
|
||||
|
||||
@@ -105,7 +105,7 @@ async def to_agent_response_format(
|
||||
}
|
||||
],
|
||||
truncation="auto",
|
||||
usage=response["usage"],
|
||||
usage=response.get("usage", {}),
|
||||
user=None,
|
||||
metadata={},
|
||||
response=response
|
||||
|
||||
@@ -164,8 +164,10 @@ MODEL_MAPPINGS = {
|
||||
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet-20250219",
|
||||
},
|
||||
"uitars": {
|
||||
# UI-TARS models default to custom endpoint
|
||||
"default": "ByteDance-Seed/UI-TARS-1.5-7B",
|
||||
# UI-TARS models using MLXVLM provider
|
||||
"default": "mlx-community/UI-TARS-1.5-7B-4bit",
|
||||
"mlx-community/UI-TARS-1.5-7B-4bit": "mlx-community/UI-TARS-1.5-7B-4bit",
|
||||
"mlx-community/UI-TARS-1.5-7B-6bit": "mlx-community/UI-TARS-1.5-7B-6bit"
|
||||
},
|
||||
"ollama": {
|
||||
# For Ollama models, we keep the original name
|
||||
@@ -288,8 +290,16 @@ def get_provider_and_model(model_name: str, loop_provider: str) -> tuple:
|
||||
model_name_to_use = cleaned_model_name
|
||||
# agent_loop remains AgentLoop.OMNI
|
||||
elif agent_loop == AgentLoop.UITARS:
|
||||
provider = LLMProvider.OAICOMPAT
|
||||
model_name_to_use = MODEL_MAPPINGS["uitars"]["default"] # Default
|
||||
# For UITARS, use MLXVLM provider for the MLX models, OAICOMPAT for custom
|
||||
if model_name == "Custom model...":
|
||||
provider = LLMProvider.OAICOMPAT
|
||||
model_name_to_use = "tgi"
|
||||
else:
|
||||
provider = LLMProvider.MLXVLM
|
||||
# Get the model name from the mappings or use as-is if not found
|
||||
model_name_to_use = MODEL_MAPPINGS["uitars"].get(
|
||||
model_name, model_name if model_name else MODEL_MAPPINGS["uitars"]["default"]
|
||||
)
|
||||
else:
|
||||
# Default to OpenAI if unrecognized loop
|
||||
provider = LLMProvider.OPENAI
|
||||
@@ -440,7 +450,11 @@ def create_gradio_ui(
|
||||
"OPENAI": openai_models,
|
||||
"ANTHROPIC": anthropic_models,
|
||||
"OMNI": omni_models + ["Custom model..."], # Add custom model option
|
||||
"UITARS": ["Custom model..."], # UI-TARS options
|
||||
"UITARS": [
|
||||
"mlx-community/UI-TARS-1.5-7B-4bit",
|
||||
"mlx-community/UI-TARS-1.5-7B-6bit",
|
||||
"Custom model..."
|
||||
], # UI-TARS options with MLX models
|
||||
}
|
||||
|
||||
# --- Apply Saved Settings (override defaults if available) ---
|
||||
|
||||
@@ -37,6 +37,9 @@ openai = [
|
||||
uitars = [
|
||||
"httpx>=0.27.0,<0.29.0",
|
||||
]
|
||||
uitars-mlx = [
|
||||
"mlx-vlm @ git+https://github.com/ddupont808/mlx-vlm.git@stable/fix/qwen2-position-id"
|
||||
]
|
||||
ui = [
|
||||
"gradio>=5.23.3,<6.0.0",
|
||||
"python-dotenv>=1.0.1,<2.0.0",
|
||||
@@ -84,7 +87,8 @@ all = [
|
||||
"requests>=2.31.0,<3.0.0",
|
||||
"ollama>=0.4.7,<0.5.0",
|
||||
"gradio>=5.23.3,<6.0.0",
|
||||
"python-dotenv>=1.0.1,<2.0.0"
|
||||
"python-dotenv>=1.0.1,<2.0.0",
|
||||
"mlx-vlm @ git+https://github.com/ddupont808/mlx-vlm.git@stable/fix/qwen2-position-id"
|
||||
]
|
||||
|
||||
[tool.pdm]
|
||||
|
||||
Reference in New Issue
Block a user