mirror of
https://github.com/trycua/computer.git
synced 2026-01-05 04:50:08 -06:00
increased max tokens, added trust_remote_code kwarg
This commit is contained in:
@@ -20,15 +20,17 @@ from .models import load_model as load_model_handler
|
||||
class HuggingFaceLocalAdapter(CustomLLM):
|
||||
"""HuggingFace Local Adapter for running vision-language models locally."""
|
||||
|
||||
def __init__(self, device: str = "auto", **kwargs):
|
||||
def __init__(self, device: str = "auto", trust_remote_code: bool = False, **kwargs):
|
||||
"""Initialize the adapter.
|
||||
|
||||
Args:
|
||||
device: Device to load model on ("auto", "cuda", "cpu", etc.)
|
||||
trust_remote_code: Whether to trust remote code
|
||||
**kwargs: Additional arguments
|
||||
"""
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.trust_remote_code = trust_remote_code
|
||||
# Cache for model handlers keyed by model_name
|
||||
self._handlers: Dict[str, Any] = {}
|
||||
self._executor = ThreadPoolExecutor(max_workers=1) # Single thread pool
|
||||
@@ -36,7 +38,7 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
||||
def _get_handler(self, model_name: str):
|
||||
"""Get or create a model handler for the given model name."""
|
||||
if model_name not in self._handlers:
|
||||
self._handlers[model_name] = load_model_handler(model_name=model_name, device=self.device)
|
||||
self._handlers[model_name] = load_model_handler(model_name=model_name, device=self.device, trust_remote_code=self.trust_remote_code)
|
||||
return self._handlers[model_name]
|
||||
|
||||
def _convert_messages(self, messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
|
||||
@@ -10,7 +10,7 @@ from .generic import GenericHFModel
|
||||
from .opencua import OpenCUAModel
|
||||
|
||||
|
||||
def load_model(model_name: str, device: str = "auto"):
|
||||
def load_model(model_name: str, device: str = "auto", trust_remote_code: bool = False):
|
||||
"""Factory function to load and return the right model handler instance.
|
||||
|
||||
- If the underlying transformers config class matches OpenCUA, return OpenCUAModel
|
||||
@@ -20,9 +20,9 @@ def load_model(model_name: str, device: str = "auto"):
|
||||
raise ImportError(
|
||||
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
|
||||
)
|
||||
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
cls = cfg.__class__.__name__
|
||||
# print(f"cls: {cls}")
|
||||
if "OpenCUA" in cls:
|
||||
return OpenCUAModel(model_name=model_name, device=device)
|
||||
return GenericHFModel(model_name=model_name, device=device)
|
||||
return OpenCUAModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
return GenericHFModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import List, Dict, Any, Optional
|
||||
# Hugging Face imports are local to avoid hard dependency at module import
|
||||
try:
|
||||
import torch # type: ignore
|
||||
from transformers import AutoModelForImageTextToText, AutoProcessor # type: ignore
|
||||
from transformers import AutoModel, AutoProcessor # type: ignore
|
||||
HF_AVAILABLE = True
|
||||
except Exception:
|
||||
HF_AVAILABLE = False
|
||||
@@ -14,7 +14,7 @@ class GenericHFModel:
|
||||
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "auto") -> None:
|
||||
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"HuggingFace transformers dependencies not found. Install with: pip install \"cua-agent[uitars-hf]\""
|
||||
@@ -23,15 +23,17 @@ class GenericHFModel:
|
||||
self.device = device
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self._load()
|
||||
|
||||
def _load(self) -> None:
|
||||
# Load model
|
||||
self.model = AutoModelForImageTextToText.from_pretrained(
|
||||
self.model = AutoModel.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=self.device,
|
||||
attn_implementation="sdpa",
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
# Load processor
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
@@ -39,6 +41,7 @@ class GenericHFModel:
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160,
|
||||
device_map=self.device,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
|
||||
def generate(self, messages: List[Dict[str, Any]], max_new_tokens: int = 128) -> str:
|
||||
|
||||
@@ -16,7 +16,7 @@ except Exception:
|
||||
class OpenCUAModel:
|
||||
"""OpenCUA model handler using AutoTokenizer, AutoModel and AutoImageProcessor."""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "auto") -> None:
|
||||
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
|
||||
if not OPENCUA_AVAILABLE:
|
||||
raise ImportError(
|
||||
"OpenCUA requirements not found. Install with: pip install \"cua-agent[opencua-hf]\""
|
||||
@@ -26,21 +26,22 @@ class OpenCUAModel:
|
||||
self.model = None
|
||||
self.tokenizer = None
|
||||
self.image_processor = None
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self._load()
|
||||
|
||||
def _load(self) -> None:
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name, trust_remote_code=True
|
||||
self.model_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
self.model = AutoModel.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype="auto",
|
||||
device_map=self.device,
|
||||
trust_remote_code=True,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
self.image_processor = AutoImageProcessor.from_pretrained(
|
||||
self.model_name, trust_remote_code=True
|
||||
self.model_name, trust_remote_code=self.trust_remote_code
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -166,6 +166,7 @@ class ComputerAgent:
|
||||
use_prompt_caching: Optional[bool] = False,
|
||||
max_trajectory_budget: Optional[float | dict] = None,
|
||||
telemetry_enabled: Optional[bool] = True,
|
||||
trust_remote_code: Optional[bool] = False,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
@@ -184,6 +185,7 @@ class ComputerAgent:
|
||||
use_prompt_caching: If set, use prompt caching to avoid reprocessing the same prompt. Intended for use with anthropic providers.
|
||||
max_trajectory_budget: If set, adds BudgetManagerCallback to track usage costs and stop when budget is exceeded
|
||||
telemetry_enabled: If set, adds TelemetryCallback to track anonymized usage data. Enabled by default.
|
||||
trust_remote_code: If set, trust remote code when loading local models. Disabled by default.
|
||||
**kwargs: Additional arguments passed to the agent loop
|
||||
"""
|
||||
self.model = model
|
||||
@@ -198,6 +200,7 @@ class ComputerAgent:
|
||||
self.use_prompt_caching = use_prompt_caching
|
||||
self.telemetry_enabled = telemetry_enabled
|
||||
self.kwargs = kwargs
|
||||
self.trust_remote_code = trust_remote_code
|
||||
|
||||
# == Add built-in callbacks ==
|
||||
|
||||
@@ -231,7 +234,8 @@ class ComputerAgent:
|
||||
|
||||
# Register local model providers
|
||||
hf_adapter = HuggingFaceLocalAdapter(
|
||||
device="auto"
|
||||
device="auto",
|
||||
trust_remote_code=self.trust_remote_code or False
|
||||
)
|
||||
human_adapter = HumanAdapter()
|
||||
litellm.custom_provider_map = [
|
||||
|
||||
@@ -331,6 +331,7 @@ Examples:
|
||||
agent_kwargs = {
|
||||
"model": args.model,
|
||||
"tools": [computer],
|
||||
"trust_remote_code": True, # needed for some local models (e.g., InternVL, OpenCUA)
|
||||
"verbosity": 20 if args.verbose else 30, # DEBUG vs WARNING
|
||||
"max_retries": args.max_retries
|
||||
}
|
||||
|
||||
@@ -155,7 +155,7 @@ class GTA1Config(AsyncAgentConfig):
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": [system_message, user_message],
|
||||
"max_tokens": 32,
|
||||
"max_tokens": 2056,
|
||||
"temperature": 0.0,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
@@ -106,7 +106,7 @@ class OpenCUAConfig(AsyncAgentConfig):
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": [system_message, user_message],
|
||||
"max_new_tokens": 512,
|
||||
"max_new_tokens": 2056,
|
||||
"temperature": 0,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
@@ -771,7 +771,7 @@ class UITARSConfig:
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": litellm_messages,
|
||||
"max_tokens": 100,
|
||||
"max_tokens": 2056,
|
||||
"temperature": 0.0,
|
||||
"do_sample": False
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user