Added local inference routing for different models

This commit is contained in:
Dillon DuPont
2025-08-21 10:51:39 -04:00
parent d7e25048be
commit dad6634ffd
6 changed files with 236 additions and 85 deletions

View File

@@ -15,6 +15,7 @@ try:
except ImportError:
HF_AVAILABLE = False
from .models import load_model as load_model_handler
class HuggingFaceLocalAdapter(CustomLLM):
"""HuggingFace Local Adapter for running vision-language models locally."""
@@ -28,41 +29,15 @@ class HuggingFaceLocalAdapter(CustomLLM):
"""
super().__init__()
self.device = device
self.models = {} # Cache for loaded models
self.processors = {} # Cache for loaded processors
# Cache for model handlers keyed by model_name
self._handlers: Dict[str, Any] = {}
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
def _load_model_and_processor(self, model_name: str):
"""Load model and processor if not already cached.
Args:
model_name: Name of the model to load
Returns:
Tuple of (model, processor)
"""
if model_name not in self.models:
# Load model
model = AutoModelForImageTextToText.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map=self.device,
attn_implementation="sdpa"
)
# Load processor
processor = AutoProcessor.from_pretrained(
model_name,
min_pixels=3136,
max_pixels=4096 * 2160,
device_map=self.device
)
# Cache them
self.models[model_name] = model
self.processors[model_name] = processor
return self.models[model_name], self.processors[model_name]
def _get_handler(self, model_name: str):
"""Get or create a model handler for the given model name."""
if model_name not in self._handlers:
self._handlers[model_name] = load_model_handler(model_name=model_name, device=self.device)
return self._handlers[model_name]
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Convert OpenAI format messages to HuggingFace format.
@@ -133,41 +108,13 @@ class HuggingFaceLocalAdapter(CustomLLM):
if ignored_kwargs:
warnings.warn(f"Ignoring unsupported kwargs: {ignored_kwargs}")
# Load model and processor
model, processor = self._load_model_and_processor(model_name)
# Convert messages to HuggingFace format
hf_messages = self._convert_messages(messages)
# Apply chat template and tokenize
inputs = processor.apply_chat_template(
hf_messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
)
# Move inputs to the same device as model
inputs = inputs.to(model.device)
# Generate response
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
# Trim input tokens from output
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
# Decode output
output_text = processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)
return output_text[0] if output_text else ""
# Delegate to model handler
handler = self._get_handler(model_name)
generated_text = handler.generate(hf_messages, max_new_tokens=max_new_tokens)
return generated_text
def completion(self, *args, **kwargs) -> ModelResponse:
"""Synchronous completion method.

View File

@@ -0,0 +1,28 @@
from typing import Optional
try:
from transformers import AutoConfig
HF_AVAILABLE = True
except ImportError:
HF_AVAILABLE = False
from .generic import GenericHFModel
from .opencua import OpenCUAModel
def load_model(model_name: str, device: str = "auto"):
"""Factory function to load and return the right model handler instance.
- If the underlying transformers config class matches OpenCUA, return OpenCUAModel
- Otherwise, return GenericHFModel
"""
if not HF_AVAILABLE:
raise ImportError(
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
)
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
cls = cfg.__class__.__name__
print(f"cls: {cls}")
if "OpenCUA" in cls:
return OpenCUAModel(model_name=model_name, device=device)
return GenericHFModel(model_name=model_name, device=device)

View File

@@ -0,0 +1,72 @@
from typing import List, Dict, Any, Optional
# Hugging Face imports are local to avoid hard dependency at module import
try:
import torch # type: ignore
from transformers import AutoModelForImageTextToText, AutoProcessor # type: ignore
HF_AVAILABLE = True
except Exception:
HF_AVAILABLE = False
class GenericHFModel:
"""Generic Hugging Face vision-language model handler.
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
"""
def __init__(self, model_name: str, device: str = "auto") -> None:
if not HF_AVAILABLE:
raise ImportError(
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
)
self.model_name = model_name
self.device = device
self.model = None
self.processor = None
self._load()
def _load(self) -> None:
# Load model
self.model = AutoModelForImageTextToText.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map=self.device,
attn_implementation="sdpa",
)
# Load processor
self.processor = AutoProcessor.from_pretrained(
self.model_name,
min_pixels=3136,
max_pixels=4096 * 2160,
device_map=self.device,
)
def generate(self, messages: List[Dict[str, Any]], max_new_tokens: int = 128) -> str:
"""Generate text for the given HF-format messages.
messages: [{ role, content: [{type:'text'|'image', text|image}] }]
"""
assert self.model is not None and self.processor is not None
# Apply chat template and tokenize
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
# Move inputs to the same device as model
inputs = inputs.to(self.model.device)
# Generate
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
# Trim prompt tokens from output
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
# Decode
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
return output_text[0] if output_text else ""

View File

