Added "cua-agent[internvl-hf]" dep

This commit is contained in:
Dillon DuPont
2025-09-16 12:02:07 -04:00
parent 7cf27b1cc3
commit 9147e8eeaf
4 changed files with 182 additions and 3 deletions

View File

@@ -9,6 +9,7 @@ except ImportError:
from .generic import GenericHFModel
from .opencua import OpenCUAModel
from .qwen2_5_vl import Qwen2_5_VLModel
from .internvl import InternVLModel
def load_model(model_name: str, device: str = "auto", trust_remote_code: bool = False):
"""Factory function to load and return the right model handler instance.
@@ -22,9 +23,11 @@ def load_model(model_name: str, device: str = "auto", trust_remote_code: bool =
)
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
cls = cfg.__class__.__name__
# print(f"cls: {cls}")
print(f"cls: {cls}")
if "OpenCUA" in cls:
return OpenCUAModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
elif "Qwen2_5_VLConfig" in cls:
return Qwen2_5_VLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
elif "InternVLChatConfig" in cls:
return InternVLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
return GenericHFModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)

View File

@@ -0,0 +1,78 @@
from typing import List, Dict, Any, Optional
# Hugging Face imports are local to avoid hard dependency at module import
try:
import torch # type: ignore
from transformers import AutoModel, AutoProcessor # type: ignore
# Attempt to import InternVL's model dependencies
import einops as _ # type: ignore
import timm as _ # type: ignore
HF_AVAILABLE = True
except Exception:
HF_AVAILABLE = False
class InternVLModel:
"""Generic Hugging Face vision-language model handler.
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
"""
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
if not HF_AVAILABLE:
raise ImportError(
"InternVL dependencies not found. Install with: pip install \"cua-agent[internvl-hf]\""
)
self.model_name = model_name
self.device = device
self.model = None
self.processor = None
self.trust_remote_code = trust_remote_code
self._load()
def _load(self) -> None:
# Load model
self.model = AutoModel.from_pretrained(
self.model_name,
torch_dtype=torch.float16,
device_map=self.device,
attn_implementation="sdpa",
trust_remote_code=self.trust_remote_code,
)
# Load processor
self.processor = AutoProcessor.from_pretrained(
self.model_name,
min_pixels=3136,
max_pixels=4096 * 2160,
device_map=self.device,
trust_remote_code=self.trust_remote_code,
)
def generate(self, messages: List[Dict[str, Any]], max_new_tokens: int = 128) -> str:
"""Generate text for the given HF-format messages.
messages: [{ role, content: [{type:'text'|'image', text|image}] }]
"""
assert self.model is not None and self.processor is not None
# Apply chat template and tokenize
inputs = self.processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
# Move inputs to the same device as model
inputs = inputs.to(self.model.device)
# Generate
with torch.no_grad():
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
# Trim prompt tokens from output
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
# Decode
output_text = self.processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)
return output_text[0] if output_text else ""

View File

@@ -18,6 +18,15 @@ try:
import json
from typing import List, Dict, Any
import dotenv
import base64
import time
import platform
from pathlib import Path
try:
from PIL import Image, ImageDraw
PIL_AVAILABLE = True
except Exception:
PIL_AVAILABLE = False
from yaspin import yaspin
except ImportError:
if __name__ == "__main__":
@@ -248,6 +257,13 @@ Examples:
help="Initial prompt to send to the agent. Leave blank for interactive mode."
)
parser.add_argument(
"--predict-click",
dest="predict_click",
type=str,
help="Instruction for click prediction. If set, runs predict_click, draws crosshair on a fresh screenshot, saves and opens it."
)
parser.add_argument(
"-c", "--cache",
action="store_true",
@@ -354,7 +370,79 @@ Examples:
agent = ComputerAgent(**agent_kwargs)
# Start chat loop
# If predict-click mode is requested, run once and exit
if args.predict_click:
if not PIL_AVAILABLE:
print_colored("❌ Pillow (PIL) is required for --predict-click visualization. Install with: pip install pillow", Colors.RED, bold=True)
sys.exit(1)
instruction = args.predict_click
print_colored(f"Predicting click for: '{instruction}'", Colors.CYAN)
# Take a fresh screenshot FIRST
try:
img_bytes = await computer.interface.screenshot()
except Exception as e:
print_colored(f"❌ Failed to take screenshot: {e}", Colors.RED, bold=True)
sys.exit(1)
# Encode screenshot to base64 for predict_click
try:
image_b64 = base64.b64encode(img_bytes).decode("utf-8")
except Exception as e:
print_colored(f"❌ Failed to encode screenshot: {e}", Colors.RED, bold=True)
sys.exit(1)
try:
coords = await agent.predict_click(instruction, image_b64=image_b64)
except Exception as e:
print_colored(f"❌ predict_click failed: {e}", Colors.RED, bold=True)
sys.exit(1)
if not coords:
print_colored("⚠️ No coordinates returned.", Colors.YELLOW)
sys.exit(2)
x, y = coords
print_colored(f"✅ Predicted coordinates: ({x}, {y})", Colors.GREEN)
try:
from io import BytesIO
with Image.open(BytesIO(img_bytes)) as img:
img = img.convert("RGB")
draw = ImageDraw.Draw(img)
# Draw crosshair
size = 12
color = (255, 0, 0)
draw.line([(x - size, y), (x + size, y)], fill=color, width=3)
draw.line([(x, y - size), (x, y + size)], fill=color, width=3)
# Optional small circle
r = 6
draw.ellipse([(x - r, y - r), (x + r, y + r)], outline=color, width=2)
out_path = Path.cwd() / f"predict_click_{int(time.time())}.png"
img.save(out_path)
print_colored(f"🖼️ Saved to {out_path}")
# Open the image with default viewer
try:
system = platform.system().lower()
if system == "windows":
os.startfile(str(out_path)) # type: ignore[attr-defined]
elif system == "darwin":
os.system(f"open \"{out_path}\"")
else:
os.system(f"xdg-open \"{out_path}\"")
except Exception:
pass
except Exception as e:
print_colored(f"❌ Failed to render/save screenshot: {e}", Colors.RED, bold=True)
sys.exit(1)
# Done
sys.exit(0)
# Start chat loop (default interactive mode)
await chat_loop(agent, args.model, container_name, args.prompt, args.usage)

View File

@@ -53,6 +53,13 @@ opencua-hf = [
"tiktoken>=0.11.0",
"blobfile>=3.0.0"
]
internvl-hf = [
"accelerate",
"torch",
"transformers>=4.55.0",
"einops",
"timm"
]
ui = [
"gradio>=5.23.3",
"python-dotenv>=1.0.1",
@@ -68,7 +75,10 @@ all = [
"mlx-vlm>=0.1.27; sys_platform == 'darwin'",
"accelerate",
"torch",
"transformers>=4.54.0",
"transformers>=4.55.0",
# internvl requirements,
"einops",
"timm",
# opencua requirements
"tiktoken>=0.11.0",
"blobfile>=3.0.0",