added agent benchmarks

This commit is contained in:
Dillon DuPont
2025-07-30 13:41:58 -04:00
parent 2076ec7596
commit ffc88e2031
6 changed files with 553 additions and 134 deletions

View File

@@ -1,2 +1,3 @@
output/
interactive_output/
*_results.md

View File

@@ -0,0 +1,177 @@
# Computer Agent Benchmarks
This directory contains benchmarks designed to test agent providers in the Computer Agent SDK against reference agent implementations.
## Overview
The benchmark system evaluates models on GUI grounding tasks, specifically click prediction accuracy. It supports both:
- **Computer Agent SDK providers** (using model strings like `"huggingface-local/HelloKKMe/GTA1-7B"`)
- **Reference agent implementations** (custom model classes implementing the `ModelProtocol`)
## Available Benchmarks
### 1. ScreenSpot-v2 (`ss-v2.py`)
- **Dataset**: ScreenSpot-v2 (click-only GUI grounding)
- **Format**: Standard resolution screenshots
- **Task**: Predict click coordinates given an instruction and image
- **Metrics**: Accuracy, Error Rate, Timing, VRAM usage
### 2. ScreenSpot-Pro (`ss-pro.py`)
- **Dataset**: ScreenSpot-Pro (high-resolution click-only GUI grounding)
- **Format**: High-resolution screenshots
- **Task**: Predict click coordinates given an instruction and image
- **Metrics**: Accuracy, Error Rate, Timing, VRAM usage
### 3. Interactive Testing (`interactive.py`)
- **Real-time testing**: Take screenshots and visualize model predictions
- **Commands**:
- Type instruction → screenshot + test all models
- `screenshot` → take screenshot without prediction
- `models` → list available models
- `quit`/`exit` → exit tool
- **Output**: Visual predictions with crosshairs for each model
## Adding Reference Agent Implementations
### 1. Implement the ModelProtocol
Create a new file in `models/` directory implementing the `ModelProtocol`:
```python
from models.base import ModelProtocol
from typing import Optional, Tuple
from PIL import Image
class YourModelName(ModelProtocol):
def __init__(self, model_path: str):
self.model_path = model_path
self._model = None
@property
def model_name(self) -> str:
return self.model_path
async def load_model(self) -> None:
"""Load the model into memory."""
# Your model loading logic here
pass
async def unload_model(self) -> None:
"""Unload the model from memory."""
# Your model cleanup logic here
pass
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
"""
Predict click coordinates for the given image and instruction.
Args:
image: PIL Image to analyze
instruction: Text instruction describing what to click
Returns:
Tuple of (x, y) coordinates or None if prediction fails
"""
# Your prediction logic here
return (x, y) # Return predicted coordinates
```
### 2. Register Your Model
Add your model to the `get_available_models()` function in `utils.py`:
```python
def get_available_models() -> List[Union[str, ModelProtocol]]:
models = [
# Computer Agent SDK providers
"huggingface-local/HelloKKMe/GTA1-7B",
# Reference implementations
GTA1Model("HelloKKMe/GTA1-7B"),
YourModelName("path/to/your/model"), # Add your model here
]
return models
```
## Running Benchmarks
### 1. Configure Models
Edit `utils.py` to specify which models you want to test in `get_available_models()`.
### 2. Set Sample Count
Edit the benchmark script to change the number of samples:
```python
max_samples = 50 # Set to None to evaluate on full dataset
```
### 3. Run Benchmark
```bash
# ScreenSpot-v2 benchmark
python ss-v2.py
# ScreenSpot-Pro benchmark
python ss-pro.py
# Interactive testing
python interactive.py
```
## Output
### Console Output
```
Model Results:
Accuracy: 85.50%
Correct: 171/200
Errors: 5
Error Rate: 2.50%
Avg Time: 1.23s
Time Range: 0.89s - 2.45s
VRAM Max: 4.5GB
VRAM Avg: 3.4GB
```
### Generated Files
- **Markdown Report**: `*_results.md` with detailed results tables
- **Visualizations**: `output/` directory with prediction visualizations
- **Interactive Output**: `interactive_output/` for interactive session results
## Metrics Tracked
- **Accuracy**: Percentage of clicks within bounding boxes
- **Error Rate**: Percentage of failed predictions
- **Timing**: Average, min, max prediction times
- **VRAM Usage**: Maximum and average GPU memory usage
- **Per-sample Results**: Detailed breakdown for debugging
## Requirements
- Python 3.8+
- PyTorch (for VRAM tracking)
- PIL/Pillow (for image processing)
- datasets (for HuggingFace datasets)
- tqdm (for progress bars)
- Computer Agent SDK
## Architecture
The benchmark system is designed for:
- **Modularity**: Easy to add new models and benchmarks
- **Flexibility**: Works with any iterator of dicts with `image`, `bbox`, `instruction` keys
- **Performance**: VRAM tracking and timing analysis
- **Visualization**: Automatic generation of prediction visualizations
- **No Exception Handling**: Fails fast to surface real issues
## Results Table
| Model | Dataset | Accuracy | Error Rate | Avg Time | VRAM Max | VRAM Avg |
|-------|---------|----------|------------|----------|----------|----------|
| (coming soon) | | | | | | |
## Contributing
To add a new benchmark:
1. Create a new script following the pattern in `ss-v2.py`
2. Use the `evaluate_model()` function from utils
3. Ensure your dataset yields dicts with `image`, `bbox`, `instruction` keys
4. Update this README with benchmark details

