diff --git a/libs/python/agent/agent/adapters/mlxvlm_adapter.py b/libs/python/agent/agent/adapters/mlxvlm_adapter.py index 8daf4bdd..c38f4ad6 100644 --- a/libs/python/agent/agent/adapters/mlxvlm_adapter.py +++ b/libs/python/agent/agent/adapters/mlxvlm_adapter.py @@ -1,27 +1,75 @@ import asyncio import functools import warnings +import io +import base64 +import math +import re from concurrent.futures import ThreadPoolExecutor -from typing import Iterator, AsyncIterator, Dict, List, Any, Optional +from typing import Iterator, AsyncIterator, Dict, List, Any, Optional, Tuple, cast +from PIL import Image from litellm.types.utils import GenericStreamingChunk, ModelResponse from litellm.llms.custom_llm import CustomLLM from litellm import completion, acompletion -import base64 -from io import BytesIO -from PIL import Image -# Try to import MLX-VLM dependencies +# Try to import MLX dependencies try: import mlx.core as mx - from mlx_vlm import load - from mlx_vlm.utils import generate - MLX_VLM_AVAILABLE = True + from mlx_vlm import load, generate + from mlx_vlm.prompt_utils import apply_chat_template + from mlx_vlm.utils import load_config + from transformers.tokenization_utils import PreTrainedTokenizer + MLX_AVAILABLE = True except ImportError: - MLX_VLM_AVAILABLE = False + MLX_AVAILABLE = False + +# Constants for smart_resize +IMAGE_FACTOR = 28 +MIN_PIXELS = 100 * 28 * 28 +MAX_PIXELS = 16384 * 28 * 28 +MAX_RATIO = 200 + +def round_by_factor(number: float, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + +def ceil_by_factor(number: float, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + +def floor_by_factor(number: float, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + +def smart_resize( + height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS +) -> tuple[int, int]: + """ + Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + 3. The aspect ratio of the image is maintained as closely as possible. + """ + if max(height, width) / min(height, width) > MAX_RATIO: + raise ValueError( + f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = floor_by_factor(height / beta, factor) + w_bar = floor_by_factor(width / beta, factor) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + return h_bar, w_bar class MLXVLMAdapter(CustomLLM): - """MLX-VLM Adapter for running vision-language models locally using Apple's MLX framework.""" + """MLX VLM Adapter for running vision-language models locally using MLX.""" def __init__(self, **kwargs): """Initialize the adapter. @@ -30,13 +78,14 @@ class MLXVLMAdapter(CustomLLM): **kwargs: Additional arguments """ super().__init__() + if not MLX_AVAILABLE: + raise ImportError("MLX VLM dependencies not available. Please install mlx-vlm.") + self.models = {} # Cache for loaded models self.processors = {} # Cache for loaded processors + self.configs = {} # Cache for loaded configs self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool - if not MLX_VLM_AVAILABLE: - raise ImportError("MLX-VLM dependencies not available. Please install mlx-vlm.") - def _load_model_and_processor(self, model_name: str): """Load model and processor if not already cached. @@ -44,37 +93,64 @@ class MLXVLMAdapter(CustomLLM): model_name: Name of the model to load Returns: - Tuple of (model, processor) + Tuple of (model, processor, config) """ if model_name not in self.models: - # Load model and processor using mlx-vlm - model, processor = load( - model_name, - processor_kwargs={ - "min_pixels": 256 * 28 * 28, - "max_pixels": 1512 * 982 - } + # Load model and processor + model_obj, processor = load( + model_name, + processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS} ) + config = load_config(model_name) # Cache them - self.models[model_name] = model + self.models[model_name] = model_obj self.processors[model_name] = processor + self.configs[model_name] = config - return self.models[model_name], self.processors[model_name] + return self.models[model_name], self.processors[model_name], self.configs[model_name] - def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]: - """Convert OpenAI format messages to MLX-VLM format. + def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str: + """Process coordinates in box tokens based on image resizing using smart_resize approach. + + Args: + text: Text containing box tokens + original_size: Original image size (width, height) + model_size: Model processed image size (width, height) + + Returns: + Text with processed coordinates + """ + # Find all box tokens + box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>" + + def process_coords(match): + model_x, model_y = int(match.group(1)), int(match.group(2)) + # Scale coordinates from model space to original image space + # Both original_size and model_size are in (width, height) format + new_x = int(model_x * original_size[0] / model_size[0]) # Width + new_y = int(model_y * original_size[1] / model_size[1]) # Height + return f"<|box_start|>({new_x},{new_y})<|box_end|>" + + return re.sub(box_pattern, process_coords, text) + + def _convert_messages(self, messages: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], List[Image.Image], Dict[int, Tuple[int, int]], Dict[int, Tuple[int, int]]]: + """Convert OpenAI format messages to MLX VLM format and extract images. Args: messages: Messages in OpenAI format Returns: - Messages in MLX-VLM format + Tuple of (processed_messages, images, original_sizes, model_sizes) """ - converted_messages = [] + processed_messages = [] + images = [] + original_sizes = {} # Track original sizes of images for coordinate mapping + model_sizes = {} # Track model processed sizes + image_index = 0 for message in messages: - converted_message = { + processed_message = { "role": message["role"], "content": [] } @@ -82,76 +158,60 @@ class MLXVLMAdapter(CustomLLM): content = message.get("content", []) if isinstance(content, str): # Simple text content - converted_message["content"].append({ - "type": "text", - "text": content - }) + processed_message["content"] = content elif isinstance(content, list): # Multi-modal content + processed_content = [] for item in content: if item.get("type") == "text": - converted_message["content"].append({ + processed_content.append({ "type": "text", "text": item.get("text", "") }) elif item.get("type") == "image_url": - # Convert image_url format to image format image_url = item.get("image_url", {}).get("url", "") - converted_message["content"].append({ - "type": "image", - "image": image_url + pil_image = None + + if image_url.startswith("data:image/"): + # Extract base64 data + base64_data = image_url.split(',')[1] + # Convert base64 to PIL Image + image_data = base64.b64decode(base64_data) + pil_image = Image.open(io.BytesIO(image_data)) + else: + # Handle file path or URL + pil_image = Image.open(image_url) + + # Store original image size for coordinate mapping + original_size = pil_image.size + original_sizes[image_index] = original_size + + # Use smart_resize to determine model size + # Note: smart_resize expects (height, width) but PIL gives (width, height) + height, width = original_size[1], original_size[0] + new_height, new_width = smart_resize(height, width) + # Store model size in (width, height) format for consistent coordinate processing + model_sizes[image_index] = (new_width, new_height) + + # Resize the image using the calculated dimensions from smart_resize + resized_image = pil_image.resize((new_width, new_height)) + images.append(resized_image) + + # Add image placeholder to content + processed_content.append({ + "type": "image" }) - elif item.get("type") == "image": - # Direct image format - pass through - converted_message["content"].append(item) + + image_index += 1 + + processed_message["content"] = processed_content - converted_messages.append(converted_message) - - return converted_messages - - def _process_image_from_url(self, image_url: str) -> Image.Image: - """Process image from URL (base64 or file path). + processed_messages.append(processed_message) - Args: - image_url: Image URL (data:image/... or file path) - - Returns: - PIL Image object - """ - if image_url.startswith("data:image/"): - # Base64 encoded image - header, data = image_url.split(",", 1) - image_data = base64.b64decode(data) - return Image.open(BytesIO(image_data)) - else: - # File path or URL - return Image.open(image_url) - - def _extract_image_from_messages(self, messages: List[Dict[str, Any]]) -> Optional[Image.Image]: - """Extract the first image from messages. - - Args: - messages: List of messages - - Returns: - PIL Image object or None - """ - for message in messages: - content = message.get("content", []) - if isinstance(content, list): - for item in content: - if item.get("type") == "image": - image_url = item.get("image", "") - if image_url: - return self._process_image_from_url(image_url) - elif item.get("type") == "image_url": - image_url = item.get("image_url", {}).get("url", "") - if image_url: - return self._process_image_from_url(image_url) - return None + return processed_messages, images, original_sizes, model_sizes def _generate(self, **kwargs) -> str: - """Generate response using the local MLX-VLM model. + """Generate response using the local MLX VLM model. Args: **kwargs: Keyword arguments containing messages and model info @@ -160,52 +220,66 @@ class MLXVLMAdapter(CustomLLM): Generated text response """ messages = kwargs.get('messages', []) - model_name = kwargs.get('model', 'mlx-community/Qwen2.5-VL-7B-Instruct-4bit') - max_tokens = kwargs.get('max_tokens', 1000) - temperature = kwargs.get('temperature', 0.1) + model_name = kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit') + max_tokens = kwargs.get('max_tokens', 128) # Warn about ignored kwargs - ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens', 'temperature'} + ignored_kwargs = set(kwargs.keys()) - {'messages', 'model', 'max_tokens'} if ignored_kwargs: warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}") # Load model and processor - model, processor = self._load_model_and_processor(model_name) + model, processor, config = self._load_model_and_processor(model_name) - # Convert messages to MLX-VLM format - mlx_messages = self._convert_messages(messages) + # Convert messages and extract images + processed_messages, images, original_sizes, model_sizes = self._convert_messages(messages) - # Extract image from messages - image = self._extract_image_from_messages(mlx_messages) + # Process user text input with box coordinates after image processing + # Swap original_size and model_size arguments for inverse transformation + for msg_idx, msg in enumerate(processed_messages): + if msg.get("role") == "user" and isinstance(msg.get("content"), str): + content = msg.get("content", "") + if "<|box_start|>" in content and original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes: + orig_size = original_sizes[0] + model_size = model_sizes[0] + # Swap arguments to perform inverse transformation for user input + processed_messages[msg_idx]["content"] = self._process_coordinates(content, model_size, orig_size) - # Apply chat template - prompt = processor.apply_chat_template( - mlx_messages, - tokenize=False, - add_generation_prompt=True, - ) - - # Generate response using mlx-vlm try: - response = generate( - model, - processor, - prompt, - image, # type: ignore - temperature=temperature, - max_tokens=max_tokens, + # Format prompt according to model requirements using the processor directly + prompt = processor.apply_chat_template( + processed_messages, + tokenize=False, + add_generation_prompt=True, + return_tensors='pt' + ) + tokenizer = cast(PreTrainedTokenizer, processor) + + # Generate response + text_content, usage = generate( + model, + tokenizer, + str(prompt), + images, # type: ignore verbose=False, + max_tokens=max_tokens ) - # Clear MLX cache to free memory - mx.metal.clear_cache() - - return response - except Exception as e: - # Clear cache on error too - mx.metal.clear_cache() - raise e + raise RuntimeError(f"Error generating response: {str(e)}") from e + + # Process coordinates in the response back to original image space + if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes: + # Get original image size and model size (using the first image) + orig_size = original_sizes[0] + model_size = model_sizes[0] + + # Check if output contains box tokens that need processing + if "<|box_start|>" in text_content: + # Process coordinates from model space back to original image space + text_content = self._process_coordinates(text_content, orig_size, model_size) + + return text_content def completion(self, *args, **kwargs) -> ModelResponse: """Synchronous completion method. @@ -215,11 +289,11 @@ class MLXVLMAdapter(CustomLLM): """ generated_text = self._generate(**kwargs) - response = completion( - model=f"mlx/{kwargs.get('model', 'default')}", + result = completion( + model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}", mock_response=generated_text, ) - return response # type: ignore + return cast(ModelResponse, result) async def acompletion(self, *args, **kwargs) -> ModelResponse: """Asynchronous completion method. @@ -234,11 +308,11 @@ class MLXVLMAdapter(CustomLLM): functools.partial(self._generate, **kwargs) ) - response = await acompletion( - model=f"mlx/{kwargs.get('model', 'default')}", + result = await acompletion( + model=f"mlx/{kwargs.get('model', 'mlx-community/UI-TARS-1.5-7B-4bit')}", mock_response=generated_text, ) - return response # type: ignore + return cast(ModelResponse, result) def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]: """Synchronous streaming method. @@ -281,4 +355,4 @@ class MLXVLMAdapter(CustomLLM): "usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0}, } - yield generic_streaming_chunk + yield generic_streaming_chunk \ No newline at end of file diff --git a/libs/python/agent/agent/loops/uitars.py b/libs/python/agent/agent/loops/uitars.py index 10e0e45a..b5d5423c 100644 --- a/libs/python/agent/agent/loops/uitars.py +++ b/libs/python/agent/agent/loops/uitars.py @@ -228,15 +228,24 @@ def parse_uitars_response(text: str, image_width: int, image_height: int) -> Lis # Handle coordinate parameters if "start_box" in param_name or "end_box" in param_name: - # Parse coordinates like '(x,y)' or '(x1,y1,x2,y2)' - numbers = param.replace("(", "").replace(")", "").split(",") - float_numbers = [float(num.strip()) / 1000 for num in numbers] # Normalize to 0-1 range + # Parse coordinates like '<|box_start|>(x,y)<|box_end|>' or '(x,y)' + # First, remove special tokens + clean_param = param.replace("<|box_start|>", "").replace("<|box_end|>", "") + # Then remove parentheses and split + numbers = clean_param.replace("(", "").replace(")", "").split(",") - if len(float_numbers) == 2: - # Single point, duplicate for box format - float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]] - - action_inputs[param_name.strip()] = str(float_numbers) + try: + float_numbers = [float(num.strip()) / 1000 for num in numbers] # Normalize to 0-1 range + + if len(float_numbers) == 2: + # Single point, duplicate for box format + float_numbers = [float_numbers[0], float_numbers[1], float_numbers[0], float_numbers[1]] + + action_inputs[param_name.strip()] = str(float_numbers) + except ValueError as e: + # If parsing fails, keep the original parameter value + print(f"Warning: Could not parse coordinates '{param}': {e}") + action_inputs[param_name.strip()] = param return [{ "thought": thought,