mirror of
https://github.com/trycua/computer.git
synced 2026-01-04 12:30:08 -06:00
Added "cua-agent[internvl-hf]" dep
This commit is contained in:
@@ -9,6 +9,7 @@ except ImportError:
|
||||
from .generic import GenericHFModel
|
||||
from .opencua import OpenCUAModel
|
||||
from .qwen2_5_vl import Qwen2_5_VLModel
|
||||
from .internvl import InternVLModel
|
||||
|
||||
def load_model(model_name: str, device: str = "auto", trust_remote_code: bool = False):
|
||||
"""Factory function to load and return the right model handler instance.
|
||||
@@ -22,9 +23,11 @@ def load_model(model_name: str, device: str = "auto", trust_remote_code: bool =
|
||||
)
|
||||
cfg = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code)
|
||||
cls = cfg.__class__.__name__
|
||||
# print(f"cls: {cls}")
|
||||
print(f"cls: {cls}")
|
||||
if "OpenCUA" in cls:
|
||||
return OpenCUAModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
elif "Qwen2_5_VLConfig" in cls:
|
||||
return Qwen2_5_VLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
elif "InternVLChatConfig" in cls:
|
||||
return InternVLModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
return GenericHFModel(model_name=model_name, device=device, trust_remote_code=trust_remote_code)
|
||||
|
||||
78
libs/python/agent/agent/adapters/models/internvl.py
Normal file
78
libs/python/agent/agent/adapters/models/internvl.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
# Hugging Face imports are local to avoid hard dependency at module import
|
||||
try:
|
||||
import torch # type: ignore
|
||||
from transformers import AutoModel, AutoProcessor # type: ignore
|
||||
# Attempt to import InternVL's model dependencies
|
||||
import einops as _ # type: ignore
|
||||
import timm as _ # type: ignore
|
||||
HF_AVAILABLE = True
|
||||
except Exception:
|
||||
HF_AVAILABLE = False
|
||||
|
||||
|
||||
class InternVLModel:
|
||||
"""Generic Hugging Face vision-language model handler.
|
||||
Loads an AutoModelForImageTextToText and AutoProcessor and generates text.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str, device: str = "auto", trust_remote_code: bool = False) -> None:
|
||||
if not HF_AVAILABLE:
|
||||
raise ImportError(
|
||||
"InternVL dependencies not found. Install with: pip install \"cua-agent[internvl-hf]\""
|
||||
)
|
||||
self.model_name = model_name
|
||||
self.device = device
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self._load()
|
||||
|
||||
def _load(self) -> None:
|
||||
# Load model
|
||||
self.model = AutoModel.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.float16,
|
||||
device_map=self.device,
|
||||
attn_implementation="sdpa",
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
# Load processor
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
self.model_name,
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160,
|
||||
device_map=self.device,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
)
|
||||
|
||||
def generate(self, messages: List[Dict[str, Any]], max_new_tokens: int = 128) -> str:
|
||||
"""Generate text for the given HF-format messages.
|
||||
messages: [{ role, content: [{type:'text'|'image', text|image}] }]
|
||||
"""
|
||||
assert self.model is not None and self.processor is not None
|
||||
# Apply chat template and tokenize
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
# Move inputs to the same device as model
|
||||
inputs = inputs.to(self.model.device)
|
||||
# Generate
|
||||
with torch.no_grad():
|
||||
generated_ids = self.model.generate(**inputs, max_new_tokens=max_new_tokens)
|
||||
# Trim prompt tokens from output
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
# Decode
|
||||
output_text = self.processor.batch_decode(
|
||||
generated_ids_trimmed,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=False,
|
||||
)
|
||||
return output_text[0] if output_text else ""
|
||||
@@ -18,6 +18,15 @@ try:
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
import dotenv
|
||||
import base64
|
||||
import time
|
||||
import platform
|
||||
from pathlib import Path
|
||||
try:
|
||||
from PIL import Image, ImageDraw
|
||||
PIL_AVAILABLE = True
|
||||
except Exception:
|
||||
PIL_AVAILABLE = False
|
||||
from yaspin import yaspin
|
||||
except ImportError:
|
||||
if __name__ == "__main__":
|
||||
@@ -248,6 +257,13 @@ Examples:
|
||||
help="Initial prompt to send to the agent. Leave blank for interactive mode."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--predict-click",
|
||||
dest="predict_click",
|
||||
type=str,
|
||||
help="Instruction for click prediction. If set, runs predict_click, draws crosshair on a fresh screenshot, saves and opens it."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-c", "--cache",
|
||||
action="store_true",
|
||||
@@ -354,7 +370,79 @@ Examples:
|
||||
|
||||
agent = ComputerAgent(**agent_kwargs)
|
||||
|
||||
# Start chat loop
|
||||
# If predict-click mode is requested, run once and exit
|
||||
if args.predict_click:
|
||||
if not PIL_AVAILABLE:
|
||||
print_colored("❌ Pillow (PIL) is required for --predict-click visualization. Install with: pip install pillow", Colors.RED, bold=True)
|
||||
sys.exit(1)
|
||||
|
||||
instruction = args.predict_click
|
||||
print_colored(f"Predicting click for: '{instruction}'", Colors.CYAN)
|
||||
|
||||
# Take a fresh screenshot FIRST
|
||||
try:
|
||||
img_bytes = await computer.interface.screenshot()
|
||||
except Exception as e:
|
||||
print_colored(f"❌ Failed to take screenshot: {e}", Colors.RED, bold=True)
|
||||
sys.exit(1)
|
||||
|
||||
# Encode screenshot to base64 for predict_click
|
||||
try:
|
||||
image_b64 = base64.b64encode(img_bytes).decode("utf-8")
|
||||
except Exception as e:
|
||||
print_colored(f"❌ Failed to encode screenshot: {e}", Colors.RED, bold=True)
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
coords = await agent.predict_click(instruction, image_b64=image_b64)
|
||||
except Exception as e:
|
||||
print_colored(f"❌ predict_click failed: {e}", Colors.RED, bold=True)
|
||||
sys.exit(1)
|
||||
|
||||
if not coords:
|
||||
print_colored("⚠️ No coordinates returned.", Colors.YELLOW)
|
||||
sys.exit(2)
|
||||
|
||||
x, y = coords
|
||||
print_colored(f"✅ Predicted coordinates: ({x}, {y})", Colors.GREEN)
|
||||
|
||||
try:
|
||||
from io import BytesIO
|
||||
with Image.open(BytesIO(img_bytes)) as img:
|
||||
img = img.convert("RGB")
|
||||
draw = ImageDraw.Draw(img)
|
||||
# Draw crosshair
|
||||
size = 12
|
||||
color = (255, 0, 0)
|
||||
draw.line([(x - size, y), (x + size, y)], fill=color, width=3)
|
||||
draw.line([(x, y - size), (x, y + size)], fill=color, width=3)
|
||||
# Optional small circle
|
||||
r = 6
|
||||
draw.ellipse([(x - r, y - r), (x + r, y + r)], outline=color, width=2)
|
||||
|
||||
out_path = Path.cwd() / f"predict_click_{int(time.time())}.png"
|
||||
img.save(out_path)
|
||||
print_colored(f"🖼️ Saved to {out_path}")
|
||||
|
||||
# Open the image with default viewer
|
||||
try:
|
||||
system = platform.system().lower()
|
||||
if system == "windows":
|
||||
os.startfile(str(out_path)) # type: ignore[attr-defined]
|
||||
elif system == "darwin":
|
||||
os.system(f"open \"{out_path}\"")
|
||||
else:
|
||||
os.system(f"xdg-open \"{out_path}\"")
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as e:
|
||||
print_colored(f"❌ Failed to render/save screenshot: {e}", Colors.RED, bold=True)
|
||||
sys.exit(1)
|
||||
|
||||
# Done
|
||||
sys.exit(0)
|
||||
|
||||
# Start chat loop (default interactive mode)
|
||||
await chat_loop(agent, args.model, container_name, args.prompt, args.usage)
|
||||
|
||||
|
||||
|
||||
@@ -53,6 +53,13 @@ opencua-hf = [
|
||||
"tiktoken>=0.11.0",
|
||||
"blobfile>=3.0.0"
|
||||
]
|
||||
internvl-hf = [
|
||||
"accelerate",
|
||||
"torch",
|
||||
"transformers>=4.55.0",
|
||||
"einops",
|
||||
"timm"
|
||||
]
|
||||
ui = [
|
||||
"gradio>=5.23.3",
|
||||
"python-dotenv>=1.0.1",
|
||||
@@ -68,7 +75,10 @@ all = [
|
||||
"mlx-vlm>=0.1.27; sys_platform == 'darwin'",
|
||||
"accelerate",
|
||||
"torch",
|
||||
"transformers>=4.54.0",
|
||||
"transformers>=4.55.0",
|
||||
# internvl requirements,
|
||||
"einops",
|
||||
"timm",
|
||||
# opencua requirements
|
||||
"tiktoken>=0.11.0",
|
||||
"blobfile>=3.0.0",
|
||||
|
||||
Reference in New Issue
Block a user