View File

@@ -117,7 +117,7 @@ Output the coordinate pair exactly:
}
# Process inputs
image_inputs, video_inputs = process_vision_info([system_message, user_message])
image_inputs, video_inputs = process_vision_info([system_message, user_message]) # type: ignore
text = self.processor.apply_chat_template(
[system_message, user_message],
tokenize=False,

View File

@@ -7,6 +7,7 @@ Supports both ComputerAgent model strings and custom model classes.
"""
import asyncio
import time
from typing import Optional
from datasets import load_dataset
@@ -43,66 +44,67 @@ async def evaluate_model(model_wrapper: ModelWrapper, dataset, max_samples: Opti
total_samples = min(max_samples, total_samples)
correct_predictions = 0
failed_predictions = 0
error_predictions = 0
results = []
try:
for i in tqdm(range(total_samples), desc=f"Evaluating {model_wrapper.model_name}"):
sample = dataset[i]
# Extract sample data
image = sample['image']
instruction = sample['instruction']
bbox = sample['bbox'] # [x1, y1, x2, y2]
sample_id = sample['id']
# Predict click coordinates
try:
click_coords = await model_wrapper.predict_click(image, instruction)
# Check if prediction is correct
is_correct = is_click_in_bbox(click_coords, bbox)
if is_correct:
correct_predictions += 1
results.append({
'id': sample_id,
'instruction': instruction,
'bbox': bbox,
'predicted_coords': click_coords,
'is_correct': is_correct,
'failed': False
})
except Exception as e:
print(f"\nError predicting sample {sample_id}: {e}")
failed_predictions += 1
results.append({
'id': sample_id,
'instruction': instruction,
'bbox': bbox,
'predicted_coords': None,
'is_correct': False,
'failed': True,
'error': str(e)
})
for i in tqdm(range(total_samples), desc=f"Evaluating {model_wrapper.model_name}"):
sample = dataset[i]
# Extract sample data
image = sample['image']
instruction = sample['instruction']
bbox = sample['bbox'] # [x1, y1, x2, y2]
sample_id = sample['img_filename']
# Predict click coordinates with timing
start_time = time.time()
click_coords = await model_wrapper.predict_click(image, instruction)
prediction_time = time.time() - start_time
# Check if prediction is correct
is_correct = is_click_in_bbox(click_coords, bbox)
if is_correct:
correct_predictions += 1
results.append({
'id': sample_id,
'instruction': instruction,
'bbox': bbox,
'predicted_coords': click_coords,
'is_correct': is_correct,
'failed': False,
'prediction_time': prediction_time
})
finally:
# Unload model
await model_wrapper.unload_model()
# Unload model
await model_wrapper.unload_model()
# Calculate metrics
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
failure_rate = failed_predictions / total_samples if total_samples > 0 else 0.0
error_rate = error_predictions / total_samples if total_samples > 0 else 0.0
# Calculate timing statistics
successful_times = [r['prediction_time'] for r in results if not r['failed']]
avg_prediction_time = sum(successful_times) / len(successful_times) if successful_times else 0.0
min_prediction_time = min(successful_times) if successful_times else 0.0
max_prediction_time = max(successful_times) if successful_times else 0.0
# Get VRAM statistics
vram_stats = model_wrapper.get_vram_stats()
return {
'model_name': model_wrapper.model_name,
'total_samples': total_samples,
'correct_predictions': correct_predictions,
'failed_predictions': failed_predictions,
'failed_predictions': error_predictions,
'accuracy': accuracy,
'failure_rate': failure_rate,
'failure_rate': error_rate,
'avg_prediction_time': avg_prediction_time,
'min_prediction_time': min_prediction_time,
'max_prediction_time': max_prediction_time,
'vram_max_mb': vram_stats['max_mb'],
'vram_avg_mb': vram_stats['avg_mb'],
'results': results
}
@@ -123,26 +125,26 @@ async def main():
models = get_available_models()
# Evaluation settings
max_samples = 5 # Set to None to evaluate on full dataset
max_samples = 300 # Set to None to evaluate on full dataset
# Run evaluations
all_results = []
for model in models:
try:
model_wrapper = ModelWrapper(model)
result = await evaluate_model(model_wrapper, dataset_list, max_samples)
all_results.append(result)
# Print summary
print(f"\n{result['model_name']} Results:")
print(f" Accuracy: {result['accuracy']*100:.2f}%")
print(f" Correct: {result['correct_predictions']}/{result['total_samples']}")
print(f" Failed: {result['failed_predictions']}")
except Exception as e:
print(f"\nError evaluating model {model}: {e}")
continue
model_wrapper = ModelWrapper(model)
result = await evaluate_model(model_wrapper, dataset_list, max_samples)
all_results.append(result)
# Print summary
print(f"\n{result['model_name']} Results:")
print(f" Accuracy: {result['accuracy']*100:.2f}%")
print(f" Correct: {result['correct_predictions']}/{result['total_samples']}")
print(f" Errors: {result['failed_predictions']}")
print(f" Error Rate: {result['failure_rate']*100:.2f}%")
print(f" Avg Time: {result['avg_prediction_time']:.2f}s")
print(f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s")
print(f" VRAM Max: {result['vram_max_mb']:.1f}MB")
print(f" VRAM Avg: {result['vram_avg_mb']:.1f}MB")
# Save results
if all_results:

View File

@@ -0,0 +1,179 @@
#!/usr/bin/env python3
"""
ScreenSpot-Pro Benchmark Script
Evaluates models on the ScreenSpot-Pro dataset for click prediction accuracy.
Supports both ComputerAgent model strings and custom model classes.
"""
import asyncio
import time
from typing import Optional
from datasets import load_dataset
from tqdm import tqdm
from utils import (
ModelWrapper,
is_click_in_bbox,
save_results_to_markdown,
save_visualizations,
get_available_models
)
async def evaluate_model(model_wrapper: ModelWrapper, samples, max_samples: Optional[int] = None) -> dict:
"""
Evaluate a model on any iterable of samples.
Args:
model_wrapper: ModelWrapper instance
samples: Iterable of dicts with keys: image, bbox, instruction
max_samples: Maximum number of samples to evaluate (None for all)
Returns:
Dictionary with evaluation results
"""
print(f"\nEvaluating model: {model_wrapper.model_name}")
# Load model
await model_wrapper.load_model()
# Convert to list if needed and limit samples
if hasattr(samples, '__len__'):
total_samples = len(samples)
if max_samples is not None:
total_samples = min(max_samples, total_samples)
sample_list = list(samples)[:total_samples]
else:
# For iterators, take max_samples or all
sample_list = list(samples)
if max_samples is not None:
sample_list = sample_list[:max_samples]
total_samples = len(sample_list)
correct_predictions = 0
error_predictions = 0
results = []
for i, sample in enumerate(tqdm(sample_list, desc=f"Evaluating {model_wrapper.model_name}")):
# Extract required data (only these 3 keys matter)
image = sample['image']
instruction = sample['instruction']
bbox = sample['bbox'] # [x1, y1, x2, y2]
# Predict click coordinates with timing
start_time = time.time()
click_coords = await model_wrapper.predict_click(image, instruction)
prediction_time = time.time() - start_time
# Check if prediction is correct
is_correct = is_click_in_bbox(click_coords, bbox)
if is_correct:
correct_predictions += 1
results.append({
'sample_idx': i,
'instruction': instruction,
'bbox': bbox,
'predicted_coords': click_coords,
'is_correct': is_correct,
'failed': False,
'prediction_time': prediction_time
})
# Unload model
await model_wrapper.unload_model()
# Calculate metrics
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
error_rate = error_predictions / total_samples if total_samples > 0 else 0.0
# Calculate timing statistics
successful_times = [r['prediction_time'] for r in results if not r['failed']]
avg_prediction_time = sum(successful_times) / len(successful_times) if successful_times else 0.0
min_prediction_time = min(successful_times) if successful_times else 0.0
max_prediction_time = max(successful_times) if successful_times else 0.0
# Get VRAM statistics
vram_stats = model_wrapper.get_vram_stats()
return {
'model_name': model_wrapper.model_name,
'total_samples': total_samples,
'correct_predictions': correct_predictions,
'failed_predictions': error_predictions,
'accuracy': accuracy,
'failure_rate': error_rate,
'avg_prediction_time': avg_prediction_time,
'min_prediction_time': min_prediction_time,
'max_prediction_time': max_prediction_time,
'vram_max_mb': vram_stats['max_mb'],
'vram_avg_mb': vram_stats['avg_mb'],
'results': results
}
async def main():
"""
Main function to run the benchmark.
"""
# Load dataset
print("Loading ScreenSpot-v2 dataset...")
ds = load_dataset("lmms-lab/ScreenSpot-v2")
dataset = ds['train'] # type: ignore
# Convert to simple list of dicts with only required keys
samples = []
for item in dataset:
# Convert dataset item to dict if needed
item_dict = dict(item) if hasattr(item, 'keys') else item
# Convert ScreenSpot-v2 bbox format [x, y, w, h] to [x1, y1, x2, y2]
bbox_xywh = item_dict['bbox'] # type: ignore
x, y, w, h = bbox_xywh
bbox_xyxy = [x, y, x + w, y + h]
samples.append({
'image': item_dict['image'], # type: ignore
'instruction': item_dict['instruction'], # type: ignore
'bbox': bbox_xyxy
})
print(f"Dataset loaded: {len(samples)} samples")
# Get available models
models = get_available_models()
# Evaluation settings
max_samples = 500 # Set to None to evaluate on full dataset
# Run evaluations
all_results = []
for model in models:
model_wrapper = ModelWrapper(model)
result = await evaluate_model(model_wrapper, samples, max_samples)
all_results.append(result)
# Print summary
print(f"\n{result['model_name']} Results:")
print(f" Accuracy: {result['accuracy']*100:.2f}%")
print(f" Correct: {result['correct_predictions']}/{result['total_samples']}")
print(f" Errors: {result['failed_predictions']}")
print(f" Error Rate: {result['failure_rate']*100:.2f}%")
print(f" Avg Time: {result['avg_prediction_time']:.2f}s")
print(f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s")
print(f" VRAM Max: {result['vram_max_mb']:.1f}MB")
print(f" VRAM Avg: {result['vram_avg_mb']:.1f}MB")
# Save results
if all_results:
save_results_to_markdown(all_results, "screenspot_v2_results.md", title="ScreenSpot-v2 Benchmark Results")
save_visualizations(all_results, samples)
print("\nBenchmark completed successfully!")
else:
print("\nNo successful evaluations completed.")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -22,6 +22,33 @@ from agent.agent import ComputerAgent
from models import GTA1Model
from models.base import ModelProtocol
def get_vram_usage() -> dict:
"""
Get current VRAM usage statistics.
Returns:
Dictionary with VRAM usage info (in MB)
"""
if torch.cuda.is_available():
device = torch.cuda.current_device()
allocated = torch.cuda.memory_allocated(device) / 1024 / 1024 # Convert to MB
reserved = torch.cuda.memory_reserved(device) / 1024 / 1024 # Convert to MB
total = torch.cuda.get_device_properties(device).total_memory / 1024 / 1024
return {
'allocated_mb': allocated,
'reserved_mb': reserved,
'total_mb': total,
'free_mb': total - reserved
}
else:
return {
'allocated_mb': 0.0,
'reserved_mb': 0.0,
'total_mb': 0.0,
'free_mb': 0.0
}
def get_available_models() -> List[Union[str, ModelProtocol]]:
"""
Get list of available models for testing.
@@ -34,11 +61,11 @@ def get_available_models() -> List[Union[str, ModelProtocol]]:
models = [
# === ComputerAgent model strings ===
f"{local_provider}HelloKKMe/GTA1-7B",
# f"{local_provider}HelloKKMe/GTA1-32B", # Uncomment if you have this model
f"{local_provider}HelloKKMe/GTA1-32B",
# === Reference model classes ===
GTA1Model("HelloKKMe/GTA1-7B"),
# GTA1Model("HelloKKMe/GTA1-32B"), # Uncomment if you have this model
GTA1Model("HelloKKMe/GTA1-32B"),
]
return models
@@ -88,11 +115,12 @@ class ModelWrapper:
self.model = model
self.is_computer_agent = isinstance(model, str)
self.agent: Optional[ComputerAgent] = None
self.vram_usage_history: List[float] = [] # Track VRAM usage over time
if self.is_computer_agent:
self.model_name = str(model)
else:
self.model_name = f"models.{model.__class__.__name__}"
self.model_name = f"{model.__class__.__name__}('{getattr(model, 'model_name', 'unknown')}')"
async def load_model(self) -> None:
"""Load the model."""
@@ -100,6 +128,10 @@ class ModelWrapper:
self.agent = ComputerAgent(model=str(self.model))
else:
await self.model.load_model() # type: ignore
# Record initial VRAM usage after loading
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
async def unload_model(self) -> None:
"""Unload the model."""
@@ -111,10 +143,28 @@ class ModelWrapper:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Record VRAM usage after unloading
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
def get_vram_stats(self) -> dict:
"""Get VRAM usage statistics for this model."""
if not self.vram_usage_history:
return {'max_mb': 0.0, 'avg_mb': 0.0}
return {
'max_mb': max(self.vram_usage_history),
'avg_mb': sum(self.vram_usage_history) / len(self.vram_usage_history)
}
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
"""Predict click coordinates."""
# Record VRAM usage before prediction
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
if self.is_computer_agent:
if self.agent is None:
await self.load_model()
@@ -122,13 +172,24 @@ class ModelWrapper:
if self.agent is not None:
image_b64 = image_to_base64(image)
result = await self.agent.predict_click(instruction=instruction, image_b64=image_b64)
# Record VRAM usage after prediction
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
return result
return None
else:
return await self.model.predict_click(image, instruction) # type: ignore
result = await self.model.predict_click(image, instruction) # type: ignore
# Record VRAM usage after prediction
vram_info = get_vram_usage()
self.vram_usage_history.append(vram_info['allocated_mb'])
return result
def save_results_to_markdown(all_results: List[dict], output_file: str = "screenspot_pro_results.md") -> None:
def save_results_to_markdown(all_results: List[dict],output_file: str = "screenspot_pro_results.md", title: str = "ScreenSpot-Pro Benchmark Results") -> None:
"""
Save evaluation results to a markdown table.
@@ -137,39 +198,46 @@ def save_results_to_markdown(all_results: List[dict], output_file: str = "screen
output_file: Output markdown file path
"""
with open(output_file, 'w', encoding='utf-8') as f:
f.write("# ScreenSpot-Pro Benchmark Results\n\n")
f.write(f"# {title}\n\n")
f.write(f"**Evaluation Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
# Summary table
f.write("## Summary\n\n")
f.write("| Model | Total Samples | Correct | Failed | Accuracy | Failure Rate |\n")
f.write("|-------|---------------|---------|--------|----------|--------------|\n")
f.write("| Model | Total Samples | Correct | Errors | Accuracy | Error Rate | Avg Time (s) | Time Range (s) | VRAM Max (GB) | VRAM Avg (GB) |\n")
f.write("|-------|---------------|---------|--------|----------|------------|--------------|----------------|---------------|---------------|\n")
for result in all_results:
model_name = result['model_name']
total = result['total_samples']
correct = result['correct_predictions']
failed = result['failed_predictions']
errors = result['failed_predictions']
accuracy = result['accuracy'] * 100
failure_rate = result['failure_rate'] * 100
error_rate = result['failure_rate'] * 100
avg_time = result.get('avg_prediction_time', 0.0)
min_time = result.get('min_prediction_time', 0.0)
max_time = result.get('max_prediction_time', 0.0)
time_range = f"{min_time:.2f} - {max_time:.2f}"
vram_max = result.get('vram_max_mb', 0.0) / 1024
vram_avg = result.get('vram_avg_mb', 0.0) / 1024
f.write(f"| {model_name} | {total} | {correct} | {failed} | {accuracy:.2f}% | {failure_rate:.2f}% |\n")
f.write(f"| {model_name} | {total} | {correct} | {errors} | {accuracy:.2f}% | {error_rate:.2f}% | {avg_time:.2f} | {time_range} | {vram_max:.1f} | {vram_avg:.1f} |\n")
# Detailed results for each model
for result in all_results:
f.write(f"\n## {result['model_name']} - Detailed Results\n\n")
f.write("| Sample ID | Instruction | BBox | Predicted | Correct | Failed |\n")
f.write("|-----------|-------------|------|-----------|---------|--------|\n")
f.write("| Sample Index | Instruction | BBox | Predicted | Correct | Error | Time (s) |\n")
f.write("|-----------|-------------|------|-----------|---------|-------|----------|\n")
for sample_result in result['results'][:10]: # Show first 10 samples
sample_id = sample_result['id']
sample_idx = sample_result['sample_idx']
instruction = sample_result['instruction'][:50] + "..." if len(sample_result['instruction']) > 50 else sample_result['instruction']
bbox = str(sample_result['bbox'])
predicted = str(sample_result['predicted_coords']) if sample_result['predicted_coords'] else "None"
correct = "PASS" if sample_result['is_correct'] else "FAIL"
failed = "YES" if sample_result['failed'] else "NO"
error = "YES" if sample_result['failed'] else "NO"
pred_time = sample_result.get('prediction_time', 0.0)
f.write(f"| {sample_id} | {instruction} | {bbox} | {predicted} | {correct} | {failed} |\n")
f.write(f"| {sample_idx} | {instruction} | {bbox} | {predicted} | {correct} | {error} | {pred_time:.2f} |\n")
if len(result['results']) > 10:
f.write(f"\n*Showing first 10 of {len(result['results'])} samples*\n")
@@ -177,76 +245,68 @@ def save_results_to_markdown(all_results: List[dict], output_file: str = "screen
print(f"\nResults saved to: {output_file}")
def save_visualizations(all_results: List[dict], dataset_list, output_dir: str = "output") -> None:
def save_visualizations(all_results: List[dict], samples, output_dir: str = "output") -> None:
"""
Save visualizations of predicted coordinates vs bboxes to an output folder.
Args:
all_results: List of evaluation results for each model
dataset_list: List of dataset samples
samples: List of sample dicts with image, bbox, instruction keys
output_dir: Output directory path
"""
# Create output directory
os.makedirs(output_dir, exist_ok=True)
for result in all_results:
model_name = result['model_name'].replace('/', '_').replace('.', '_')
model_name = result['model_name'].replace('/', '_').replace('\\', '_')
model_dir = os.path.join(output_dir, model_name)
os.makedirs(model_dir, exist_ok=True)
print(f"\nSaving visualizations for {result['model_name']}...")
print(f"Saving visualizations for {result['model_name']}...")
# Save first 10 samples for visualization
for i, sample_result in enumerate(tqdm(result['results'][:10], desc=f"Saving {model_name} visualizations")):
try:
# Find the original sample
sample_id = sample_result['id']
sample = None
for s in dataset_list:
if s['id'] == sample_id:
sample = s
break
if sample is None:
continue
# Get image and data
image = sample['image'].copy()
bbox = sample_result['bbox'] # [x1, y1, x2, y2]
predicted_coords = sample_result['predicted_coords']
is_correct = sample_result['is_correct']
# Draw on image
draw = ImageDraw.Draw(image)
# Draw bounding box (ground truth) in green
x1, y1, x2, y2 = bbox
draw.rectangle([x1, y1, x2, y2], outline="green", width=3)
draw.text((x1, y1-20), "Ground Truth", fill="green")
# Draw predicted click in red or blue
if predicted_coords is not None:
px, py = predicted_coords
color = "blue" if is_correct else "red"
# Draw crosshair
crosshair_size = 15
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=3)
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=3)
draw.text((px+10, py-20), f"Predicted ({px},{py})", fill=color)
# Add status text
status = "CORRECT" if is_correct else "INCORRECT"
status_color = "blue" if is_correct else "red"
draw.text((10, 10), f"Status: {status}", fill=status_color)
draw.text((10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black")
# Save image
filename = f"sample_{i+1:02d}_{sample_id}_{status.lower()}.png"
filepath = os.path.join(model_dir, filename)
image.save(filepath)
except Exception as e:
print(f"Error saving visualization for sample {sample_id}: {e}")
# Get sample data using index
sample_idx = sample_result['sample_idx']
if sample_idx < len(samples):
sample = samples[sample_idx]
image = sample['image'].copy() # Make a copy to avoid modifying original
else:
print(f"Warning: Could not find sample at index {sample_idx}")
continue
bbox = sample_result['bbox']
predicted_coords = sample_result['predicted_coords']
is_correct = sample_result['is_correct']
# Draw on image
draw = ImageDraw.Draw(image)
# Draw bounding box (ground truth) in green
x1, y1, x2, y2 = bbox
draw.rectangle([x1, y1, x2, y2], outline="green", width=3)
draw.text((x1, y1-20), "Ground Truth", fill="green")
# Draw predicted click in red or blue
if predicted_coords is not None:
px, py = predicted_coords
color = "blue" if is_correct else "red"
# Draw crosshair
crosshair_size = 15
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=3)
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=3)
draw.text((px+10, py-20), f"Predicted ({px},{py})", fill=color)
# Add status text
status = "CORRECT" if is_correct else "INCORRECT"
status_color = "blue" if is_correct else "red"
draw.text((10, 10), f"Status: {status}", fill=status_color)
draw.text((10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black")
# Save image
filename = f"sample_{i+1:02d}_idx{sample_idx}_{status.lower()}.png"
filepath = os.path.join(model_dir, filename)
image.save(filepath)
print(f"Visualizations saved to: {model_dir}")