@@ -0,0 +1,98 @@
from typing import List, Dict, Any
import re
import base64
from io import BytesIO
try:
import torch # type: ignore
from transformers import AutoTokenizer, AutoModel, AutoImageProcessor # type: ignore
from PIL import Image # type: ignore
import blobfile as _ # assert blobfile is installed
OPENCUA_AVAILABLE = True
except Exception:
OPENCUA_AVAILABLE = False
class OpenCUAModel:
"""OpenCUA model handler using AutoTokenizer, AutoModel and AutoImageProcessor."""
def __init__(self, model_name: str, device: str = "auto") -> None:
if not OPENCUA_AVAILABLE:
raise ImportError(
"OpenCUA requirements not found. Install with: pip install \"cua-agent[opencua-hf]\""
)
self.model_name = model_name
self.device = device
self.model = None
self.tokenizer = None
self.image_processor = None
self._load()
def _load(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name, trust_remote_code=True
)
self.model = AutoModel.from_pretrained(
self.model_name,
torch_dtype="auto",
device_map=self.device,
trust_remote_code=True,
)
self.image_processor = AutoImageProcessor.from_pretrained(
self.model_name, trust_remote_code=True
)
@staticmethod
def _extract_last_image_b64(messages: List[Dict[str, Any]]) -> str:
# Expect HF-format messages with content items type: "image" with data URL
for msg in reversed(messages):
for item in reversed(msg.get("content", [])):
if isinstance(item, dict) and item.get("type") == "image":
url = item.get("image", "")
if isinstance(url, str) and url.startswith("data:image/"):
return url.split(",", 1)[1]
return ""
def generate(self, messages: List[Dict[str, Any]], max_new_tokens: int = 512) -> str:
assert self.model is not None and self.tokenizer is not None and self.image_processor is not None
# Tokenize text side using chat template
input_ids = self.tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True
)
input_ids = torch.tensor([input_ids]).to(self.model.device)
# Prepare image inputs from last data URL image
image_b64 = self._extract_last_image_b64(messages)
pixel_values = None
grid_thws = None
if image_b64:
image = Image.open(BytesIO(base64.b64decode(image_b64))).convert("RGB")
image_info = self.image_processor.preprocess(images=[image])
pixel_values = torch.tensor(image_info["pixel_values"]).to(
dtype=torch.bfloat16, device=self.model.device
)
grid_thws = torch.tensor(image_info["image_grid_thw"]) if "image_grid_thw" in image_info else None
gen_kwargs: Dict[str, Any] = {
"max_new_tokens": max_new_tokens,
"temperature": 0,
}
if pixel_values is not None:
gen_kwargs["pixel_values"] = pixel_values
if grid_thws is not None:
gen_kwargs["grid_thws"] = grid_thws
with torch.no_grad():
generated_ids = self.model.generate(
input_ids,
**gen_kwargs,
)
# Remove prompt tokens
prompt_len = input_ids.shape[1]
generated_ids = generated_ids[:, prompt_len:]
output_text = self.tokenizer.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
return output_text

View File

@@ -90,8 +90,10 @@ class OpenCUAConfig(AsyncAgentConfig):
"role": "user",
"content": [
{
"type": "image",
"image": f"data:image/png;base64,{image_b64}"
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{image_b64}"
}
},
{
"type": "text",
@@ -109,24 +111,18 @@ class OpenCUAConfig(AsyncAgentConfig):
**kwargs
}
try:
# Use liteLLM acompletion
response = await litellm.acompletion(**api_kwargs)
# Extract response text
output_text = response.choices[0].message.content
if not output_text:
return None
# Extract coordinates from pyautogui format
coordinates = extract_coordinates_from_pyautogui(output_text)
return coordinates
except Exception as e:
print(f"Error in OpenCUA predict_click: {e}")
return None
# Use liteLLM acompletion
response = await litellm.acompletion(**api_kwargs)
# Extract response text
output_text = response.choices[0].message.content
print(output_text)
# Extract coordinates from pyautogui format
coordinates = extract_coordinates_from_pyautogui(output_text)
return coordinates
def get_capabilities(self) -> List[AgentCapability]:
"""Return the capabilities supported by this agent."""

View File

@@ -47,6 +47,13 @@ glm45v-hf = [
"torch",
"transformers-v4.55.0-GLM-4.5V-preview"
]
opencua-hf = [
"accelerate",
"torch",
"transformers>=4.54.0",
"tiktoken>=0.11.0",
"blobfile>=3.0.0"
]
ui = [
"gradio>=5.23.3",
"python-dotenv>=1.0.1",
@@ -66,6 +73,9 @@ all = [
"accelerate",
"torch",
"transformers>=4.54.0",
# opencua requirements
"tiktoken>=0.11.0",
"blobfile>=3.0.0"
# ui requirements
"gradio>=5.23.3",
"python-dotenv>=1.0.1",