mirror of
https://github.com/trycua/computer.git
synced 2026-01-08 06:20:00 -06:00
added GTA1 agent and click benchmarks (ss-pro, repl)
This commit is contained in:
@@ -48,7 +48,11 @@ class HuggingFaceLocalAdapter(CustomLLM):
|
||||
)
|
||||
|
||||
# Load processor
|
||||
processor = AutoProcessor.from_pretrained(model_name)
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
model_name,
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160
|
||||
)
|
||||
|
||||
# Cache them
|
||||
self.models[model_name] = model
|
||||
|
||||
@@ -7,5 +7,6 @@ from . import anthropic
|
||||
from . import openai
|
||||
from . import uitars
|
||||
from . import omniparser
|
||||
from . import gta1
|
||||
|
||||
__all__ = ["anthropic", "openai", "uitars", "omniparser"]
|
||||
__all__ = ["anthropic", "openai", "uitars", "omniparser", "gta1"]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""
|
||||
GTA1 agent loop implementation for click prediction using litellm.acompletion
|
||||
Paper: https://arxiv.org/pdf/2507.05791
|
||||
Code: https://github.com/Yan98/GTA1
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -20,7 +22,7 @@ You are an expert UI element locator. Given a GUI image and a user's element des
|
||||
|
||||
Output the coordinate pair exactly:
|
||||
(x,y)
|
||||
'''
|
||||
'''.strip()
|
||||
|
||||
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
|
||||
"""Extract coordinates from model output."""
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""
|
||||
OpenAI computer-use-preview agent loop implementation using liteLLM
|
||||
Paper: https://arxiv.org/abs/2408.00203
|
||||
Code: https://github.com/microsoft/OmniParser
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""
|
||||
UITARS agent loop implementation using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B
|
||||
Paper: https://arxiv.org/abs/2501.12326
|
||||
Code: https://github.com/bytedance/UI-TARS
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -79,6 +81,18 @@ Action: ...
|
||||
{instruction}
|
||||
"""
|
||||
|
||||
GROUNDING_UITARS_PROMPT_TEMPLATE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
|
||||
|
||||
## Output Format
|
||||
|
||||
Action: ...
|
||||
|
||||
|
||||
## Action Space
|
||||
click(point='<|box_start|>(x1,y1)<|box_end|>')
|
||||
|
||||
## User Instruction
|
||||
{instruction}"""
|
||||
|
||||
def round_by_factor(number: float, factor: int) -> int:
|
||||
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
|
||||
@@ -511,7 +525,7 @@ class UITARSConfig:
|
||||
|
||||
async def predict_step(
|
||||
self,
|
||||
messages: Messages,
|
||||
messages: List[Dict[str, Any]],
|
||||
model: str,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
@@ -729,13 +743,15 @@ class UITARSConfig:
|
||||
Tuple with (x, y) coordinates or None
|
||||
"""
|
||||
try:
|
||||
# Create a simple click instruction for UITARS
|
||||
user_prompt = UITARS_PROMPT_TEMPLATE.format(
|
||||
instruction=f"Click on: {instruction}",
|
||||
action_space=UITARS_ACTION_SPACE,
|
||||
language="English"
|
||||
# Create prompt using grounding template
|
||||
user_prompt = GROUNDING_UITARS_PROMPT_TEMPLATE.format(
|
||||
instruction=instruction
|
||||
)
|
||||
|
||||
# Process image for UITARS
|
||||
processed_image, original_width, original_height = process_image_for_uitars(image_b64)
|
||||
encoded_image = pil_to_base64(processed_image)
|
||||
|
||||
# Prepare messages for liteLLM
|
||||
litellm_messages = [
|
||||
{
|
||||
@@ -746,46 +762,47 @@ class UITARSConfig:
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": user_prompt},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
# Prepare API call kwargs
|
||||
api_kwargs = {
|
||||
"model": model,
|
||||
"messages": litellm_messages,
|
||||
"max_tokens": 100,
|
||||
"temperature": 0.0,
|
||||
"do_sample": False
|
||||
}
|
||||
|
||||
# Call liteLLM with UITARS model
|
||||
response = await litellm.acompletion(
|
||||
model=model,
|
||||
messages=litellm_messages,
|
||||
max_tokens=100,
|
||||
temperature=0.0
|
||||
)
|
||||
response = await litellm.acompletion(**api_kwargs)
|
||||
|
||||
# Extract response content
|
||||
response_content = response.choices[0].message.content.strip() # type: ignore
|
||||
|
||||
# Parse UITARS response to extract click coordinates
|
||||
parsed_responses = parse_uitars_response(response_content, 1024, 768) # Default dimensions
|
||||
# Parse the response to extract click coordinates
|
||||
# Look for click action with coordinates
|
||||
click_pattern = r"click\(point='<\|box_start\|>\((\d+),(\d+)\)<\|box_end\|>'\)"
|
||||
match = re.search(click_pattern, response_content)
|
||||
|
||||
if parsed_responses and len(parsed_responses) > 0:
|
||||
action_type = parsed_responses[0].get("action_type")
|
||||
if action_type == "click":
|
||||
action_inputs = parsed_responses[0].get("action_inputs", {})
|
||||
start_box = action_inputs.get("start_box")
|
||||
if start_box:
|
||||
# Parse coordinates from start_box
|
||||
try:
|
||||
coords = eval(start_box) # Parse the coordinate list
|
||||
if len(coords) >= 2:
|
||||
# Convert normalized coordinates back to pixel coordinates
|
||||
x = int(coords[0] * 1024)
|
||||
y = int(coords[1] * 768)
|
||||
return (x, y)
|
||||
except:
|
||||
pass
|
||||
if match:
|
||||
x, y = int(match.group(1)), int(match.group(2))
|
||||
# Scale coordinates back to original image dimensions
|
||||
scale_x = original_width / processed_image.width
|
||||
scale_y = original_height / processed_image.height
|
||||
|
||||
scaled_x = int(x * scale_x)
|
||||
scaled_y = int(y * scale_y)
|
||||
|
||||
return (scaled_x, scaled_y)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in UITARS predict_click: {e}")
|
||||
# Log error and return None
|
||||
print(f"Error in predict_click: {e}")
|
||||
return None
|
||||
|
||||
def get_capabilities(self) -> List[AgentCapability]:
|
||||
|
||||
2
libs/python/agent/benchmarks/.gitignore
vendored
Normal file
2
libs/python/agent/benchmarks/.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
output/
|
||||
interactive_output/
|
||||
201
libs/python/agent/benchmarks/interactive.py
Normal file
201
libs/python/agent/benchmarks/interactive.py
Normal file
@@ -0,0 +1,201 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Interactive Click Prediction Tool
|
||||
|
||||
Takes screenshots and allows testing multiple models interactively.
|
||||
Models are loaded/unloaded one at a time to avoid memory issues.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from utils import (
|
||||
ModelWrapper,
|
||||
take_screenshot,
|
||||
save_prediction_visualization,
|
||||
get_available_models
|
||||
)
|
||||
|
||||
|
||||
async def predict_with_all_models(image, instruction: str, models) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Predict click coordinates with all models sequentially.
|
||||
|
||||
Args:
|
||||
image: PIL Image to analyze
|
||||
instruction: Instruction text
|
||||
models: List of model instances
|
||||
|
||||
Returns:
|
||||
List of prediction results
|
||||
"""
|
||||
predictions = []
|
||||
|
||||
for model in models:
|
||||
model_wrapper = ModelWrapper(model)
|
||||
print(f"\n🔄 Loading {model_wrapper.model_name}...")
|
||||
|
||||
try:
|
||||
# Load model
|
||||
await model_wrapper.load_model()
|
||||
|
||||
# Predict
|
||||
coords = await model_wrapper.predict_click(image, instruction)
|
||||
|
||||
predictions.append({
|
||||
'model_name': model_wrapper.model_name,
|
||||
'coords': coords,
|
||||
'error': None
|
||||
})
|
||||
|
||||
if coords:
|
||||
print(f"✅ {model_wrapper.model_name}: ({coords[0]}, {coords[1]})")
|
||||
else:
|
||||
print(f"❌ {model_wrapper.model_name}: No prediction")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ {model_wrapper.model_name}: ERROR - {str(e)}")
|
||||
predictions.append({
|
||||
'model_name': model_wrapper.model_name,
|
||||
'coords': None,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
finally:
|
||||
# Always unload model to free memory
|
||||
try:
|
||||
await model_wrapper.unload_model()
|
||||
print(f"🗑️ Unloaded {model_wrapper.model_name}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error unloading {model_wrapper.model_name}: {e}")
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def print_header():
|
||||
"""Print the interactive tool header."""
|
||||
print("=" * 60)
|
||||
print("🖱️ Interactive Click Prediction Tool")
|
||||
print("=" * 60)
|
||||
print("Commands:")
|
||||
print(" • Type an instruction to test models on last screenshot")
|
||||
print(" • 'screenshot' - Take a new screenshot")
|
||||
print(" • 'models' - List available models")
|
||||
print(" • 'quit' or 'exit' - Exit the tool")
|
||||
print("=" * 60)
|
||||
print("💡 Tip: Take a screenshot first, then send instructions to test models!")
|
||||
|
||||
|
||||
def print_models(models):
|
||||
"""Print available models."""
|
||||
print("\n📋 Available Models:")
|
||||
for i, model in enumerate(models, 1):
|
||||
if isinstance(model, str):
|
||||
print(f" {i}. {model}")
|
||||
else:
|
||||
print(f" {i}. models.{model.__class__.__name__}")
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Main interactive loop.
|
||||
"""
|
||||
print_header()
|
||||
|
||||
# Get available models
|
||||
models = get_available_models()
|
||||
print_models(models)
|
||||
|
||||
# Create output directory for visualizations
|
||||
output_dir = "interactive_output"
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
session_count = 0
|
||||
last_screenshot = None
|
||||
screenshot_timestamp = None
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Get user input
|
||||
print(f"\n{'='*40}")
|
||||
user_input = input("🎯 Enter instruction (or command): ").strip()
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Handle commands
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
print("👋 Goodbye!")
|
||||
break
|
||||
|
||||
elif user_input.lower() == 'models':
|
||||
print_models(models)
|
||||
continue
|
||||
|
||||
elif user_input.lower() == 'screenshot':
|
||||
print("📸 Taking screenshot...")
|
||||
try:
|
||||
last_screenshot = take_screenshot()
|
||||
screenshot_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
screenshot_path = os.path.join(output_dir, f"screenshot_{screenshot_timestamp}.png")
|
||||
last_screenshot.save(screenshot_path)
|
||||
print(f"✅ Screenshot captured and saved to: {screenshot_path}")
|
||||
print(f"📝 Ready for instructions! Screenshot size: {last_screenshot.size}")
|
||||
except Exception as e:
|
||||
print(f"❌ Error taking screenshot: {e}")
|
||||
continue
|
||||
|
||||
# Handle instruction input
|
||||
if last_screenshot is None:
|
||||
print("⚠️ No screenshot available! Please take a screenshot first using 'screenshot' command.")
|
||||
continue
|
||||
|
||||
session_count += 1
|
||||
print(f"\n🎯 Session {session_count}: '{user_input}'")
|
||||
print(f"📷 Using screenshot from: {screenshot_timestamp}")
|
||||
|
||||
# Predict with all models using last screenshot
|
||||
print(f"\n🤖 Testing {len(models)} models on screenshot...")
|
||||
predictions = await predict_with_all_models(last_screenshot, user_input, models)
|
||||
|
||||
# Display results summary
|
||||
print(f"\n📊 Results Summary:")
|
||||
print("-" * 50)
|
||||
for pred in predictions:
|
||||
if pred['coords']:
|
||||
print(f"✅ {pred['model_name']}: ({pred['coords'][0]}, {pred['coords'][1]})")
|
||||
elif pred['error']:
|
||||
print(f"❌ {pred['model_name']}: ERROR - {pred['error']}")
|
||||
else:
|
||||
print(f"❌ {pred['model_name']}: No prediction")
|
||||
|
||||
# Save visualization
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
vis_filename = f"session_{session_count:03d}_{timestamp}.png"
|
||||
vis_path = os.path.join(output_dir, vis_filename)
|
||||
|
||||
try:
|
||||
save_prediction_visualization(last_screenshot, user_input, predictions, vis_path)
|
||||
print(f"\n💾 Visualization saved to: {vis_path}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ Error saving visualization: {e}")
|
||||
|
||||
print(f"\n✨ Session {session_count} completed!")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n👋 Interrupted by user. Goodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"\n❌ Unexpected error: {e}")
|
||||
print("Continuing...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(main())
|
||||
except KeyboardInterrupt:
|
||||
print("\n👋 Goodbye!")
|
||||
except Exception as e:
|
||||
print(f"❌ Fatal error: {e}")
|
||||
4
libs/python/agent/benchmarks/models/__init__.py
Normal file
4
libs/python/agent/benchmarks/models/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .base import ModelProtocol
|
||||
from .gta1 import GTA1Model
|
||||
|
||||
__all__ = ["ModelProtocol", "GTA1Model"]
|
||||
36
libs/python/agent/benchmarks/models/base.py
Normal file
36
libs/python/agent/benchmarks/models/base.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""
|
||||
Base protocol for benchmark models.
|
||||
"""
|
||||
|
||||
from typing import Protocol, Optional, Tuple
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class ModelProtocol(Protocol):
|
||||
"""Protocol for benchmark models that can predict click coordinates."""
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return the name of the model."""
|
||||
...
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model into memory."""
|
||||
...
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the model from memory."""
|
||||
...
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates for the given image and instruction.
|
||||
|
||||
Args:
|
||||
image: PIL Image to analyze
|
||||
instruction: Text instruction describing what to click
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
...
|
||||
162
libs/python/agent/benchmarks/models/gta1.py
Normal file
162
libs/python/agent/benchmarks/models/gta1.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
GTA1 model implementation for benchmarking.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
import torch
|
||||
import re
|
||||
import gc
|
||||
from qwen_vl_utils import process_vision_info, smart_resize
|
||||
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
|
||||
from .base import ModelProtocol
|
||||
|
||||
|
||||
class GTA1Model:
|
||||
"""Ground truth GTA1 model implementation."""
|
||||
|
||||
def __init__(self, model_path: str = "HelloKKMe/GTA1-7B"):
|
||||
self.model_path = model_path
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.max_new_tokens = 32
|
||||
|
||||
self.system_prompt = '''
|
||||
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
|
||||
|
||||
Output the coordinate pair exactly:
|
||||
(x,y)
|
||||
'''.strip()
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Return the name of the model."""
|
||||
return f"GTA1-{self.model_path.split('/')[-1]}"
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model into memory."""
|
||||
if self.model is None:
|
||||
print(f"Loading GTA1 model: {self.model_path}")
|
||||
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
self.model_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto"
|
||||
)
|
||||
self.processor = AutoProcessor.from_pretrained(
|
||||
self.model_path,
|
||||
min_pixels=3136,
|
||||
max_pixels=4096 * 2160
|
||||
)
|
||||
print("GTA1 model loaded successfully")
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the model from memory."""
|
||||
if self.model is not None:
|
||||
print("Unloading GTA1 model from GPU...")
|
||||
del self.model
|
||||
del self.processor
|
||||
self.model = None
|
||||
self.processor = None
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
print("GTA1 model unloaded")
|
||||
|
||||
def _extract_coordinates(self, raw_string: str) -> Tuple[int, int]:
|
||||
"""Extract coordinates from model output."""
|
||||
try:
|
||||
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
|
||||
return tuple(map(int, map(float, matches[0]))) # type: ignore
|
||||
except:
|
||||
return (0, 0)
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates for the given image and instruction.
|
||||
|
||||
Args:
|
||||
image: PIL Image to analyze
|
||||
instruction: Text instruction describing what to click
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
if self.model is None or self.processor is None:
|
||||
await self.load_model()
|
||||
|
||||
assert self.processor is not None
|
||||
assert self.model is not None
|
||||
|
||||
try:
|
||||
width, height = image.width, image.height
|
||||
|
||||
# Resize image according to processor requirements
|
||||
resized_height, resized_width = smart_resize(
|
||||
image.height,
|
||||
image.width,
|
||||
factor=self.processor.image_processor.patch_size * self.processor.image_processor.merge_size,
|
||||
min_pixels=self.processor.image_processor.min_pixels,
|
||||
max_pixels=self.processor.image_processor.max_pixels,
|
||||
)
|
||||
resized_image = image.resize((resized_width, resized_height))
|
||||
scale_x, scale_y = width / resized_width, height / resized_height
|
||||
|
||||
# Prepare messages
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": self.system_prompt.format(height=resized_height, width=resized_width)
|
||||
}
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": resized_image},
|
||||
{"type": "text", "text": instruction}
|
||||
]
|
||||
}
|
||||
|
||||
# Process inputs
|
||||
image_inputs, video_inputs = process_vision_info([system_message, user_message])
|
||||
text = self.processor.apply_chat_template(
|
||||
[system_message, user_message],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
inputs = self.processor(
|
||||
text=[text],
|
||||
images=image_inputs,
|
||||
videos=video_inputs,
|
||||
padding=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
inputs = inputs.to(self.model.device)
|
||||
|
||||
# Generate prediction
|
||||
output_ids = self.model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=self.max_new_tokens,
|
||||
do_sample=False,
|
||||
temperature=1.0,
|
||||
use_cache=True
|
||||
)
|
||||
generated_ids = [
|
||||
output_ids[len(input_ids):]
|
||||
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
|
||||
]
|
||||
output_text = self.processor.batch_decode(
|
||||
generated_ids,
|
||||
skip_special_tokens=True,
|
||||
clean_up_tokenization_spaces=True
|
||||
)[0]
|
||||
|
||||
# Extract and rescale coordinates
|
||||
pred_x, pred_y = self._extract_coordinates(output_text)
|
||||
pred_x = int(pred_x * scale_x)
|
||||
pred_y = int(pred_y * scale_y)
|
||||
|
||||
return (pred_x, pred_y)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in GTA1 prediction: {e}")
|
||||
return None
|
||||
157
libs/python/agent/benchmarks/ss-pro.py
Normal file
157
libs/python/agent/benchmarks/ss-pro.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ScreenSpot-Pro Benchmark Script
|
||||
|
||||
Evaluates models on the ScreenSpot-Pro dataset for click prediction accuracy.
|
||||
Supports both ComputerAgent model strings and custom model classes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils import (
|
||||
ModelWrapper,
|
||||
is_click_in_bbox,
|
||||
save_results_to_markdown,
|
||||
save_visualizations,
|
||||
get_available_models
|
||||
)
|
||||
|
||||
|
||||
async def evaluate_model(model_wrapper: ModelWrapper, dataset, max_samples: Optional[int] = None) -> dict:
|
||||
"""
|
||||
Evaluate a model on the ScreenSpot-Pro dataset.
|
||||
|
||||
Args:
|
||||
model_wrapper: ModelWrapper instance
|
||||
dataset: ScreenSpot-Pro dataset (list of samples)
|
||||
max_samples: Maximum number of samples to evaluate (None for all)
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation results
|
||||
"""
|
||||
print(f"\nEvaluating model: {model_wrapper.model_name}")
|
||||
|
||||
# Load model
|
||||
await model_wrapper.load_model()
|
||||
|
||||
total_samples = len(dataset)
|
||||
if max_samples is not None:
|
||||
total_samples = min(max_samples, total_samples)
|
||||
|
||||
correct_predictions = 0
|
||||
failed_predictions = 0
|
||||
results = []
|
||||
|
||||
try:
|
||||
for i in tqdm(range(total_samples), desc=f"Evaluating {model_wrapper.model_name}"):
|
||||
sample = dataset[i]
|
||||
|
||||
# Extract sample data
|
||||
image = sample['image']
|
||||
instruction = sample['instruction']
|
||||
bbox = sample['bbox'] # [x1, y1, x2, y2]
|
||||
sample_id = sample['id']
|
||||
|
||||
# Predict click coordinates
|
||||
try:
|
||||
click_coords = await model_wrapper.predict_click(image, instruction)
|
||||
|
||||
# Check if prediction is correct
|
||||
is_correct = is_click_in_bbox(click_coords, bbox)
|
||||
|
||||
if is_correct:
|
||||
correct_predictions += 1
|
||||
|
||||
results.append({
|
||||
'id': sample_id,
|
||||
'instruction': instruction,
|
||||
'bbox': bbox,
|
||||
'predicted_coords': click_coords,
|
||||
'is_correct': is_correct,
|
||||
'failed': False
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError predicting sample {sample_id}: {e}")
|
||||
failed_predictions += 1
|
||||
results.append({
|
||||
'id': sample_id,
|
||||
'instruction': instruction,
|
||||
'bbox': bbox,
|
||||
'predicted_coords': None,
|
||||
'is_correct': False,
|
||||
'failed': True,
|
||||
'error': str(e)
|
||||
})
|
||||
|
||||
finally:
|
||||
# Unload model
|
||||
await model_wrapper.unload_model()
|
||||
|
||||
# Calculate metrics
|
||||
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
|
||||
failure_rate = failed_predictions / total_samples if total_samples > 0 else 0.0
|
||||
|
||||
return {
|
||||
'model_name': model_wrapper.model_name,
|
||||
'total_samples': total_samples,
|
||||
'correct_predictions': correct_predictions,
|
||||
'failed_predictions': failed_predictions,
|
||||
'accuracy': accuracy,
|
||||
'failure_rate': failure_rate,
|
||||
'results': results
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Main function to run the benchmark.
|
||||
"""
|
||||
# Load dataset
|
||||
print("Loading ScreenSpot-Pro dataset...")
|
||||
ds = load_dataset("lmms-lab/ScreenSpot-Pro")
|
||||
dataset = ds['train'] # type: ignore
|
||||
# Convert to list to support indexing
|
||||
dataset_list = list(dataset)
|
||||
print(f"Dataset loaded: {len(dataset_list)} samples")
|
||||
|
||||
# Get available models
|
||||
models = get_available_models()
|
||||
|
||||
# Evaluation settings
|
||||
max_samples = 5 # Set to None to evaluate on full dataset
|
||||
|
||||
# Run evaluations
|
||||
all_results = []
|
||||
|
||||
for model in models:
|
||||
try:
|
||||
model_wrapper = ModelWrapper(model)
|
||||
result = await evaluate_model(model_wrapper, dataset_list, max_samples)
|
||||
all_results.append(result)
|
||||
|
||||
# Print summary
|
||||
print(f"\n{result['model_name']} Results:")
|
||||
print(f" Accuracy: {result['accuracy']*100:.2f}%")
|
||||
print(f" Correct: {result['correct_predictions']}/{result['total_samples']}")
|
||||
print(f" Failed: {result['failed_predictions']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating model {model}: {e}")
|
||||
continue
|
||||
|
||||
# Save results
|
||||
if all_results:
|
||||
save_results_to_markdown(all_results)
|
||||
save_visualizations(all_results, dataset_list)
|
||||
print("\nBenchmark completed successfully!")
|
||||
else:
|
||||
print("\nNo successful evaluations completed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
316
libs/python/agent/benchmarks/utils.py
Normal file
316
libs/python/agent/benchmarks/utils.py
Normal file
@@ -0,0 +1,316 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Shared utilities for ScreenSpot-Pro benchmarking and interactive testing.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
from typing import List, Union, Tuple, Optional
|
||||
|
||||
from PIL import Image, ImageDraw
|
||||
from tqdm import tqdm
|
||||
import gc
|
||||
import torch
|
||||
|
||||
# Add parent directory to path for imports
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
from agent.agent import ComputerAgent
|
||||
from models import GTA1Model
|
||||
from models.base import ModelProtocol
|
||||
|
||||
def get_available_models() -> List[Union[str, ModelProtocol]]:
|
||||
"""
|
||||
Get list of available models for testing.
|
||||
|
||||
Returns:
|
||||
List of model strings and model classes
|
||||
"""
|
||||
local_provider = "huggingface-local/" # Options: huggingface-local/ or mlx/
|
||||
|
||||
models = [
|
||||
# === ComputerAgent model strings ===
|
||||
f"{local_provider}HelloKKMe/GTA1-7B",
|
||||
# f"{local_provider}HelloKKMe/GTA1-32B", # Uncomment if you have this model
|
||||
|
||||
# === Reference model classes ===
|
||||
GTA1Model("HelloKKMe/GTA1-7B"),
|
||||
# GTA1Model("HelloKKMe/GTA1-32B"), # Uncomment if you have this model
|
||||
]
|
||||
|
||||
return models
|
||||
|
||||
|
||||
def is_click_in_bbox(click_coords: Optional[Tuple[int, int]], bbox: List[int]) -> bool:
|
||||
"""
|
||||
Check if click coordinates are within the bounding box.
|
||||
|
||||
Args:
|
||||
click_coords: (x, y) coordinates or None
|
||||
bbox: [x1, y1, x2, y2] bounding box
|
||||
|
||||
Returns:
|
||||
True if click is within bbox, False otherwise
|
||||
"""
|
||||
if click_coords is None:
|
||||
return False
|
||||
|
||||
x, y = click_coords
|
||||
x1, y1, x2, y2 = bbox
|
||||
|
||||
return x1 <= x <= x2 and y1 <= y <= y2
|
||||
|
||||
|
||||
def image_to_base64(image: Image.Image) -> str:
|
||||
"""
|
||||
Convert PIL Image to base64 string.
|
||||
|
||||
Args:
|
||||
image: PIL Image
|
||||
|
||||
Returns:
|
||||
Base64 encoded image string
|
||||
"""
|
||||
buffered = BytesIO()
|
||||
image.save(buffered, format="PNG")
|
||||
return base64.b64encode(buffered.getvalue()).decode()
|
||||
|
||||
|
||||
class ModelWrapper:
|
||||
"""
|
||||
Wrapper to provide unified interface for both ComputerAgent and custom models.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Union[str, ModelProtocol]):
|
||||
self.model = model
|
||||
self.is_computer_agent = isinstance(model, str)
|
||||
self.agent: Optional[ComputerAgent] = None
|
||||
|
||||
if self.is_computer_agent:
|
||||
self.model_name = str(model)
|
||||
else:
|
||||
self.model_name = f"models.{model.__class__.__name__}"
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model."""
|
||||
if self.is_computer_agent:
|
||||
self.agent = ComputerAgent(model=str(self.model))
|
||||
else:
|
||||
await self.model.load_model() # type: ignore
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the model."""
|
||||
if not self.is_computer_agent:
|
||||
await self.model.unload_model() # type: ignore
|
||||
else:
|
||||
del self.agent
|
||||
self.agent = None
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
"""Predict click coordinates."""
|
||||
if self.is_computer_agent:
|
||||
if self.agent is None:
|
||||
await self.load_model()
|
||||
|
||||
if self.agent is not None:
|
||||
image_b64 = image_to_base64(image)
|
||||
result = await self.agent.predict_click(instruction=instruction, image_b64=image_b64)
|
||||
return result
|
||||
return None
|
||||
else:
|
||||
return await self.model.predict_click(image, instruction) # type: ignore
|
||||
|
||||
|
||||
def save_results_to_markdown(all_results: List[dict], output_file: str = "screenspot_pro_results.md") -> None:
|
||||
"""
|
||||
Save evaluation results to a markdown table.
|
||||
|
||||
Args:
|
||||
all_results: List of evaluation results for each model
|
||||
output_file: Output markdown file path
|
||||
"""
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write("# ScreenSpot-Pro Benchmark Results\n\n")
|
||||
f.write(f"**Evaluation Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
|
||||
# Summary table
|
||||
f.write("## Summary\n\n")
|
||||
f.write("| Model | Total Samples | Correct | Failed | Accuracy | Failure Rate |\n")
|
||||
f.write("|-------|---------------|---------|--------|----------|--------------|\n")
|
||||
|
||||
for result in all_results:
|
||||
model_name = result['model_name']
|
||||
total = result['total_samples']
|
||||
correct = result['correct_predictions']
|
||||
failed = result['failed_predictions']
|
||||
accuracy = result['accuracy'] * 100
|
||||
failure_rate = result['failure_rate'] * 100
|
||||
|
||||
f.write(f"| {model_name} | {total} | {correct} | {failed} | {accuracy:.2f}% | {failure_rate:.2f}% |\n")
|
||||
|
||||
# Detailed results for each model
|
||||
for result in all_results:
|
||||
f.write(f"\n## {result['model_name']} - Detailed Results\n\n")
|
||||
f.write("| Sample ID | Instruction | BBox | Predicted | Correct | Failed |\n")
|
||||
f.write("|-----------|-------------|------|-----------|---------|--------|\n")
|
||||
|
||||
for sample_result in result['results'][:10]: # Show first 10 samples
|
||||
sample_id = sample_result['id']
|
||||
instruction = sample_result['instruction'][:50] + "..." if len(sample_result['instruction']) > 50 else sample_result['instruction']
|
||||
bbox = str(sample_result['bbox'])
|
||||
predicted = str(sample_result['predicted_coords']) if sample_result['predicted_coords'] else "None"
|
||||
correct = "PASS" if sample_result['is_correct'] else "FAIL"
|
||||
failed = "YES" if sample_result['failed'] else "NO"
|
||||
|
||||
f.write(f"| {sample_id} | {instruction} | {bbox} | {predicted} | {correct} | {failed} |\n")
|
||||
|
||||
if len(result['results']) > 10:
|
||||
f.write(f"\n*Showing first 10 of {len(result['results'])} samples*\n")
|
||||
|
||||
print(f"\nResults saved to: {output_file}")
|
||||
|
||||
|
||||
def save_visualizations(all_results: List[dict], dataset_list, output_dir: str = "output") -> None:
|
||||
"""
|
||||
Save visualizations of predicted coordinates vs bboxes to an output folder.
|
||||
|
||||
Args:
|
||||
all_results: List of evaluation results for each model
|
||||
dataset_list: List of dataset samples
|
||||
output_dir: Output directory path
|
||||
"""
|
||||
# Create output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
for result in all_results:
|
||||
model_name = result['model_name'].replace('/', '_').replace('.', '_')
|
||||
model_dir = os.path.join(output_dir, model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
print(f"\nSaving visualizations for {result['model_name']}...")
|
||||
|
||||
for i, sample_result in enumerate(tqdm(result['results'][:10], desc=f"Saving {model_name} visualizations")):
|
||||
try:
|
||||
# Find the original sample
|
||||
sample_id = sample_result['id']
|
||||
sample = None
|
||||
for s in dataset_list:
|
||||
if s['id'] == sample_id:
|
||||
sample = s
|
||||
break
|
||||
|
||||
if sample is None:
|
||||
continue
|
||||
|
||||
# Get image and data
|
||||
image = sample['image'].copy()
|
||||
bbox = sample_result['bbox'] # [x1, y1, x2, y2]
|
||||
predicted_coords = sample_result['predicted_coords']
|
||||
is_correct = sample_result['is_correct']
|
||||
|
||||
# Draw on image
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
# Draw bounding box (ground truth) in green
|
||||
x1, y1, x2, y2 = bbox
|
||||
draw.rectangle([x1, y1, x2, y2], outline="green", width=3)
|
||||
draw.text((x1, y1-20), "Ground Truth", fill="green")
|
||||
|
||||
# Draw predicted click in red or blue
|
||||
if predicted_coords is not None:
|
||||
px, py = predicted_coords
|
||||
color = "blue" if is_correct else "red"
|
||||
# Draw crosshair
|
||||
crosshair_size = 15
|
||||
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=3)
|
||||
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=3)
|
||||
draw.text((px+10, py-20), f"Predicted ({px},{py})", fill=color)
|
||||
|
||||
# Add status text
|
||||
status = "CORRECT" if is_correct else "INCORRECT"
|
||||
status_color = "blue" if is_correct else "red"
|
||||
draw.text((10, 10), f"Status: {status}", fill=status_color)
|
||||
draw.text((10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black")
|
||||
|
||||
# Save image
|
||||
filename = f"sample_{i+1:02d}_{sample_id}_{status.lower()}.png"
|
||||
filepath = os.path.join(model_dir, filename)
|
||||
image.save(filepath)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving visualization for sample {sample_id}: {e}")
|
||||
continue
|
||||
|
||||
print(f"Visualizations saved to: {model_dir}")
|
||||
|
||||
|
||||
def save_prediction_visualization(image: Image.Image, instruction: str, predictions: List[dict],
|
||||
output_file: str = "interactive_prediction.png") -> None:
|
||||
"""
|
||||
Save visualization of multiple model predictions on a single image.
|
||||
|
||||
Args:
|
||||
image: PIL Image to visualize
|
||||
instruction: Instruction text
|
||||
predictions: List of prediction dicts with keys: model_name, coords, error
|
||||
output_file: Output file path
|
||||
"""
|
||||
# Create a copy of the image
|
||||
vis_image = image.copy()
|
||||
draw = ImageDraw.Draw(vis_image)
|
||||
|
||||
# Colors for different models
|
||||
colors = ["red", "blue", "orange", "purple", "brown", "pink", "gray", "olive"]
|
||||
|
||||
# Draw predictions
|
||||
for i, pred in enumerate(predictions):
|
||||
color = colors[i % len(colors)]
|
||||
model_name = pred['model_name']
|
||||
coords = pred.get('coords')
|
||||
error = pred.get('error')
|
||||
|
||||
if coords is not None:
|
||||
px, py = coords
|
||||
# Draw crosshair
|
||||
crosshair_size = 20
|
||||
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=4)
|
||||
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=4)
|
||||
# Draw model name
|
||||
draw.text((px+15, py+15), f"{model_name}: ({px},{py})", fill=color)
|
||||
else:
|
||||
# Draw error text
|
||||
draw.text((10, 50 + i*20), f"{model_name}: ERROR - {error}", fill=color)
|
||||
|
||||
# Add instruction at the top
|
||||
draw.text((10, 10), f"Instruction: {instruction}", fill="black")
|
||||
|
||||
# Save image
|
||||
vis_image.save(output_file)
|
||||
print(f"Prediction visualization saved to: {output_file}")
|
||||
|
||||
|
||||
def take_screenshot() -> Image.Image:
|
||||
"""
|
||||
Take a screenshot of the current screen.
|
||||
|
||||
Returns:
|
||||
PIL Image of the screenshot
|
||||
"""
|
||||
try:
|
||||
import pyautogui
|
||||
screenshot = pyautogui.screenshot()
|
||||
return screenshot
|
||||
except ImportError:
|
||||
print("pyautogui not installed. Please install it with: pip install pyautogui")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"Error taking screenshot: {e}")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user