mirror of
https://github.com/trycua/computer.git
synced 2026-01-07 05:50:13 -06:00
added forced resolution
This commit is contained in:
@@ -5,7 +5,8 @@ import logging
|
||||
import base64
|
||||
import tempfile
|
||||
import os
|
||||
from typing import Dict, List, Optional, Any, cast
|
||||
import re
|
||||
from typing import Dict, List, Optional, Any, cast, Tuple
|
||||
from PIL import Image
|
||||
|
||||
from .base import BaseUITarsClient
|
||||
@@ -21,11 +22,17 @@ logger = logging.getLogger(__name__)
|
||||
class MLXVLMUITarsClient(BaseUITarsClient):
|
||||
"""MLX LVM client implementation class."""
|
||||
|
||||
def __init__(self, model: str = "mlx-community/UI-TARS-1.5-7B-4bit"):
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "mlx-community/UI-TARS-1.5-7B-4bit",
|
||||
force_resolution: Optional[Tuple[int, int]] = (1512, 982)
|
||||
):
|
||||
"""Initialize MLX LVM client.
|
||||
|
||||
Args:
|
||||
model: Model name or path (defaults to mlx-community/UI-TARS-1.5-7B-4bit)
|
||||
force_resolution: Optional target resolution to resize images to (width, height).
|
||||
If None, images will not be resized.
|
||||
"""
|
||||
# Load model and processor
|
||||
model_obj, processor = load(model)
|
||||
@@ -33,8 +40,32 @@ class MLXVLMUITarsClient(BaseUITarsClient):
|
||||
self.model = model_obj
|
||||
self.processor = processor
|
||||
self.model_name = model
|
||||
self.force_resolution = force_resolution
|
||||
|
||||
|
||||
def _remap_coordinates(self, text: str, original_size: Tuple[int, int], target_size: Tuple[int, int]) -> str:
|
||||
"""Remap coordinates in box tokens based on image resizing.
|
||||
|
||||
Args:
|
||||
text: Text containing box tokens
|
||||
original_size: Original image size (width, height)
|
||||
target_size: Target image size (width, height)
|
||||
|
||||
Returns:
|
||||
Text with remapped coordinates
|
||||
"""
|
||||
# Find all box tokens
|
||||
box_pattern = r"<\|box_start\|>\((\d+),\s*(\d+)\)<\|box_end\|>"
|
||||
|
||||
def remap_coords(match):
|
||||
x, y = int(match.group(1)), int(match.group(2))
|
||||
# Scale coordinates to new dimensions
|
||||
new_x = int(x * target_size[0] / original_size[0])
|
||||
new_y = int(y * target_size[1] / original_size[1])
|
||||
return f"<|box_start|>({new_x},{new_y})<|box_end|>"
|
||||
|
||||
return re.sub(box_pattern, remap_coords, text)
|
||||
|
||||
async def run_interleaved(
|
||||
self, messages: List[Dict[str, Any]], system: str, max_tokens: Optional[int] = None
|
||||
) -> Dict[str, Any]:
|
||||
@@ -51,32 +82,79 @@ class MLXVLMUITarsClient(BaseUITarsClient):
|
||||
# Ensure the system message is included
|
||||
if not any(msg.get("role") == "system" for msg in messages):
|
||||
messages = [{"role": "system", "content": system}] + messages
|
||||
|
||||
# Extract any images from the messages
|
||||
|
||||
# Create a deep copy of messages to avoid modifying the original
|
||||
processed_messages = messages.copy()
|
||||
|
||||
# Extract images and process messages if force_resolution is set
|
||||
images = []
|
||||
for msg in messages:
|
||||
original_sizes = {} # Track original sizes of images for coordinate remapping
|
||||
image_index = 0
|
||||
|
||||
for msg_idx, msg in enumerate(messages):
|
||||
content = msg.get("content", [])
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if item.get("type") == "image_url":
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
if image_url.startswith("data:image/"):
|
||||
# Extract base64 data
|
||||
base64_data = image_url.split(',')[1]
|
||||
|
||||
# Convert base64 to PIL Image
|
||||
image_data = base64.b64decode(base64_data)
|
||||
pil_image = Image.open(io.BytesIO(image_data))
|
||||
images.append(pil_image)
|
||||
else:
|
||||
# Handle file path or URL
|
||||
pil_image = Image.open(image_url)
|
||||
images.append(pil_image)
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
# Create a copy of the content list to modify
|
||||
processed_content = []
|
||||
|
||||
for item_idx, item in enumerate(content):
|
||||
if item.get("type") == "image_url":
|
||||
image_url = item.get("image_url", {}).get("url", "")
|
||||
pil_image = None
|
||||
|
||||
if image_url.startswith("data:image/"):
|
||||
# Extract base64 data
|
||||
base64_data = image_url.split(',')[1]
|
||||
# Convert base64 to PIL Image
|
||||
image_data = base64.b64decode(base64_data)
|
||||
pil_image = Image.open(io.BytesIO(image_data))
|
||||
else:
|
||||
# Handle file path or URL
|
||||
pil_image = Image.open(image_url)
|
||||
|
||||
# Store original image size for coordinate mapping
|
||||
original_sizes[image_index] = pil_image.size
|
||||
|
||||
# Resize image if force_resolution is set
|
||||
if self.force_resolution:
|
||||
pil_image = pil_image.resize(self.force_resolution)
|
||||
|
||||
images.append(pil_image)
|
||||
image_index += 1
|
||||
|
||||
# Copy items to processed content list
|
||||
processed_content.append(item.copy())
|
||||
|
||||
# Update the processed message content
|
||||
processed_messages[msg_idx] = msg.copy()
|
||||
processed_messages[msg_idx]["content"] = processed_content
|
||||
|
||||
# Remap coordinates in messages with box tokens if force_resolution is set
|
||||
if self.force_resolution and original_sizes:
|
||||
for msg_idx, msg in enumerate(processed_messages):
|
||||
content = msg.get("content", [])
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
|
||||
for item_idx, item in enumerate(content):
|
||||
if item.get("type") == "text":
|
||||
text_content = item.get("text", "")
|
||||
|
||||
# Check if there are any box tokens to remap
|
||||
if "<|box_start|>" in text_content:
|
||||
# Use the first image's dimensions as reference (most common case)
|
||||
if 0 in original_sizes:
|
||||
orig_size = original_sizes[0]
|
||||
processed_messages[msg_idx]["content"][item_idx]["text"] = self._remap_coordinates(
|
||||
text_content, orig_size, self.force_resolution
|
||||
)
|
||||
|
||||
try:
|
||||
# Format prompt according to model requirements using the processor directly
|
||||
prompt = self.processor.apply_chat_template(
|
||||
messages,
|
||||
processed_messages, # Use processed messages instead of original
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
@@ -108,6 +186,17 @@ class MLXVLMUITarsClient(BaseUITarsClient):
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
# Remap coordinates in the response back to original image space if needed
|
||||
if self.force_resolution and original_sizes and 0 in original_sizes:
|
||||
# Get original image size (using the first image)
|
||||
orig_size = original_sizes[0]
|
||||
|
||||
# Check if output contains box tokens that need remapping
|
||||
if "<|box_start|>" in output:
|
||||
# Remap coordinates from model space back to original image space
|
||||
# We just swap the arguments - from force_resolution back to original size
|
||||
output = self._remap_coordinates(output, self.force_resolution, orig_size)
|
||||
|
||||
# Format response to match OpenAI format
|
||||
response = {
|
||||
"choices": [
|
||||
|
||||
Reference in New Issue
Block a user