mirror of
https://github.com/trycua/computer.git
synced 2026-01-07 05:50:13 -06:00
extra coordinate processing
This commit is contained in:
@@ -6,6 +6,7 @@ import base64
|
||||
import tempfile
|
||||
import os
|
||||
import re
|
||||
import math
|
||||
from typing import Dict, List, Optional, Any, cast, Tuple
|
||||
from PIL import Image
|
||||
|
||||
@@ -18,53 +19,95 @@ from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Constants for smart_resize
|
||||
IMAGE_FACTOR = 28
|
||||
MIN_PIXELS = 100 * 28 * 28
|
||||
MAX_PIXELS = 16384 * 28 * 28
|
||||
MAX_RATIO = 200
|
||||
|
||||
def round_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
return round(number / factor) * factor
|
||||
|
||||
def ceil_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.ceil(number / factor) * factor
|
||||
|
||||
def floor_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
|
||||
return math.floor(number / factor) * factor
|
||||
|
||||
def smart_resize(
|
||||
height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
|
||||
) -> tuple[int, int]:
|
||||
"""
|
||||
Rescales the image so that the following conditions are met:
|
||||
|
||||
1. Both dimensions (height and width) are divisible by 'factor'.
|
||||
2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
|
||||
3. The aspect ratio of the image is maintained as closely as possible.
|
||||
"""
|
||||
if max(height, width) / min(height, width) > MAX_RATIO:
|
||||
raise ValueError(
|
||||
f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
|
||||
)
|
||||
h_bar = max(factor, round_by_factor(height, factor))
|
||||
w_bar = max(factor, round_by_factor(width, factor))
|
||||
if h_bar * w_bar > max_pixels:
|
||||
beta = math.sqrt((height * width) / max_pixels)
|
||||
h_bar = floor_by_factor(height / beta, factor)
|
||||
w_bar = floor_by_factor(width / beta, factor)
|
||||
elif h_bar * w_bar < min_pixels:
|
||||
beta = math.sqrt(min_pixels / (height * width))
|
||||
h_bar = ceil_by_factor(height * beta, factor)
|
||||
w_bar = ceil_by_factor(width * beta, factor)
|
||||
return h_bar, w_bar
|
||||
|
||||
class MLXVLMUITarsClient(BaseUITarsClient):
|
||||
"""MLX LVM client implementation class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "mlx-community/UI-TARS-1.5-7B-4bit",
|
||||
force_resolution: Optional[Tuple[int, int]] = (1512, 982)
|
||||
model: str = "mlx-community/UI-TARS-1.5-7B-4bit"
|
||||
):
|
||||
"""Initialize MLX LVM client.
|
||||
|
||||
Args:
|
||||
model: Model name or path (defaults to mlx-community/UI-TARS-1.5-7B-4bit)
|
||||
force_resolution: Optional target resolution to resize images to (width, height).
|
||||
If None, images will not be resized.
|
||||
"""
|
||||
# Load model and processor
|
||||
model_obj, processor = load(model)
|
||||
model_obj, processor = load(
|
||||
model,
|
||||
processor_kwargs={"min_pixels": MIN_PIXELS, "max_pixels": MAX_PIXELS}
|
||||
)
|
||||
self.config = load_config(model)
|
||||
self.model = model_obj
|
||||
self.processor = processor
|
||||
self.model_name = model
|
||||
self.force_resolution = force_resolution
|
||||
|
||||
|
||||
def _remap_coordinates(self, text: str, original_size: Tuple[int, int], target_size: Tuple[int, int]) -> str:
|
||||
"""Remap coordinates in box tokens based on image resizing.
|
||||
def _process_coordinates(self, text: str, original_size: Tuple[int, int], model_size: Tuple[int, int]) -> str:
|
||||
"""Process coordinates in box tokens based on image resizing using smart_resize approach.
|
||||
|
||||
Args:
|
||||
text: Text containing box tokens
|
||||
original_size: Original image size (width, height)
|
||||
target_size: Target image size (width, height)
|
||||
model_size: Model processed image size (width, height)
|
||||
|
||||
Returns:
|
||||
Text with remapped coordinates
|
||||
Text with processed coordinates
|
||||
"""
|
||||
# Find all box tokens
|
||||
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
|
||||
|
||||
def remap_coords(match):
|
||||
x, y = int(match.group(1)), int(match.group(2))
|
||||
# Scale coordinates to new dimensions
|
||||
new_x = int(x * target_size[0] / original_size[0])
|
||||
new_y = int(y * target_size[1] / original_size[1])
|
||||
def process_coords(match):
|
||||
model_x, model_y = int(match.group(1)), int(match.group(2))
|
||||
# Scale coordinates from model space to original image space
|
||||
# Note that model_size is (height, width) while original_size is (width, height)
|
||||
new_x = int(model_x * original_size[0] / model_size[1]) # Width
|
||||
new_y = int(model_y * original_size[1] / model_size[0]) # Height
|
||||
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
|
||||
|
||||
return re.sub(box_pattern, remap_coords, text)
|
||||
return re.sub(box_pattern, process_coords, text)
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
@@ -86,9 +129,10 @@ class MLXVLMUITarsClient(BaseUITarsClient):
|
||||
# Create a deep copy of messages to avoid modifying the original
|
||||
processed_messages = messages.copy()
|
||||
|
||||
# Extract images and process messages if force_resolution is set
|
||||
# Extract images and process messages
|
||||
images = []
|
||||
original_sizes = {} # Track original sizes of images for coordinate remapping
|
||||
original_sizes = {} # Track original sizes of images for coordinate mapping
|
||||
model_sizes = {} # Track model processed sizes
|
||||
image_index = 0
|
||||
|
||||
for msg_idx, msg in enumerate(messages):
|
||||
@@ -115,13 +159,18 @@ class MLXVLMUITarsClient(BaseUITarsClient):
|
||||
pil_image = Image.open(image_url)
|
||||
|
||||
# Store original image size for coordinate mapping
|
||||
original_sizes[image_index] = pil_image.size
|
||||
original_size = pil_image.size
|
||||
original_sizes[image_index] = original_size
|
||||
|
||||
# Resize image if force_resolution is set
|
||||
if self.force_resolution:
|
||||
pil_image = pil_image.resize(self.force_resolution)
|
||||
# Use smart_resize to determine model size
|
||||
# Note: smart_resize expects (height, width) but PIL gives (width, height)
|
||||
height, width = original_size[1], original_size[0]
|
||||
new_height, new_width = smart_resize(height, width)
|
||||
model_sizes[image_index] = (new_height, new_width)
|
||||
|
||||
images.append(pil_image)
|
||||
# Resize the image using the calculated dimensions from smart_resize
|
||||
resized_image = pil_image.resize((new_width, new_height))
|
||||
images.append(resized_image)
|
||||
image_index += 1
|
||||
|
||||
# Copy items to processed content list
|
||||
@@ -131,30 +180,10 @@ class MLXVLMUITarsClient(BaseUITarsClient):
|
||||
processed_messages[msg_idx] = msg.copy()
|
||||
processed_messages[msg_idx]["content"] = processed_content
|
||||
|
||||
# Remap coordinates in messages with box tokens if force_resolution is set
|
||||
if self.force_resolution and original_sizes:
|
||||
for msg_idx, msg in enumerate(processed_messages):
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
for item_idx, item in enumerate(content):
|
||||
if item.get("type") == "text":
|
||||
text_content = item.get("text", "")
|
||||
|
||||
# Check if there are any box tokens to remap
|
||||
if "<|box_start|>" in text_content:
|
||||
# Use the first image's dimensions as reference (most common case)
|
||||
if 0 in original_sizes:
|
||||
orig_size = original_sizes[0]
|
||||
processed_messages[msg_idx]["content"][item_idx]["text"] = self._remap_coordinates(
|
||||
text_content, orig_size, self.force_resolution
|
||||
)
|
||||
|
||||
try:
|
||||
# Format prompt according to model requirements using the processor directly
|
||||
prompt = self.processor.apply_chat_template(
|
||||
processed_messages, # Use processed messages instead of original
|
||||
processed_messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
@@ -186,16 +215,16 @@ class MLXVLMUITarsClient(BaseUITarsClient):
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Remap coordinates in the response back to original image space if needed
|
||||
if self.force_resolution and original_sizes and 0 in original_sizes:
|
||||
# Get original image size (using the first image)
|
||||
# Process coordinates in the response back to original image space
|
||||
if original_sizes and model_sizes and 0 in original_sizes and 0 in model_sizes:
|
||||
# Get original image size and model size (using the first image)
|
||||
orig_size = original_sizes[0]
|
||||
model_size = model_sizes[0]
|
||||
|
||||
# Check if output contains box tokens that need remapping
|
||||
# Check if output contains box tokens that need processing
|
||||
if "<|box_start|>" in output:
|
||||
# Remap coordinates from model space back to original image space
|
||||
# We just swap the arguments - from force_resolution back to original size
|
||||
output = self._remap_coordinates(output, self.force_resolution, orig_size)
|
||||
# Process coordinates from model space back to original image space
|
||||
output = self._process_coordinates(output, orig_size, model_size)
|
||||
|
||||
# Format response to match OpenAI format
|
||||
response = {
|
||||
|
||||
Reference in New Issue
Block a user