extra coordinate processing

This commit is contained in:
Dillon DuPont
2025-04-28 16:32:48 -04:00
parent 00eb09209c
commit 8e8200dc17

View File

@@ -6,6 +6,7 @@ import base64
import tempfile
import os
import re
import math
from typing import Dict, List, Optional, Any, cast, Tuple
from PIL import Image
@@ -18,53 +19,95 @@ from transformers.tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
# Constants for smart_resize
IMAGE_FACTOR = 28
MIN_PIXELS = 100 * 28 * 28
MAX_PIXELS = 16384 * 28 * 28
MAX_RATIO = 200
def round_by_factor(number: float, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
return round(number / factor) * factor
def ceil_by_factor(number: float, factor: int) -> int:
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
return math.ceil(number / factor) * factor
def floor_by_factor(number: float, factor: int) -> int:
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return math.floor(number / factor) * factor
def smart_resize(
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
"""
Rescales the image so that the following conditions are met:
1. Both dimensions (height and width) are divisible by 'factor'.
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
3. The aspect ratio of the image is maintained as closely as possible.
"""
if max(height, width) / min(height, width) > MAX_RATIO:
raise ValueError(
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
)
h_bar = max(factor, round_by_factor(height, factor))
w_bar = max(factor, round_by_factor(width, factor))
if h_bar * w_bar > max_pixels:
beta = math.sqrt((height * width) / max_pixels)
h_bar = floor_by_factor(height / beta, factor)
w_bar = floor_by_factor(width / beta, factor)
elif h_bar * w_bar < min_pixels:
beta = math.sqrt(min_pixels / (height * width))
h_bar = ceil_by_factor(height * beta, factor)
w_bar = ceil_by_factor(width * beta, factor)
return h_bar, w_bar
class MLXVLMUITarsClient(BaseUITarsClient):
"""MLX LVM client implementation class."""
def __init__(
self,
model: str = "mlx-community/UI-TARS-1.5-7B-4bit",
force_resolution: Optional[Tuple[int, int]] = (1512, 982)
model: str = "mlx-community/UI-TARS-1.5-7B-4bit"
):
"""Initialize MLX LVM client.
Args:
model: Model name or path (defaults to mlx-community/UI-TARS-1.5-7B-4bit)
force_resolution: Optional target resolution to resize images to (width, height).
If None, images will not be resized.
"""
# Load model and processor
model_obj, processor = load(model)
model_obj, processor = load(
model,
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
)
self.config = load_config(model)
self.model = model_obj
self.processor = processor
self.model_name = model
self.force_resolution = force_resolution
def _remap_coordinates(self, text: str, original_size: Tuple[int, int], target_size: Tuple[int, int]) -> str:
"""Remap coordinates in box tokens based on image resizing.
def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
Args:
text: Text containing box tokens
original_size: Original image size (width, height)
target_size: Target image size (width, height)
model_size: Model processed image size (width, height)
Returns:
Text with remapped coordinates
Text with processed coordinates
"""
# Find all box tokens
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
def remap_coords(match):
x, y = int(match.group(1)), int(match.group(2))
# Scale coordinates to new dimensions
new_x = int(x * target_size[0] / original_size[0])
new_y = int(y * target_size[1] / original_size[1])
def process_coords(match):
model_x, model_y = int(match.group(1)), int(match.group(2))
# Scale coordinates from model space to original image space
# Note that model_size is (height, width) while original_size is (width, height)
new_x = int(model_x * original_size[0] / model_size[1]) # Width
new_y = int(model_y * original_size[1] / model_size[0]) # Height
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
return re.sub(box_pattern, remap_coords, text)
return re.sub(box_pattern, process_coords, text)
async def run_interleaved(
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
@@ -86,9 +129,10 @@ class MLXVLMUITarsClient(BaseUITarsClient):
# Create a deep copy of messages to avoid modifying the original
processed_messages = messages.copy()
# Extract images and process messages if force_resolution is set
# Extract images and process messages
images = []
original_sizes = {} # Track original sizes of images for coordinate remapping
original_sizes = {} # Track original sizes of images for coordinate mapping
model_sizes = {} # Track model processed sizes
image_index = 0
for msg_idx, msg in enumerate(messages):
@@ -115,13 +159,18 @@ class MLXVLMUITarsClient(BaseUITarsClient):
pil_image = Image.open(image_url)
# Store original image size for coordinate mapping
original_sizes[image_index] = pil_image.size
original_size = pil_image.size
original_sizes[image_index] = original_size
# Resize image if force_resolution is set
if self.force_resolution:
pil_image = pil_image.resize(self.force_resolution)
# Use smart_resize to determine model size
# Note: smart_resize expects (height, width) but PIL gives (width, height)
height, width = original_size[1], original_size[0]
new_height, new_width = smart_resize(height, width)
model_sizes[image_index] = (new_height, new_width)
images.append(pil_image)
# Resize the image using the calculated dimensions from smart_resize
resized_image = pil_image.resize((new_width, new_height))
images.append(resized_image)
image_index += 1
# Copy items to processed content list
@@ -131,30 +180,10 @@ class MLXVLMUITarsClient(BaseUITarsClient):
processed_messages[msg_idx] = msg.copy()
processed_messages[msg_idx]["content"] = processed_content
# Remap coordinates in messages with box tokens if force_resolution is set
if self.force_resolution and original_sizes:
for msg_idx, msg in enumerate(processed_messages):
content = msg.get("content", [])
if not isinstance(content, list):
continue
for item_idx, item in enumerate(content):
if item.get("type") == "text":
text_content = item.get("text", "")
# Check if there are any box tokens to remap
if "<|box_start|>" in text_content:
# Use the first image's dimensions as reference (most common case)
if 0 in original_sizes:
orig_size = original_sizes[0]
processed_messages[msg_idx]["content"][item_idx]["text"] = self._remap_coordinates(
text_content, orig_size, self.force_resolution
)
try:
# Format prompt according to model requirements using the processor directly
prompt = self.processor.apply_chat_template(
processed_messages, # Use processed messages instead of original
processed_messages,
tokenize=False,
add_generation_prompt=True
)
@@ -186,16 +215,16 @@ class MLXVLMUITarsClient(BaseUITarsClient):
"error": str(e)
}
# Remap coordinates in the response back to original image space if needed
if self.force_resolution and original_sizes and 0 in original_sizes:
# Get original image size (using the first image)
# Process coordinates in the response back to original image space
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
# Get original image size and model size (using the first image)
orig_size = original_sizes[0]
model_size = model_sizes[0]
# Check if output contains box tokens that need remapping
# Check if output contains box tokens that need processing
if "<|box_start|>" in output:
# Remap coordinates from model space back to original image space
# We just swap the arguments - from force_resolution back to original size
output = self._remap_coordinates(output, self.force_resolution, orig_size)
# Process coordinates from model space back to original image space
output = self._process_coordinates(output, orig_size, model_size)
# Format response to match OpenAI format
response = {