From 6ddddf8f880dd1eb86a724b395a0a57ba0bba7e6 Mon Sep 17 00:00:00 2001 From: Dillon DuPont Date: Tue, 16 Sep 2025 12:56:07 -0400 Subject: [PATCH] fix internVL inference --- .../agent/agent/adapters/models/__init__.py | 4 +- .../agent/agent/adapters/models/internvl.py | 247 +++++++++++++++--- libs/python/agent/agent/loops/internvl.py | 12 +- 3 files changed, 222 insertions(+), 41 deletions(-) diff --git a/libs/python/agent/agent/adapters/models/__init__.py b/libs/python/agent/agent/adapters/models/__init__.py index b36fda1b..3ed48404 100644 --- a/libs/python/agent/agent/adapters/models/__init__.py +++ b/libs/python/agent/agent/adapters/models/__init__.py @@ -26,8 +26,8 @@ def load_model(model_name: str, device: str = "auto", trust_remote_code: bool = print(f"cls: {cls}") if "OpenCUA" in cls: return OpenCUAModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code) - elif "Qwen2_5_VLConfig" in cls: + elif "Qwen2_5_VL" in cls: return Qwen2_5_VLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code) - elif "InternVLChatConfig" in cls: + elif "InternVL" in cls: return InternVLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code) return GenericHFModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code) diff --git a/libs/python/agent/agent/adapters/models/internvl.py b/libs/python/agent/agent/adapters/models/internvl.py index 0ed32e6b..bb2de42e 100644 --- a/libs/python/agent/agent/adapters/models/internvl.py +++ b/libs/python/agent/agent/adapters/models/internvl.py @@ -3,10 +3,16 @@ from typing import List, Dict, Any, Optional # Hugging Face imports are local to avoid hard dependency at module import try: import torch # type: ignore - from transformers import AutoModel, AutoProcessor # type: ignore + from transformers import AutoModel, AutoTokenizer # type: ignore # Attempt to import InternVL's model dependencies import einops as _ # type: ignore import timm as _ # type: ignore + from PIL import Image # type: ignore + import torchvision.transforms as T # type: ignore + from torchvision.transforms.functional import InterpolationMode # type: ignore + import base64 # type: ignore + from io import BytesIO # type: ignore + import requests # type: ignore HF_AVAILABLE = True except Exception: HF_AVAILABLE = False @@ -14,7 +20,8 @@ except Exception: class InternVLModel: """Generic Hugging Face vision-language model handler. - Loads an AutoModelForImageTextToText and AutoProcessor and generates text. + Uses InternVL's native `model.chat()` interface with `AutoTokenizer`. + Provides preprocessing to support multi-turn conversations with multiple images. """ def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None: @@ -25,7 +32,7 @@ class InternVLModel: self.model_name = model_name self.device = device self.model = None - self.processor = None + self.tokenizer = None self.trust_remote_code = trust_remote_code self._load() @@ -33,46 +40,214 @@ class InternVLModel: # Load model self.model = AutoModel.from_pretrained( self.model_name, - torch_dtype=torch.float16, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + use_flash_attn=True, device_map=self.device, - attn_implementation="sdpa", trust_remote_code=self.trust_remote_code, - ) - # Load processor - self.processor = AutoProcessor.from_pretrained( + ).eval() + # Load tokenizer (InternVL requires trust_remote_code=True and often use_fast=False) + self.tokenizer = AutoTokenizer.from_pretrained( self.model_name, - min_pixels=3136, - max_pixels=4096 * 2160, - device_map=self.device, trust_remote_code=self.trust_remote_code, + use_fast=False, ) + # ---- Image preprocessing utilities adapted from InternVL docs ---- + IMAGENET_MEAN = (0.485, 0.456, 0.406) + IMAGENET_STD = (0.229, 0.224, 0.225) + + def _build_transform(self, input_size: int) -> T.Compose: + MEAN, STD = self.IMAGENET_MEAN, self.IMAGENET_STD + transform = T.Compose([ + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD) + ]) + return transform + + def _find_closest_aspect_ratio(self, aspect_ratio: float, target_ratios: List[tuple], width: int, height: int, image_size: int): + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def _dynamic_preprocess(self, image: Image.Image, min_num: int = 1, max_num: int = 12, image_size: int = 448, use_thumbnail: bool = True) -> List[Image.Image]: + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if + i * j <= max_num and i * j >= min_num) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + target_aspect_ratio = self._find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size) + + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + resized_img = image.resize((target_width, target_height)) + processed_images: List[Image.Image] = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + def _load_image_from_source(self, src: str) -> Image.Image: + """Load PIL image from various sources: data URL, http(s), or local path.""" + if src.startswith("data:image/"): + # data URL base64 + header, b64data = src.split(",", 1) + img_bytes = base64.b64decode(b64data) + return Image.open(BytesIO(img_bytes)).convert('RGB') + if src.startswith("http://") or src.startswith("https://"): + resp = requests.get(src, timeout=10) + resp.raise_for_status() + return Image.open(BytesIO(resp.content)).convert('RGB') + # Assume local file path + return Image.open(src).convert('RGB') + + def _images_to_pixel_values(self, images: List[Image.Image], input_size: int = 448, max_num: int = 12): + transform = self._build_transform(input_size=input_size) + pixel_values_list = [] + num_patches_list: List[int] = [] + for img in images: + tiles = self._dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num) + pv = [transform(tile) for tile in tiles] + pv = torch.stack(pv) + num_patches_list.append(pv.shape[0]) + pixel_values_list.append(pv) + if not pixel_values_list: + return None, [] + pixel_values = torch.cat(pixel_values_list) + return pixel_values, num_patches_list + def generate(self, messages: List[Dict[str, Any]], max_new_tokens: int = 128) -> str: """Generate text for the given HF-format messages. messages: [{ role, content: [{type:'text'|'image', text|image}] }] + + This implementation constructs InternVL-compatible inputs and uses + `model.chat(tokenizer, pixel_values, question, history=...)` to avoid + relying on AutoProcessor (which fails for some tokenizers). """ - assert self.model is not None and self.processor is not None - # Apply chat template and tokenize - inputs = self.processor.apply_chat_template( - messages, - add_generation_prompt=True, - tokenize=True, - return_dict=True, - return_tensors="pt", - ) - # Move inputs to the same device as model - inputs = inputs.to(self.model.device) - # Generate - with torch.no_grad(): - generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens) - # Trim prompt tokens from output - generated_ids_trimmed = [ - out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) - ] - # Decode - output_text = self.processor.batch_decode( - generated_ids_trimmed, - skip_special_tokens=True, - clean_up_tokenization_spaces=False, - ) - return output_text[0] if output_text else "" + assert self.model is not None and self.tokenizer is not None + + # Build textual context and collect images and the final question + context_lines: List[str] = [] + all_images: List[Image.Image] = [] + last_user_text_parts: List[str] = [] + + for msg in messages: + role = msg.get("role", "user") + content = msg.get("content", []) + if isinstance(content, str): + content_items = [{"type": "text", "text": content}] + else: + content_items = content + + if role == "user": + # Collect text and images + parts_text: List[str] = [] + for item in content_items: + if item.get("type") == "text": + t = item.get("text", "") + if t: + parts_text.append(t) + elif item.get("type") == "image": + url = item.get("image", "") + if url: + try: + all_images.append(self._load_image_from_source(url)) + except Exception: + # Ignore failed image loads but keep going + pass + text = "\n".join(parts_text).strip() + if text: + context_lines.append(f"User: {text}") + # Track last user text separately for question + last_user_text_parts = parts_text or last_user_text_parts + elif role == "assistant": + # Only keep text content for history + parts_text = [item.get("text", "") for item in content_items if item.get("type") == "text"] + text = "\n".join(parts_text).strip() + if text: + context_lines.append(f"Assistant: {text}") + + # Prepare pixel values for all collected images (across turns) + pixel_values = None + num_patches_list: List[int] = [] + if all_images: + pixel_values, num_patches_list = self._images_to_pixel_values(all_images, input_size=448, max_num=12) + if pixel_values is not None: + # Convert dtype/device as in docs + pixel_values = pixel_values.to(torch.bfloat16) + # Chat API expects tensors on CUDA when model is on CUDA + try: + pixel_values = pixel_values.to(self.model.device) + except Exception: + pass + + # Build question with any prior context and numbered image placeholders + if all_images: + # Separate images layout: Image-1: ... then question text + prefix_lines = [f"Image-{i+1}: " for i in range(len(all_images))] + prefix = "\n".join(prefix_lines) + "\n" + else: + prefix = "" + + last_user_text = "\n".join(last_user_text_parts).strip() + # Combine prior text-only turns as context to emulate multi-turn + context_text = "\n".join(context_lines[:-1]) if len(context_lines) > 1 else "" + base_question = last_user_text if last_user_text else "Describe the image(s) in detail." + if context_text: + question = (context_text + "\n" + prefix + base_question).strip() + else: + question = (prefix + base_question).strip() + + # Generation config + generation_config = dict(max_new_tokens=max_new_tokens, do_sample=False) + + # Call InternVL chat + try: + if pixel_values is None: + # Pure-text conversation (embed prior turns in question) + response = self.model.chat(self.tokenizer, None, question, generation_config) + else: + # Multi-image: pass num_patches_list if >1 image + if len(num_patches_list) > 1: + response = self.model.chat( + self.tokenizer, + pixel_values, + question, + generation_config, + num_patches_list=num_patches_list, + ) + else: + response = self.model.chat(self.tokenizer, pixel_values, question, generation_config) + except Exception as e: + # Fallback: return empty string to avoid crashing the adapter + return "" + + return response or "" diff --git a/libs/python/agent/agent/loops/internvl.py b/libs/python/agent/agent/loops/internvl.py index d1b8c3fe..a857ffe3 100644 --- a/libs/python/agent/agent/loops/internvl.py +++ b/libs/python/agent/agent/loops/internvl.py @@ -26,9 +26,13 @@ from .composed_grounded import ComposedGroundedConfig from ..types import AgentCapability -# Regex patterns matching ScreenSpot baseline extractors -_POINT_PATTERN = re.compile(r"\[\[(\d+),(\d+)\]\]") -_BBOX_PATTERN = re.compile(r"\[\[(\d+),(\d+),(\d+),(\d+)\]\]") +# Regex patterns for extracting coordinates +# Accept optional whitespace and optional decimal fractions +_NUM = r"(\d+(?:\.\d+)?)" +_POINT_PATTERN = re.compile(r"\[\[\s*" + _NUM + r"\s*,\s*" + _NUM + r"\s*\]\]") +_BBOX_PATTERN = re.compile( + r"\[\[\s*" + _NUM + r"\s*,\s*" + _NUM + r"\s*,\s*" + _NUM + r"\s*,\s*" + _NUM + r"\s*\]\]" +) def _extract_first_point(text: str) -> Optional[Tuple[float, float]]: @@ -160,6 +164,8 @@ class InternVLConfig(ComposedGroundedConfig): response = await litellm.acompletion(**api_kwargs) output_text = (response.choices[0].message.content or "").strip() # type: ignore + print(f"InternVL output: {output_text}") + # Try to parse a point first; if absent, parse bbox and take center point = _extract_first_point(output_text) if point is None: