diff --git a/libs/agent/agent/providers/uitars/clients/mlxvlm.py b/libs/agent/agent/providers/uitars/clients/mlxvlm.py index d1e5dfff..c0c9b459 100644 --- a/libs/agent/agent/providers/uitars/clients/mlxvlm.py +++ b/libs/agent/agent/providers/uitars/clients/mlxvlm.py @@ -5,7 +5,8 @@ import logging import base64 import tempfile import os -from typing import Dict, List, Optional, Any, cast +import re +from typing import Dict, List, Optional, Any, cast, Tuple from PIL import Image from .base import BaseUITarsClient @@ -21,11 +22,17 @@ logger = logging.getLogger(__name__) class MLXVLMUITarsClient(BaseUITarsClient): """MLX LVM client implementation class.""" - def __init__(self, model: str = "mlx-community/UI-TARS-1.5-7B-4bit"): + def __init__( + self, + model: str = "mlx-community/UI-TARS-1.5-7B-4bit", + force_resolution: Optional[Tuple[int, int]] = (1512, 982) + ): """Initialize MLX LVM client. Args: model: Model name or path (defaults to mlx-community/UI-TARS-1.5-7B-4bit) + force_resolution: Optional target resolution to resize images to (width, height). + If None, images will not be resized. """ # Load model and processor model_obj, processor = load(model) @@ -33,8 +40,32 @@ class MLXVLMUITarsClient(BaseUITarsClient): self.model = model_obj self.processor = processor self.model_name = model + self.force_resolution = force_resolution + def _remap_coordinates(self, text: str, original_size: Tuple[int, int], target_size: Tuple[int, int]) -> str: + """Remap coordinates in box tokens based on image resizing. + + Args: + text: Text containing box tokens + original_size: Original image size (width, height) + target_size: Target image size (width, height) + + Returns: + Text with remapped coordinates + """ + # Find all box tokens + box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>" + + def remap_coords(match): + x, y = int(match.group(1)), int(match.group(2)) + # Scale coordinates to new dimensions + new_x = int(x * target_size[0] / original_size[0]) + new_y = int(y * target_size[1] / original_size[1]) + return f"<|box_start|>({new_x},{new_y})<|box_end|>" + + return re.sub(box_pattern, remap_coords, text) + async def run_interleaved( self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None ) -> Dict[str, Any]: @@ -51,32 +82,79 @@ class MLXVLMUITarsClient(BaseUITarsClient): # Ensure the system message is included if not any(msg.get("role") == "system" for msg in messages): messages = [{"role": "system", "content": system}] + messages - - # Extract any images from the messages + + # Create a deep copy of messages to avoid modifying the original + processed_messages = messages.copy() + + # Extract images and process messages if force_resolution is set images = [] - for msg in messages: + original_sizes = {} # Track original sizes of images for coordinate remapping + image_index = 0 + + for msg_idx, msg in enumerate(messages): content = msg.get("content", []) - if isinstance(content, list): - for item in content: - if item.get("type") == "image_url": - image_url = item.get("image_url", {}).get("url", "") - if image_url.startswith("data:image/"): - # Extract base64 data - base64_data = image_url.split(',')[1] - - # Convert base64 to PIL Image - image_data = base64.b64decode(base64_data) - pil_image = Image.open(io.BytesIO(image_data)) - images.append(pil_image) - else: - # Handle file path or URL - pil_image = Image.open(image_url) - images.append(pil_image) + if not isinstance(content, list): + continue + + # Create a copy of the content list to modify + processed_content = [] + + for item_idx, item in enumerate(content): + if item.get("type") == "image_url": + image_url = item.get("image_url", {}).get("url", "") + pil_image = None + + if image_url.startswith("data:image/"): + # Extract base64 data + base64_data = image_url.split(',')[1] + # Convert base64 to PIL Image + image_data = base64.b64decode(base64_data) + pil_image = Image.open(io.BytesIO(image_data)) + else: + # Handle file path or URL + pil_image = Image.open(image_url) + + # Store original image size for coordinate mapping + original_sizes[image_index] = pil_image.size + + # Resize image if force_resolution is set + if self.force_resolution: + pil_image = pil_image.resize(self.force_resolution) + + images.append(pil_image) + image_index += 1 + + # Copy items to processed content list + processed_content.append(item.copy()) + + # Update the processed message content + processed_messages[msg_idx] = msg.copy() + processed_messages[msg_idx]["content"] = processed_content + + # Remap coordinates in messages with box tokens if force_resolution is set + if self.force_resolution and original_sizes: + for msg_idx, msg in enumerate(processed_messages): + content = msg.get("content", []) + if not isinstance(content, list): + continue + + for item_idx, item in enumerate(content): + if item.get("type") == "text": + text_content = item.get("text", "") + + # Check if there are any box tokens to remap + if "<|box_start|>" in text_content: + # Use the first image's dimensions as reference (most common case) + if 0 in original_sizes: + orig_size = original_sizes[0] + processed_messages[msg_idx]["content"][item_idx]["text"] = self._remap_coordinates( + text_content, orig_size, self.force_resolution + ) try: # Format prompt according to model requirements using the processor directly prompt = self.processor.apply_chat_template( - messages, + processed_messages, # Use processed messages instead of original tokenize=False, add_generation_prompt=True ) @@ -108,6 +186,17 @@ class MLXVLMUITarsClient(BaseUITarsClient): "error": str(e) } + # Remap coordinates in the response back to original image space if needed + if self.force_resolution and original_sizes and 0 in original_sizes: + # Get original image size (using the first image) + orig_size = original_sizes[0] + + # Check if output contains box tokens that need remapping + if "<|box_start|>" in output: + # Remap coordinates from model space back to original image space + # We just swap the arguments - from force_resolution back to original size + output = self._remap_coordinates(output, self.force_resolution, orig_size) + # Format response to match OpenAI format response = { "choices": [