added GTA1 agent and click benchmarks (ss-pro, repl)

This commit is contained in:
Dillon DuPont
2025-07-29 20:48:44 -04:00
parent 3a67485e42
commit 2076ec7596
12 changed files with 939 additions and 35 deletions

View File

@@ -48,7 +48,11 @@ class HuggingFaceLocalAdapter(CustomLLM):
)
# Load processor
processor = AutoProcessor.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(
model_name,
min_pixels=3136,
max_pixels=4096 * 2160
)
# Cache them
self.models[model_name] = model

View File

@@ -7,5 +7,6 @@ from . import anthropic
from . import openai
from . import uitars
from . import omniparser
from . import gta1
__all__ = ["anthropic", "openai", "uitars", "omniparser"]
__all__ = ["anthropic", "openai", "uitars", "omniparser", "gta1"]

View File

@@ -1,5 +1,7 @@
"""
GTA1 agent loop implementation for click prediction using litellm.acompletion
Paper: https://arxiv.org/pdf/2507.05791
Code: https://github.com/Yan98/GTA1
"""
import asyncio
@@ -20,7 +22,7 @@ You are an expert UI element locator. Given a GUI image and a user's element des
Output the coordinate pair exactly:
(x,y)
'''
'''.strip()
def extract_coordinates(raw_string: str) -> Tuple[float, float]:
"""Extract coordinates from model output."""

View File

@@ -1,5 +1,7 @@
"""
OpenAI computer-use-preview agent loop implementation using liteLLM
Paper: https://arxiv.org/abs/2408.00203
Code: https://github.com/microsoft/OmniParser
"""
import asyncio

View File

@@ -1,5 +1,7 @@
"""
UITARS agent loop implementation using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B
Paper: https://arxiv.org/abs/2501.12326
Code: https://github.com/bytedance/UI-TARS
"""
import asyncio
@@ -79,6 +81,18 @@ Action: ...
{instruction}
"""
GROUNDING_UITARS_PROMPT_TEMPLATE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
## Output Format
Action: ...
## Action Space
click(point='<|box_start|>(x1,y1)<|box_end|>')
## User Instruction
{instruction}"""
def round_by_factor(number: float, factor: int) -> int:
"""Returns the closest integer to 'number' that is divisible by 'factor'."""
@@ -511,7 +525,7 @@ class UITARSConfig:
async def predict_step(
self,
messages: Messages,
messages: List[Dict[str, Any]],
model: str,
tools: Optional[List[Dict[str, Any]]] = None,
max_retries: Optional[int] = None,
@@ -729,13 +743,15 @@ class UITARSConfig:
Tuple with (x, y) coordinates or None
"""
try:
# Create a simple click instruction for UITARS
user_prompt = UITARS_PROMPT_TEMPLATE.format(
instruction=f"Click on: {instruction}",
action_space=UITARS_ACTION_SPACE,
language="English"
# Create prompt using grounding template
user_prompt = GROUNDING_UITARS_PROMPT_TEMPLATE.format(
instruction=instruction
)
# Process image for UITARS
processed_image, original_width, original_height = process_image_for_uitars(image_b64)
encoded_image = pil_to_base64(processed_image)
# Prepare messages for liteLLM
litellm_messages = [
{
@@ -746,46 +762,47 @@ class UITARSConfig:
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}}
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}
]
}
]
# Prepare API call kwargs
api_kwargs = {
"model": model,
"messages": litellm_messages,
"max_tokens": 100,
"temperature": 0.0,
"do_sample": False
}
# Call liteLLM with UITARS model
response = await litellm.acompletion(
model=model,
messages=litellm_messages,
max_tokens=100,
temperature=0.0
)
response = await litellm.acompletion(**api_kwargs)
# Extract response content
response_content = response.choices[0].message.content.strip() # type: ignore
# Parse UITARS response to extract click coordinates
parsed_responses = parse_uitars_response(response_content, 1024, 768) # Default dimensions
# Parse the response to extract click coordinates
# Look for click action with coordinates
click_pattern = r"click\(point='<\|box_start\|>\((\d+),(\d+)\)<\|box_end\|>'\)"
match = re.search(click_pattern, response_content)
if parsed_responses and len(parsed_responses) > 0:
action_type = parsed_responses[0].get("action_type")
if action_type == "click":
action_inputs = parsed_responses[0].get("action_inputs", {})
start_box = action_inputs.get("start_box")
if start_box:
# Parse coordinates from start_box
try:
coords = eval(start_box) # Parse the coordinate list
if len(coords) >= 2:
# Convert normalized coordinates back to pixel coordinates
x = int(coords[0] * 1024)
y = int(coords[1] * 768)
return (x, y)
except:
pass
if match:
x, y = int(match.group(1)), int(match.group(2))
# Scale coordinates back to original image dimensions
scale_x = original_width / processed_image.width
scale_y = original_height / processed_image.height
scaled_x = int(x * scale_x)
scaled_y = int(y * scale_y)
return (scaled_x, scaled_y)
return None
except Exception as e:
print(f"Error in UITARS predict_click: {e}")
# Log error and return None
print(f"Error in predict_click: {e}")
return None
def get_capabilities(self) -> List[AgentCapability]:

View File

@@ -0,0 +1,2 @@
output/
interactive_output/

View File

@@ -0,0 +1,201 @@
#!/usr/bin/env python3
"""
Interactive Click Prediction Tool
Takes screenshots and allows testing multiple models interactively.
Models are loaded/unloaded one at a time to avoid memory issues.
"""
import asyncio
import os
from datetime import datetime
from typing import List, Dict, Any
from utils import (
ModelWrapper,
take_screenshot,
save_prediction_visualization,
get_available_models
)
async def predict_with_all_models(image, instruction: str, models) -> List[Dict[str, Any]]:
"""
Predict click coordinates with all models sequentially.
Args:
image: PIL Image to analyze
instruction: Instruction text
models: List of model instances
Returns:
List of prediction results
"""
predictions = []
for model in models:
model_wrapper = ModelWrapper(model)
print(f"\n🔄 Loading {model_wrapper.model_name}...")
try:
# Load model
await model_wrapper.load_model()
# Predict
coords = await model_wrapper.predict_click(image, instruction)
predictions.append({
'model_name': model_wrapper.model_name,
'coords': coords,
'error': None
})
if coords:
print(f"{model_wrapper.model_name}: ({coords[0]}, {coords[1]})")
else:
print(f"{model_wrapper.model_name}: No prediction")
except Exception as e:
print(f"{model_wrapper.model_name}: ERROR - {str(e)}")
predictions.append({
'model_name': model_wrapper.model_name,
'coords': None,
'error': str(e)
})
finally:
# Always unload model to free memory
try:
await model_wrapper.unload_model()
print(f"🗑️ Unloaded {model_wrapper.model_name}")
except Exception as e:
print(f"⚠️ Error unloading {model_wrapper.model_name}: {e}")
return predictions
def print_header():
"""Print the interactive tool header."""
print("=" * 60)
print("🖱️ Interactive Click Prediction Tool")
print("=" * 60)
print("Commands:")
print(" • Type an instruction to test models on last screenshot")
print("'screenshot' - Take a new screenshot")
print("'models' - List available models")
print("'quit' or 'exit' - Exit the tool")
print("=" * 60)
print("💡 Tip: Take a screenshot first, then send instructions to test models!")
def print_models(models):
"""Print available models."""
print("\n📋 Available Models:")
for i, model in enumerate(models, 1):
if isinstance(model, str):
print(f" {i}. {model}")
else:
print(f" {i}. models.{model.__class__.__name__}")
async def main():
"""
Main interactive loop.
"""
print_header()
# Get available models
models = get_available_models()
print_models(models)
# Create output directory for visualizations
output_dir = "interactive_output"
os.makedirs(output_dir, exist_ok=True)
session_count = 0
last_screenshot = None
screenshot_timestamp = None
while True:
try:
# Get user input
print(f"\n{'='*40}")
user_input = input("🎯 Enter instruction (or command): ").strip()
if not user_input:
continue
# Handle commands
if user_input.lower() in ['quit', 'exit', 'q']:
print("👋 Goodbye!")
break
elif user_input.lower() == 'models':
print_models(models)
continue
elif user_input.lower() == 'screenshot':
print("📸 Taking screenshot...")
try:
last_screenshot = take_screenshot()
screenshot_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
screenshot_path = os.path.join(output_dir, f"screenshot_{screenshot_timestamp}.png")
last_screenshot.save(screenshot_path)
print(f"✅ Screenshot captured and saved to: {screenshot_path}")
print(f"📝 Ready for instructions! Screenshot size: {last_screenshot.size}")
except Exception as e:
print(f"❌ Error taking screenshot: {e}")
continue
# Handle instruction input
if last_screenshot is None:
print("⚠️ No screenshot available! Please take a screenshot first using 'screenshot' command.")
continue
session_count += 1
print(f"\n🎯 Session {session_count}: '{user_input}'")
print(f"📷 Using screenshot from: {screenshot_timestamp}")
# Predict with all models using last screenshot
print(f"\n🤖 Testing {len(models)} models on screenshot...")
predictions = await predict_with_all_models(last_screenshot, user_input, models)
# Display results summary
print(f"\n📊 Results Summary:")
print("-" * 50)
for pred in predictions:
if pred['coords']:
print(f"{pred['model_name']}: ({pred['coords'][0]}, {pred['coords'][1]})")
elif pred['error']:
print(f"{pred['model_name']}: ERROR - {pred['error']}")
else:
print(f"{pred['model_name']}: No prediction")
# Save visualization
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
vis_filename = f"session_{session_count:03d}_{timestamp}.png"
vis_path = os.path.join(output_dir, vis_filename)
try:
save_prediction_visualization(last_screenshot, user_input, predictions, vis_path)
print(f"\n💾 Visualization saved to: {vis_path}")
except Exception as e:
print(f"⚠️ Error saving visualization: {e}")
print(f"\n✨ Session {session_count} completed!")
except KeyboardInterrupt:
print("\n\n👋 Interrupted by user. Goodbye!")
break
except Exception as e:
print(f"\n❌ Unexpected error: {e}")
print("Continuing...")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\n👋 Goodbye!")
except Exception as e:
print(f"❌ Fatal error: {e}")

View File

@@ -0,0 +1,4 @@
from .base import ModelProtocol
from .gta1 import GTA1Model
__all__ = ["ModelProtocol", "GTA1Model"]

View File

@@ -0,0 +1,36 @@
"""
Base protocol for benchmark models.
"""
from typing import Protocol, Optional, Tuple
from PIL import Image
class ModelProtocol(Protocol):
"""Protocol for benchmark models that can predict click coordinates."""
@property
def model_name(self) -> str:
"""Return the name of the model."""
...
async def load_model(self) -> None:
"""Load the model into memory."""
...
async def unload_model(self) -> None:
"""Unload the model from memory."""
...
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates for the given image and instruction.
Args:
image: PIL Image to analyze
instruction: Text instruction describing what to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
...

View File

@@ -0,0 +1,162 @@
"""
GTA1 model implementation for benchmarking.
"""
from typing import Optional, Tuple
from PIL import Image
import torch
import re
import gc
from qwen_vl_utils import process_vision_info, smart_resize
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from .base import ModelProtocol
class GTA1Model:
"""Ground truth GTA1 model implementation."""
def __init__(self, model_path: str = "HelloKKMe/GTA1-7B"):
self.model_path = model_path
self.model = None
self.processor = None
self.max_new_tokens = 32
self.system_prompt = '''
You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point.
Output the coordinate pair exactly:
(x,y)
'''.strip()
@property
def model_name(self) -> str:
"""Return the name of the model."""
return f"GTA1-{self.model_path.split('/')[-1]}"
async def load_model(self) -> None:
"""Load the model into memory."""
if self.model is None:
print(f"Loading GTA1 model: {self.model_path}")
self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
self.model_path,
torch_dtype=torch.bfloat16,
device_map="auto"
)
self.processor = AutoProcessor.from_pretrained(
self.model_path,
min_pixels=3136,
max_pixels=4096 * 2160
)
print("GTA1 model loaded successfully")
async def unload_model(self) -> None:
"""Unload the model from memory."""
if self.model is not None:
print("Unloading GTA1 model from GPU...")
del self.model
del self.processor
self.model = None
self.processor = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
print("GTA1 model unloaded")
def _extract_coordinates(self, raw_string: str) -> Tuple[int, int]:
"""Extract coordinates from model output."""
try:
matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string)
return tuple(map(int, map(float, matches[0]))) # type: ignore
except:
return (0, 0)
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates for the given image and instruction.
Args:
image: PIL Image to analyze
instruction: Text instruction describing what to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
if self.model is None or self.processor is None:
await self.load_model()
assert self.processor is not None
assert self.model is not None
try:
width, height = image.width, image.height
# Resize image according to processor requirements
resized_height, resized_width = smart_resize(
image.height,
image.width,
factor=self.processor.image_processor.patch_size * self.processor.image_processor.merge_size,
min_pixels=self.processor.image_processor.min_pixels,
max_pixels=self.processor.image_processor.max_pixels,
)
resized_image = image.resize((resized_width, resized_height))
scale_x, scale_y = width / resized_width, height / resized_height
# Prepare messages
system_message = {
"role": "system",
"content": self.system_prompt.format(height=resized_height, width=resized_width)
}
user_message = {
"role": "user",
"content": [
{"type": "image", "image": resized_image},
{"type": "text", "text": instruction}
]
}
# Process inputs
image_inputs, video_inputs = process_vision_info([system_message, user_message])
text = self.processor.apply_chat_template(
[system_message, user_message],
tokenize=False,
add_generation_prompt=True
)
inputs = self.processor(
text=[text],
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pt"
)
inputs = inputs.to(self.model.device)
# Generate prediction
output_ids = self.model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
do_sample=False,
temperature=1.0,
use_cache=True
)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(inputs.input_ids, output_ids)
]
output_text = self.processor.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)[0]
# Extract and rescale coordinates
pred_x, pred_y = self._extract_coordinates(output_text)
pred_x = int(pred_x * scale_x)
pred_y = int(pred_y * scale_y)
return (pred_x, pred_y)
except Exception as e:
print(f"Error in GTA1 prediction: {e}")
return None

View File

@@ -0,0 +1,157 @@
#!/usr/bin/env python3
"""
ScreenSpot-Pro Benchmark Script
Evaluates models on the ScreenSpot-Pro dataset for click prediction accuracy.
Supports both ComputerAgent model strings and custom model classes.
"""
import asyncio
from typing import Optional
from datasets import load_dataset
from tqdm import tqdm
from utils import (
ModelWrapper,
is_click_in_bbox,
save_results_to_markdown,
save_visualizations,
get_available_models
)
async def evaluate_model(model_wrapper: ModelWrapper, dataset, max_samples: Optional[int] = None) -> dict:
"""
Evaluate a model on the ScreenSpot-Pro dataset.
Args:
model_wrapper: ModelWrapper instance
dataset: ScreenSpot-Pro dataset (list of samples)
max_samples: Maximum number of samples to evaluate (None for all)
Returns:
Dictionary with evaluation results
"""
print(f"\nEvaluating model: {model_wrapper.model_name}")
# Load model
await model_wrapper.load_model()
total_samples = len(dataset)
if max_samples is not None:
total_samples = min(max_samples, total_samples)
correct_predictions = 0
failed_predictions = 0
results = []
try:
for i in tqdm(range(total_samples), desc=f"Evaluating {model_wrapper.model_name}"):
sample = dataset[i]
# Extract sample data
image = sample['image']
instruction = sample['instruction']
bbox = sample['bbox'] # [x1, y1, x2, y2]
sample_id = sample['id']
# Predict click coordinates
try:
click_coords = await model_wrapper.predict_click(image, instruction)
# Check if prediction is correct
is_correct = is_click_in_bbox(click_coords, bbox)
if is_correct:
correct_predictions += 1
results.append({
'id': sample_id,
'instruction': instruction,
'bbox': bbox,
'predicted_coords': click_coords,
'is_correct': is_correct,
'failed': False
})
except Exception as e:
print(f"\nError predicting sample {sample_id}: {e}")
failed_predictions += 1
results.append({
'id': sample_id,
'instruction': instruction,
'bbox': bbox,
'predicted_coords': None,
'is_correct': False,
'failed': True,
'error': str(e)
})
finally:
# Unload model
await model_wrapper.unload_model()
# Calculate metrics
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
failure_rate = failed_predictions / total_samples if total_samples > 0 else 0.0
return {
'model_name': model_wrapper.model_name,
'total_samples': total_samples,
'correct_predictions': correct_predictions,
'failed_predictions': failed_predictions,
'accuracy': accuracy,
'failure_rate': failure_rate,
'results': results
}
async def main():
"""
Main function to run the benchmark.
"""
# Load dataset
print("Loading ScreenSpot-Pro dataset...")
ds = load_dataset("lmms-lab/ScreenSpot-Pro")
dataset = ds['train'] # type: ignore
# Convert to list to support indexing
dataset_list = list(dataset)
print(f"Dataset loaded: {len(dataset_list)} samples")
# Get available models
models = get_available_models()
# Evaluation settings
max_samples = 5 # Set to None to evaluate on full dataset
# Run evaluations
all_results = []
for model in models:
try:
model_wrapper = ModelWrapper(model)
result = await evaluate_model(model_wrapper, dataset_list, max_samples)
all_results.append(result)
# Print summary
print(f"\n{result['model_name']} Results:")
print(f" Accuracy: {result['accuracy']*100:.2f}%")
print(f" Correct: {result['correct_predictions']}/{result['total_samples']}")
print(f" Failed: {result['failed_predictions']}")
except Exception as e:
print(f"\nError evaluating model {model}: {e}")
continue
# Save results
if all_results:
save_results_to_markdown(all_results)
save_visualizations(all_results, dataset_list)
print("\nBenchmark completed successfully!")
else:
print("\nNo successful evaluations completed.")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,316 @@
#!/usr/bin/env python3
"""
Shared utilities for ScreenSpot-Pro benchmarking and interactive testing.
"""
import asyncio
import base64
import os
import sys
from datetime import datetime
from io import BytesIO
from typing import List, Union, Tuple, Optional
from PIL import Image, ImageDraw
from tqdm import tqdm
import gc
import torch
# Add parent directory to path for imports
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from agent.agent import ComputerAgent
from models import GTA1Model
from models.base import ModelProtocol
def get_available_models() -> List[Union[str, ModelProtocol]]:
"""
Get list of available models for testing.
Returns:
List of model strings and model classes
"""
local_provider = "huggingface-local/" # Options: huggingface-local/ or mlx/
models = [
# === ComputerAgent model strings ===
f"{local_provider}HelloKKMe/GTA1-7B",
# f"{local_provider}HelloKKMe/GTA1-32B", # Uncomment if you have this model
# === Reference model classes ===
GTA1Model("HelloKKMe/GTA1-7B"),
# GTA1Model("HelloKKMe/GTA1-32B"), # Uncomment if you have this model
]
return models
def is_click_in_bbox(click_coords: Optional[Tuple[int, int]], bbox: List[int]) -> bool:
"""
Check if click coordinates are within the bounding box.
Args:
click_coords: (x, y) coordinates or None
bbox: [x1, y1, x2, y2] bounding box
Returns:
True if click is within bbox, False otherwise
"""
if click_coords is None:
return False
x, y = click_coords
x1, y1, x2, y2 = bbox
return x1 <= x <= x2 and y1 <= y <= y2
def image_to_base64(image: Image.Image) -> str:
"""
Convert PIL Image to base64 string.
Args:
image: PIL Image
Returns:
Base64 encoded image string
"""
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
class ModelWrapper:
"""
Wrapper to provide unified interface for both ComputerAgent and custom models.
"""
def __init__(self, model: Union[str, ModelProtocol]):
self.model = model
self.is_computer_agent = isinstance(model, str)
self.agent: Optional[ComputerAgent] = None
if self.is_computer_agent:
self.model_name = str(model)
else:
self.model_name = f"models.{model.__class__.__name__}"
async def load_model(self) -> None:
"""Load the model."""
if self.is_computer_agent:
self.agent = ComputerAgent(model=str(self.model))
else:
await self.model.load_model() # type: ignore
async def unload_model(self) -> None:
"""Unload the model."""
if not self.is_computer_agent:
await self.model.unload_model() # type: ignore
else:
del self.agent
self.agent = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
"""Predict click coordinates."""
if self.is_computer_agent:
if self.agent is None:
await self.load_model()
if self.agent is not None:
image_b64 = image_to_base64(image)
result = await self.agent.predict_click(instruction=instruction, image_b64=image_b64)
return result
return None
else:
return await self.model.predict_click(image, instruction) # type: ignore
def save_results_to_markdown(all_results: List[dict], output_file: str = "screenspot_pro_results.md") -> None:
"""
Save evaluation results to a markdown table.
Args:
all_results: List of evaluation results for each model
output_file: Output markdown file path
"""
with open(output_file, 'w', encoding='utf-8') as f:
f.write("# ScreenSpot-Pro Benchmark Results\n\n")
f.write(f"**Evaluation Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
# Summary table
f.write("## Summary\n\n")
f.write("| Model | Total Samples | Correct | Failed | Accuracy | Failure Rate |\n")
f.write("|-------|---------------|---------|--------|----------|--------------|\n")
for result in all_results:
model_name = result['model_name']
total = result['total_samples']
correct = result['correct_predictions']
failed = result['failed_predictions']
accuracy = result['accuracy'] * 100
failure_rate = result['failure_rate'] * 100
f.write(f"| {model_name} | {total} | {correct} | {failed} | {accuracy:.2f}% | {failure_rate:.2f}% |\n")
# Detailed results for each model
for result in all_results:
f.write(f"\n## {result['model_name']} - Detailed Results\n\n")
f.write("| Sample ID | Instruction | BBox | Predicted | Correct | Failed |\n")
f.write("|-----------|-------------|------|-----------|---------|--------|\n")
for sample_result in result['results'][:10]: # Show first 10 samples
sample_id = sample_result['id']
instruction = sample_result['instruction'][:50] + "..." if len(sample_result['instruction']) > 50 else sample_result['instruction']
bbox = str(sample_result['bbox'])
predicted = str(sample_result['predicted_coords']) if sample_result['predicted_coords'] else "None"
correct = "PASS" if sample_result['is_correct'] else "FAIL"
failed = "YES" if sample_result['failed'] else "NO"
f.write(f"| {sample_id} | {instruction} | {bbox} | {predicted} | {correct} | {failed} |\n")
if len(result['results']) > 10:
f.write(f"\n*Showing first 10 of {len(result['results'])} samples*\n")
print(f"\nResults saved to: {output_file}")
def save_visualizations(all_results: List[dict], dataset_list, output_dir: str = "output") -> None:
"""
Save visualizations of predicted coordinates vs bboxes to an output folder.
Args:
all_results: List of evaluation results for each model
dataset_list: List of dataset samples
output_dir: Output directory path
"""
# Create output directory
os.makedirs(output_dir, exist_ok=True)
for result in all_results:
model_name = result['model_name'].replace('/', '_').replace('.', '_')
model_dir = os.path.join(output_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
print(f"\nSaving visualizations for {result['model_name']}...")
for i, sample_result in enumerate(tqdm(result['results'][:10], desc=f"Saving {model_name} visualizations")):
try:
# Find the original sample
sample_id = sample_result['id']
sample = None
for s in dataset_list:
if s['id'] == sample_id:
sample = s
break
if sample is None:
continue
# Get image and data
image = sample['image'].copy()
bbox = sample_result['bbox'] # [x1, y1, x2, y2]
predicted_coords = sample_result['predicted_coords']
is_correct = sample_result['is_correct']
# Draw on image
draw = ImageDraw.Draw(image)
# Draw bounding box (ground truth) in green
x1, y1, x2, y2 = bbox
draw.rectangle([x1, y1, x2, y2], outline="green", width=3)
draw.text((x1, y1-20), "Ground Truth", fill="green")
# Draw predicted click in red or blue
if predicted_coords is not None:
px, py = predicted_coords
color = "blue" if is_correct else "red"
# Draw crosshair
crosshair_size = 15
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=3)
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=3)
draw.text((px+10, py-20), f"Predicted ({px},{py})", fill=color)
# Add status text
status = "CORRECT" if is_correct else "INCORRECT"
status_color = "blue" if is_correct else "red"
draw.text((10, 10), f"Status: {status}", fill=status_color)
draw.text((10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black")
# Save image
filename = f"sample_{i+1:02d}_{sample_id}_{status.lower()}.png"
filepath = os.path.join(model_dir, filename)
image.save(filepath)
except Exception as e:
print(f"Error saving visualization for sample {sample_id}: {e}")
continue
print(f"Visualizations saved to: {model_dir}")
def save_prediction_visualization(image: Image.Image, instruction: str, predictions: List[dict],
output_file: str = "interactive_prediction.png") -> None:
"""
Save visualization of multiple model predictions on a single image.
Args:
image: PIL Image to visualize
instruction: Instruction text
predictions: List of prediction dicts with keys: model_name, coords, error
output_file: Output file path
"""
# Create a copy of the image
vis_image = image.copy()
draw = ImageDraw.Draw(vis_image)
# Colors for different models
colors = ["red", "blue", "orange", "purple", "brown", "pink", "gray", "olive"]
# Draw predictions
for i, pred in enumerate(predictions):
color = colors[i % len(colors)]
model_name = pred['model_name']
coords = pred.get('coords')
error = pred.get('error')
if coords is not None:
px, py = coords
# Draw crosshair
crosshair_size = 20
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=4)
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=4)
# Draw model name
draw.text((px+15, py+15), f"{model_name}: ({px},{py})", fill=color)
else:
# Draw error text
draw.text((10, 50 + i*20), f"{model_name}: ERROR - {error}", fill=color)
# Add instruction at the top
draw.text((10, 10), f"Instruction: {instruction}", fill="black")
# Save image
vis_image.save(output_file)
print(f"Prediction visualization saved to: {output_file}")
def take_screenshot() -> Image.Image:
"""
Take a screenshot of the current screen.
Returns:
PIL Image of the screenshot
"""
try:
import pyautogui
screenshot = pyautogui.screenshot()
return screenshot
except ImportError:
print("pyautogui not installed. Please install it with: pip install pyautogui")
raise
except Exception as e:
print(f"Error taking screenshot: {e}")
raise