increased max tokens, added trust_remote_code kwarg

This commit is contained in:
Dillon DuPont
2025-08-26 18:15:24 -04:00
parent 52afcd4c6f
commit bf3c3256df
9 changed files with 28 additions and 17 deletions

View File

@@ -20,15 +20,17 @@ from .models import load_model as load_model_handler
class HuggingFaceLocalAdapter(CustomLLM):
"""HuggingFace Local Adapter for running vision-language models locally."""
def __init__(self, device: str = "auto", **kwargs):
def __init__(self, device: str = "auto", trust_remote_code: bool = False, **kwargs):
"""Initialize the adapter.
Args:
device: Device to load model on ("auto", "cuda", "cpu", etc.)
trust_remote_code: Whether to trust remote code
**kwargs: Additional arguments
"""
super().__init__()
self.device = device
self.trust_remote_code = trust_remote_code
# Cache for model handlers keyed by model_name
self._handlers: Dict[str, Any] = {}
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
@@ -36,7 +38,7 @@ class HuggingFaceLocalAdapter(CustomLLM):
def _get_handler(self, model_name: str):
"""Get or create a model handler for the given model name."""
if model_name not in self._handlers:
self._handlers[model_name] = load_model_handler(model_name=model_name, device=self.device)
self._handlers[model_name] = load_model_handler(model_name=model_name, device=self.device, trust_remote_code=self.trust_remote_code)
return self._handlers[model_name]
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:

View File

@@ -10,7 +10,7 @@ from .generic import GenericHFModel
from .opencua import OpenCUAModel
def load_model(model_name: str, device: str = "auto"):
def load_model(model_name: str, device: str = "auto", trust_remote_code: bool = False):
"""Factory function to load and return the right model handler instance.
- If the underlying transformers config class matches OpenCUA, return OpenCUAModel
@@ -20,9 +20,9 @@ def load_model(model_name: str, device: str = "auto"):
raise ImportError(
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
)
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
cls = cfg.__class__.__name__
# print(f"cls: {cls}")
if "OpenCUA" in cls:
return OpenCUAModel(model_name=model_name, device=device)
return GenericHFModel(model_name=model_name, device=device)
return OpenCUAModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
return GenericHFModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)

View File

@@ -3,7 +3,7 @@ from typing import List, Dict, Any, Optional
# Hugging Face imports are local to avoid hard dependency at module import
try:
import torch # type: ignore
from transformers import AutoModelForImageTextToText, AutoProcessor # type: ignore
from transformers import AutoModel, AutoProcessor # type: ignore
HF_AVAILABLE = True
except Exception:
HF_AVAILABLE = False
@@ -14,7 +14,7 @@ class GenericHFModel:
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
"""
def __init__(self, model_name: str, device: str = "auto") -> None:
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
if not HF_AVAILABLE:
raise ImportError(
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
@@ -23,15 +23,17 @@ class GenericHFModel:
self.device = device
self.model = None
self.processor = None
self.trust_remote_code = trust_remote_code
self._load()
def _load(self) -> None:
# Load model
self.model = AutoModelForImageTextToText.from_pretrained(
self.model = AutoModel.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map=self.device,
attn_implementation="sdpa",
trust_remote_code=self.trust_remote_code,
)
# Load processor
self.processor = AutoProcessor.from_pretrained(
@@ -39,6 +41,7 @@ class GenericHFModel:
min_pixels=3136,
max_pixels=4096 * 2160,
device_map=self.device,
trust_remote_code=self.trust_remote_code,
)
def generate(self, messages: List[Dict[str, Any]], max_new_tokens: int = 128) -> str:

View File

@@ -16,7 +16,7 @@ except Exception:
class OpenCUAModel:
"""OpenCUA model handler using AutoTokenizer, AutoModel and AutoImageProcessor."""
def __init__(self, model_name: str, device: str = "auto") -> None:
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
if not OPENCUA_AVAILABLE:
raise ImportError(
"OpenCUA requirements not found. Install with: pip install \"cua-agent[opencua-hf]\""
@@ -26,21 +26,22 @@ class OpenCUAModel:
self.model = None
self.tokenizer = None
self.image_processor = None
self.trust_remote_code = trust_remote_code
self._load()
def _load(self) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name, trust_remote_code=True
self.model_name, trust_remote_code=self.trust_remote_code
)
self.model = AutoModel.from_pretrained(
self.model_name,
torch_dtype="auto",
device_map=self.device,
trust_remote_code=True,
trust_remote_code=self.trust_remote_code,
attn_implementation="sdpa",
)
self.image_processor = AutoImageProcessor.from_pretrained(
self.model_name, trust_remote_code=True
self.model_name, trust_remote_code=self.trust_remote_code
)
@staticmethod

View File

@@ -166,6 +166,7 @@ class ComputerAgent:
use_prompt_caching: Optional[bool] = False,
max_trajectory_budget: Optional[float | dict] = None,
telemetry_enabled: Optional[bool] = True,
trust_remote_code: Optional[bool] = False,
**kwargs
):
"""
@@ -184,6 +185,7 @@ class ComputerAgent:
use_prompt_caching: If set, use prompt caching to avoid reprocessing the same prompt. Intended for use with anthropic providers.
max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
**kwargs: Additional arguments passed to the agent loop
"""
self.model = model
@@ -198,6 +200,7 @@ class ComputerAgent:
self.use_prompt_caching = use_prompt_caching
self.telemetry_enabled = telemetry_enabled
self.kwargs = kwargs
self.trust_remote_code = trust_remote_code
# == Add built-in callbacks ==
@@ -231,7 +234,8 @@ class ComputerAgent:
# Register local model providers
hf_adapter = HuggingFaceLocalAdapter(
device="auto"
device="auto",
trust_remote_code=self.trust_remote_code or False
)
human_adapter = HumanAdapter()
litellm.custom_provider_map = [

View File

@@ -331,6 +331,7 @@ Examples:
agent_kwargs = {
"model": args.model,
"tools": [computer],
"trust_remote_code": True, # needed for some local models (e.g., InternVL, OpenCUA)
"verbosity": 20 if args.verbose else 30, # DEBUG vs WARNING
"max_retries": args.max_retries
}

View File

@@ -155,7 +155,7 @@ class GTA1Config(AsyncAgentConfig):
api_kwargs = {
"model": model,
"messages": [system_message, user_message],
"max_tokens": 32,
"max_tokens": 2056,
"temperature": 0.0,
**kwargs
}

View File

@@ -106,7 +106,7 @@ class OpenCUAConfig(AsyncAgentConfig):
api_kwargs = {
"model": model,
"messages": [system_message, user_message],
"max_new_tokens": 512,
"max_new_tokens": 2056,
"temperature": 0,
**kwargs
}

View File

@@ -771,7 +771,7 @@ class UITARSConfig:
api_kwargs = {
"model": model,
"messages": litellm_messages,
"max_tokens": 100,
"max_tokens": 2056,
"temperature": 0.0,
"do_sample": False
}