mirror of
https://github.com/trycua/computer.git
synced 2026-01-05 04:50:08 -06:00
added agent benchmarks
This commit is contained in:
1
libs/python/agent/benchmarks/.gitignore
vendored
1
libs/python/agent/benchmarks/.gitignore
vendored
@@ -1,2 +1,3 @@
|
||||
output/
|
||||
interactive_output/
|
||||
*_results.md
|
||||
177
libs/python/agent/benchmarks/README.md
Normal file
177
libs/python/agent/benchmarks/README.md
Normal file
@@ -0,0 +1,177 @@
|
||||
# Computer Agent Benchmarks
|
||||
|
||||
This directory contains benchmarks designed to test agent providers in the Computer Agent SDK against reference agent implementations.
|
||||
|
||||
## Overview
|
||||
|
||||
The benchmark system evaluates models on GUI grounding tasks, specifically click prediction accuracy. It supports both:
|
||||
- **Computer Agent SDK providers** (using model strings like `"huggingface-local/HelloKKMe/GTA1-7B"`)
|
||||
- **Reference agent implementations** (custom model classes implementing the `ModelProtocol`)
|
||||
|
||||
## Available Benchmarks
|
||||
|
||||
### 1. ScreenSpot-v2 (`ss-v2.py`)
|
||||
- **Dataset**: ScreenSpot-v2 (click-only GUI grounding)
|
||||
- **Format**: Standard resolution screenshots
|
||||
- **Task**: Predict click coordinates given an instruction and image
|
||||
- **Metrics**: Accuracy, Error Rate, Timing, VRAM usage
|
||||
|
||||
### 2. ScreenSpot-Pro (`ss-pro.py`)
|
||||
- **Dataset**: ScreenSpot-Pro (high-resolution click-only GUI grounding)
|
||||
- **Format**: High-resolution screenshots
|
||||
- **Task**: Predict click coordinates given an instruction and image
|
||||
- **Metrics**: Accuracy, Error Rate, Timing, VRAM usage
|
||||
|
||||
### 3. Interactive Testing (`interactive.py`)
|
||||
- **Real-time testing**: Take screenshots and visualize model predictions
|
||||
- **Commands**:
|
||||
- Type instruction → screenshot + test all models
|
||||
- `screenshot` → take screenshot without prediction
|
||||
- `models` → list available models
|
||||
- `quit`/`exit` → exit tool
|
||||
- **Output**: Visual predictions with crosshairs for each model
|
||||
|
||||
## Adding Reference Agent Implementations
|
||||
|
||||
### 1. Implement the ModelProtocol
|
||||
|
||||
Create a new file in `models/` directory implementing the `ModelProtocol`:
|
||||
|
||||
```python
|
||||
from models.base import ModelProtocol
|
||||
from typing import Optional, Tuple
|
||||
from PIL import Image
|
||||
|
||||
class YourModelName(ModelProtocol):
|
||||
def __init__(self, model_path: str):
|
||||
self.model_path = model_path
|
||||
self._model = None
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self.model_path
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model into memory."""
|
||||
# Your model loading logic here
|
||||
pass
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the model from memory."""
|
||||
# Your model cleanup logic here
|
||||
pass
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
"""
|
||||
Predict click coordinates for the given image and instruction.
|
||||
|
||||
Args:
|
||||
image: PIL Image to analyze
|
||||
instruction: Text instruction describing what to click
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates or None if prediction fails
|
||||
"""
|
||||
# Your prediction logic here
|
||||
return (x, y) # Return predicted coordinates
|
||||
```
|
||||
|
||||
### 2. Register Your Model
|
||||
|
||||
Add your model to the `get_available_models()` function in `utils.py`:
|
||||
|
||||
```python
|
||||
def get_available_models() -> List[Union[str, ModelProtocol]]:
|
||||
models = [
|
||||
# Computer Agent SDK providers
|
||||
"huggingface-local/HelloKKMe/GTA1-7B",
|
||||
|
||||
# Reference implementations
|
||||
GTA1Model("HelloKKMe/GTA1-7B"),
|
||||
YourModelName("path/to/your/model"), # Add your model here
|
||||
]
|
||||
return models
|
||||
```
|
||||
|
||||
## Running Benchmarks
|
||||
|
||||
### 1. Configure Models
|
||||
Edit `utils.py` to specify which models you want to test in `get_available_models()`.
|
||||
|
||||
### 2. Set Sample Count
|
||||
Edit the benchmark script to change the number of samples:
|
||||
```python
|
||||
max_samples = 50 # Set to None to evaluate on full dataset
|
||||
```
|
||||
|
||||
### 3. Run Benchmark
|
||||
```bash
|
||||
# ScreenSpot-v2 benchmark
|
||||
python ss-v2.py
|
||||
|
||||
# ScreenSpot-Pro benchmark
|
||||
python ss-pro.py
|
||||
|
||||
# Interactive testing
|
||||
python interactive.py
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
### Console Output
|
||||
```
|
||||
Model Results:
|
||||
Accuracy: 85.50%
|
||||
Correct: 171/200
|
||||
Errors: 5
|
||||
Error Rate: 2.50%
|
||||
Avg Time: 1.23s
|
||||
Time Range: 0.89s - 2.45s
|
||||
VRAM Max: 4.5GB
|
||||
VRAM Avg: 3.4GB
|
||||
```
|
||||
|
||||
### Generated Files
|
||||
- **Markdown Report**: `*_results.md` with detailed results tables
|
||||
- **Visualizations**: `output/` directory with prediction visualizations
|
||||
- **Interactive Output**: `interactive_output/` for interactive session results
|
||||
|
||||
## Metrics Tracked
|
||||
|
||||
- **Accuracy**: Percentage of clicks within bounding boxes
|
||||
- **Error Rate**: Percentage of failed predictions
|
||||
- **Timing**: Average, min, max prediction times
|
||||
- **VRAM Usage**: Maximum and average GPU memory usage
|
||||
- **Per-sample Results**: Detailed breakdown for debugging
|
||||
|
||||
## Requirements
|
||||
|
||||
- Python 3.8+
|
||||
- PyTorch (for VRAM tracking)
|
||||
- PIL/Pillow (for image processing)
|
||||
- datasets (for HuggingFace datasets)
|
||||
- tqdm (for progress bars)
|
||||
- Computer Agent SDK
|
||||
|
||||
## Architecture
|
||||
|
||||
The benchmark system is designed for:
|
||||
- **Modularity**: Easy to add new models and benchmarks
|
||||
- **Flexibility**: Works with any iterator of dicts with `image`, `bbox`, `instruction` keys
|
||||
- **Performance**: VRAM tracking and timing analysis
|
||||
- **Visualization**: Automatic generation of prediction visualizations
|
||||
- **No Exception Handling**: Fails fast to surface real issues
|
||||
|
||||
## Results Table
|
||||
|
||||
| Model | Dataset | Accuracy | Error Rate | Avg Time | VRAM Max | VRAM Avg |
|
||||
|-------|---------|----------|------------|----------|----------|----------|
|
||||
| (coming soon) | | | | | | |
|
||||
|
||||
## Contributing
|
||||
|
||||
To add a new benchmark:
|
||||
1. Create a new script following the pattern in `ss-v2.py`
|
||||
2. Use the `evaluate_model()` function from utils
|
||||
3. Ensure your dataset yields dicts with `image`, `bbox`, `instruction` keys
|
||||
4. Update this README with benchmark details
|
||||
@@ -117,7 +117,7 @@ Output the coordinate pair exactly:
|
||||
}
|
||||
|
||||
# Process inputs
|
||||
image_inputs, video_inputs = process_vision_info([system_message, user_message])
|
||||
image_inputs, video_inputs = process_vision_info([system_message, user_message]) # type: ignore
|
||||
text = self.processor.apply_chat_template(
|
||||
[system_message, user_message],
|
||||
tokenize=False,
|
||||
|
||||
@@ -7,6 +7,7 @@ Supports both ComputerAgent model strings and custom model classes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
@@ -43,66 +44,67 @@ async def evaluate_model(model_wrapper: ModelWrapper, dataset, max_samples: Opti
|
||||
total_samples = min(max_samples, total_samples)
|
||||
|
||||
correct_predictions = 0
|
||||
failed_predictions = 0
|
||||
error_predictions = 0
|
||||
results = []
|
||||
|
||||
try:
|
||||
for i in tqdm(range(total_samples), desc=f"Evaluating {model_wrapper.model_name}"):
|
||||
sample = dataset[i]
|
||||
|
||||
# Extract sample data
|
||||
image = sample['image']
|
||||
instruction = sample['instruction']
|
||||
bbox = sample['bbox'] # [x1, y1, x2, y2]
|
||||
sample_id = sample['id']
|
||||
|
||||
# Predict click coordinates
|
||||
try:
|
||||
click_coords = await model_wrapper.predict_click(image, instruction)
|
||||
|
||||
# Check if prediction is correct
|
||||
is_correct = is_click_in_bbox(click_coords, bbox)
|
||||
|
||||
if is_correct:
|
||||
correct_predictions += 1
|
||||
|
||||
results.append({
|
||||
'id': sample_id,
|
||||
'instruction': instruction,
|
||||
'bbox': bbox,
|
||||
'predicted_coords': click_coords,
|
||||
'is_correct': is_correct,
|
||||
'failed': False
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError predicting sample {sample_id}: {e}")
|
||||
failed_predictions += 1
|
||||
results.append({
|
||||
'id': sample_id,
|
||||
'instruction': instruction,
|
||||
'bbox': bbox,
|
||||
'predicted_coords': None,
|
||||
'is_correct': False,
|
||||
'failed': True,
|
||||
'error': str(e)
|
||||
})
|
||||
for i in tqdm(range(total_samples), desc=f"Evaluating {model_wrapper.model_name}"):
|
||||
sample = dataset[i]
|
||||
|
||||
# Extract sample data
|
||||
image = sample['image']
|
||||
instruction = sample['instruction']
|
||||
bbox = sample['bbox'] # [x1, y1, x2, y2]
|
||||
sample_id = sample['img_filename']
|
||||
|
||||
# Predict click coordinates with timing
|
||||
start_time = time.time()
|
||||
click_coords = await model_wrapper.predict_click(image, instruction)
|
||||
prediction_time = time.time() - start_time
|
||||
|
||||
# Check if prediction is correct
|
||||
is_correct = is_click_in_bbox(click_coords, bbox)
|
||||
|
||||
if is_correct:
|
||||
correct_predictions += 1
|
||||
|
||||
results.append({
|
||||
'id': sample_id,
|
||||
'instruction': instruction,
|
||||
'bbox': bbox,
|
||||
'predicted_coords': click_coords,
|
||||
'is_correct': is_correct,
|
||||
'failed': False,
|
||||
'prediction_time': prediction_time
|
||||
})
|
||||
|
||||
finally:
|
||||
# Unload model
|
||||
await model_wrapper.unload_model()
|
||||
# Unload model
|
||||
await model_wrapper.unload_model()
|
||||
|
||||
# Calculate metrics
|
||||
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
|
||||
failure_rate = failed_predictions / total_samples if total_samples > 0 else 0.0
|
||||
error_rate = error_predictions / total_samples if total_samples > 0 else 0.0
|
||||
|
||||
# Calculate timing statistics
|
||||
successful_times = [r['prediction_time'] for r in results if not r['failed']]
|
||||
avg_prediction_time = sum(successful_times) / len(successful_times) if successful_times else 0.0
|
||||
min_prediction_time = min(successful_times) if successful_times else 0.0
|
||||
max_prediction_time = max(successful_times) if successful_times else 0.0
|
||||
|
||||
# Get VRAM statistics
|
||||
vram_stats = model_wrapper.get_vram_stats()
|
||||
|
||||
return {
|
||||
'model_name': model_wrapper.model_name,
|
||||
'total_samples': total_samples,
|
||||
'correct_predictions': correct_predictions,
|
||||
'failed_predictions': failed_predictions,
|
||||
'failed_predictions': error_predictions,
|
||||
'accuracy': accuracy,
|
||||
'failure_rate': failure_rate,
|
||||
'failure_rate': error_rate,
|
||||
'avg_prediction_time': avg_prediction_time,
|
||||
'min_prediction_time': min_prediction_time,
|
||||
'max_prediction_time': max_prediction_time,
|
||||
'vram_max_mb': vram_stats['max_mb'],
|
||||
'vram_avg_mb': vram_stats['avg_mb'],
|
||||
'results': results
|
||||
}
|
||||
|
||||
@@ -123,26 +125,26 @@ async def main():
|
||||
models = get_available_models()
|
||||
|
||||
# Evaluation settings
|
||||
max_samples = 5 # Set to None to evaluate on full dataset
|
||||
max_samples = 300 # Set to None to evaluate on full dataset
|
||||
|
||||
# Run evaluations
|
||||
all_results = []
|
||||
|
||||
for model in models:
|
||||
try:
|
||||
model_wrapper = ModelWrapper(model)
|
||||
result = await evaluate_model(model_wrapper, dataset_list, max_samples)
|
||||
all_results.append(result)
|
||||
|
||||
# Print summary
|
||||
print(f"\n{result['model_name']} Results:")
|
||||
print(f" Accuracy: {result['accuracy']*100:.2f}%")
|
||||
print(f" Correct: {result['correct_predictions']}/{result['total_samples']}")
|
||||
print(f" Failed: {result['failed_predictions']}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\nError evaluating model {model}: {e}")
|
||||
continue
|
||||
model_wrapper = ModelWrapper(model)
|
||||
result = await evaluate_model(model_wrapper, dataset_list, max_samples)
|
||||
all_results.append(result)
|
||||
|
||||
# Print summary
|
||||
print(f"\n{result['model_name']} Results:")
|
||||
print(f" Accuracy: {result['accuracy']*100:.2f}%")
|
||||
print(f" Correct: {result['correct_predictions']}/{result['total_samples']}")
|
||||
print(f" Errors: {result['failed_predictions']}")
|
||||
print(f" Error Rate: {result['failure_rate']*100:.2f}%")
|
||||
print(f" Avg Time: {result['avg_prediction_time']:.2f}s")
|
||||
print(f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s")
|
||||
print(f" VRAM Max: {result['vram_max_mb']:.1f}MB")
|
||||
print(f" VRAM Avg: {result['vram_avg_mb']:.1f}MB")
|
||||
|
||||
# Save results
|
||||
if all_results:
|
||||
|
||||
179
libs/python/agent/benchmarks/ss-v2.py
Normal file
179
libs/python/agent/benchmarks/ss-v2.py
Normal file
@@ -0,0 +1,179 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
ScreenSpot-Pro Benchmark Script
|
||||
|
||||
Evaluates models on the ScreenSpot-Pro dataset for click prediction accuracy.
|
||||
Supports both ComputerAgent model strings and custom model classes.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from utils import (
|
||||
ModelWrapper,
|
||||
is_click_in_bbox,
|
||||
save_results_to_markdown,
|
||||
save_visualizations,
|
||||
get_available_models
|
||||
)
|
||||
|
||||
|
||||
async def evaluate_model(model_wrapper: ModelWrapper, samples, max_samples: Optional[int] = None) -> dict:
|
||||
"""
|
||||
Evaluate a model on any iterable of samples.
|
||||
|
||||
Args:
|
||||
model_wrapper: ModelWrapper instance
|
||||
samples: Iterable of dicts with keys: image, bbox, instruction
|
||||
max_samples: Maximum number of samples to evaluate (None for all)
|
||||
|
||||
Returns:
|
||||
Dictionary with evaluation results
|
||||
"""
|
||||
print(f"\nEvaluating model: {model_wrapper.model_name}")
|
||||
|
||||
# Load model
|
||||
await model_wrapper.load_model()
|
||||
|
||||
# Convert to list if needed and limit samples
|
||||
if hasattr(samples, '__len__'):
|
||||
total_samples = len(samples)
|
||||
if max_samples is not None:
|
||||
total_samples = min(max_samples, total_samples)
|
||||
sample_list = list(samples)[:total_samples]
|
||||
else:
|
||||
# For iterators, take max_samples or all
|
||||
sample_list = list(samples)
|
||||
if max_samples is not None:
|
||||
sample_list = sample_list[:max_samples]
|
||||
total_samples = len(sample_list)
|
||||
|
||||
correct_predictions = 0
|
||||
error_predictions = 0
|
||||
results = []
|
||||
|
||||
for i, sample in enumerate(tqdm(sample_list, desc=f"Evaluating {model_wrapper.model_name}")):
|
||||
# Extract required data (only these 3 keys matter)
|
||||
image = sample['image']
|
||||
instruction = sample['instruction']
|
||||
bbox = sample['bbox'] # [x1, y1, x2, y2]
|
||||
|
||||
# Predict click coordinates with timing
|
||||
start_time = time.time()
|
||||
click_coords = await model_wrapper.predict_click(image, instruction)
|
||||
prediction_time = time.time() - start_time
|
||||
|
||||
# Check if prediction is correct
|
||||
is_correct = is_click_in_bbox(click_coords, bbox)
|
||||
|
||||
if is_correct:
|
||||
correct_predictions += 1
|
||||
|
||||
results.append({
|
||||
'sample_idx': i,
|
||||
'instruction': instruction,
|
||||
'bbox': bbox,
|
||||
'predicted_coords': click_coords,
|
||||
'is_correct': is_correct,
|
||||
'failed': False,
|
||||
'prediction_time': prediction_time
|
||||
})
|
||||
|
||||
# Unload model
|
||||
await model_wrapper.unload_model()
|
||||
|
||||
# Calculate metrics
|
||||
accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0
|
||||
error_rate = error_predictions / total_samples if total_samples > 0 else 0.0
|
||||
|
||||
# Calculate timing statistics
|
||||
successful_times = [r['prediction_time'] for r in results if not r['failed']]
|
||||
avg_prediction_time = sum(successful_times) / len(successful_times) if successful_times else 0.0
|
||||
min_prediction_time = min(successful_times) if successful_times else 0.0
|
||||
max_prediction_time = max(successful_times) if successful_times else 0.0
|
||||
|
||||
# Get VRAM statistics
|
||||
vram_stats = model_wrapper.get_vram_stats()
|
||||
|
||||
return {
|
||||
'model_name': model_wrapper.model_name,
|
||||
'total_samples': total_samples,
|
||||
'correct_predictions': correct_predictions,
|
||||
'failed_predictions': error_predictions,
|
||||
'accuracy': accuracy,
|
||||
'failure_rate': error_rate,
|
||||
'avg_prediction_time': avg_prediction_time,
|
||||
'min_prediction_time': min_prediction_time,
|
||||
'max_prediction_time': max_prediction_time,
|
||||
'vram_max_mb': vram_stats['max_mb'],
|
||||
'vram_avg_mb': vram_stats['avg_mb'],
|
||||
'results': results
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
Main function to run the benchmark.
|
||||
"""
|
||||
# Load dataset
|
||||
print("Loading ScreenSpot-v2 dataset...")
|
||||
ds = load_dataset("lmms-lab/ScreenSpot-v2")
|
||||
dataset = ds['train'] # type: ignore
|
||||
# Convert to simple list of dicts with only required keys
|
||||
samples = []
|
||||
for item in dataset:
|
||||
# Convert dataset item to dict if needed
|
||||
item_dict = dict(item) if hasattr(item, 'keys') else item
|
||||
|
||||
# Convert ScreenSpot-v2 bbox format [x, y, w, h] to [x1, y1, x2, y2]
|
||||
bbox_xywh = item_dict['bbox'] # type: ignore
|
||||
x, y, w, h = bbox_xywh
|
||||
bbox_xyxy = [x, y, x + w, y + h]
|
||||
|
||||
samples.append({
|
||||
'image': item_dict['image'], # type: ignore
|
||||
'instruction': item_dict['instruction'], # type: ignore
|
||||
'bbox': bbox_xyxy
|
||||
})
|
||||
print(f"Dataset loaded: {len(samples)} samples")
|
||||
|
||||
# Get available models
|
||||
models = get_available_models()
|
||||
|
||||
# Evaluation settings
|
||||
max_samples = 500 # Set to None to evaluate on full dataset
|
||||
|
||||
# Run evaluations
|
||||
all_results = []
|
||||
|
||||
for model in models:
|
||||
model_wrapper = ModelWrapper(model)
|
||||
result = await evaluate_model(model_wrapper, samples, max_samples)
|
||||
all_results.append(result)
|
||||
|
||||
# Print summary
|
||||
print(f"\n{result['model_name']} Results:")
|
||||
print(f" Accuracy: {result['accuracy']*100:.2f}%")
|
||||
print(f" Correct: {result['correct_predictions']}/{result['total_samples']}")
|
||||
print(f" Errors: {result['failed_predictions']}")
|
||||
print(f" Error Rate: {result['failure_rate']*100:.2f}%")
|
||||
print(f" Avg Time: {result['avg_prediction_time']:.2f}s")
|
||||
print(f" Time Range: {result['min_prediction_time']:.2f}s - {result['max_prediction_time']:.2f}s")
|
||||
print(f" VRAM Max: {result['vram_max_mb']:.1f}MB")
|
||||
print(f" VRAM Avg: {result['vram_avg_mb']:.1f}MB")
|
||||
|
||||
# Save results
|
||||
if all_results:
|
||||
save_results_to_markdown(all_results, "screenspot_v2_results.md", title="ScreenSpot-v2 Benchmark Results")
|
||||
save_visualizations(all_results, samples)
|
||||
print("\nBenchmark completed successfully!")
|
||||
else:
|
||||
print("\nNo successful evaluations completed.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -22,6 +22,33 @@ from agent.agent import ComputerAgent
|
||||
from models import GTA1Model
|
||||
from models.base import ModelProtocol
|
||||
|
||||
def get_vram_usage() -> dict:
|
||||
"""
|
||||
Get current VRAM usage statistics.
|
||||
|
||||
Returns:
|
||||
Dictionary with VRAM usage info (in MB)
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
device = torch.cuda.current_device()
|
||||
allocated = torch.cuda.memory_allocated(device) / 1024 / 1024 # Convert to MB
|
||||
reserved = torch.cuda.memory_reserved(device) / 1024 / 1024 # Convert to MB
|
||||
total = torch.cuda.get_device_properties(device).total_memory / 1024 / 1024
|
||||
return {
|
||||
'allocated_mb': allocated,
|
||||
'reserved_mb': reserved,
|
||||
'total_mb': total,
|
||||
'free_mb': total - reserved
|
||||
}
|
||||
else:
|
||||
return {
|
||||
'allocated_mb': 0.0,
|
||||
'reserved_mb': 0.0,
|
||||
'total_mb': 0.0,
|
||||
'free_mb': 0.0
|
||||
}
|
||||
|
||||
|
||||
def get_available_models() -> List[Union[str, ModelProtocol]]:
|
||||
"""
|
||||
Get list of available models for testing.
|
||||
@@ -34,11 +61,11 @@ def get_available_models() -> List[Union[str, ModelProtocol]]:
|
||||
models = [
|
||||
# === ComputerAgent model strings ===
|
||||
f"{local_provider}HelloKKMe/GTA1-7B",
|
||||
# f"{local_provider}HelloKKMe/GTA1-32B", # Uncomment if you have this model
|
||||
f"{local_provider}HelloKKMe/GTA1-32B",
|
||||
|
||||
# === Reference model classes ===
|
||||
GTA1Model("HelloKKMe/GTA1-7B"),
|
||||
# GTA1Model("HelloKKMe/GTA1-32B"), # Uncomment if you have this model
|
||||
GTA1Model("HelloKKMe/GTA1-32B"),
|
||||
]
|
||||
|
||||
return models
|
||||
@@ -88,11 +115,12 @@ class ModelWrapper:
|
||||
self.model = model
|
||||
self.is_computer_agent = isinstance(model, str)
|
||||
self.agent: Optional[ComputerAgent] = None
|
||||
self.vram_usage_history: List[float] = [] # Track VRAM usage over time
|
||||
|
||||
if self.is_computer_agent:
|
||||
self.model_name = str(model)
|
||||
else:
|
||||
self.model_name = f"models.{model.__class__.__name__}"
|
||||
self.model_name = f"{model.__class__.__name__}('{getattr(model, 'model_name', 'unknown')}')"
|
||||
|
||||
async def load_model(self) -> None:
|
||||
"""Load the model."""
|
||||
@@ -100,6 +128,10 @@ class ModelWrapper:
|
||||
self.agent = ComputerAgent(model=str(self.model))
|
||||
else:
|
||||
await self.model.load_model() # type: ignore
|
||||
|
||||
# Record initial VRAM usage after loading
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
async def unload_model(self) -> None:
|
||||
"""Unload the model."""
|
||||
@@ -111,10 +143,28 @@ class ModelWrapper:
|
||||
gc.collect()
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Record VRAM usage after unloading
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
def get_vram_stats(self) -> dict:
|
||||
"""Get VRAM usage statistics for this model."""
|
||||
if not self.vram_usage_history:
|
||||
return {'max_mb': 0.0, 'avg_mb': 0.0}
|
||||
|
||||
return {
|
||||
'max_mb': max(self.vram_usage_history),
|
||||
'avg_mb': sum(self.vram_usage_history) / len(self.vram_usage_history)
|
||||
}
|
||||
|
||||
|
||||
async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]:
|
||||
"""Predict click coordinates."""
|
||||
# Record VRAM usage before prediction
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
if self.is_computer_agent:
|
||||
if self.agent is None:
|
||||
await self.load_model()
|
||||
@@ -122,13 +172,24 @@ class ModelWrapper:
|
||||
if self.agent is not None:
|
||||
image_b64 = image_to_base64(image)
|
||||
result = await self.agent.predict_click(instruction=instruction, image_b64=image_b64)
|
||||
|
||||
# Record VRAM usage after prediction
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
return result
|
||||
return None
|
||||
else:
|
||||
return await self.model.predict_click(image, instruction) # type: ignore
|
||||
result = await self.model.predict_click(image, instruction) # type: ignore
|
||||
|
||||
# Record VRAM usage after prediction
|
||||
vram_info = get_vram_usage()
|
||||
self.vram_usage_history.append(vram_info['allocated_mb'])
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def save_results_to_markdown(all_results: List[dict], output_file: str = "screenspot_pro_results.md") -> None:
|
||||
def save_results_to_markdown(all_results: List[dict],output_file: str = "screenspot_pro_results.md", title: str = "ScreenSpot-Pro Benchmark Results") -> None:
|
||||
"""
|
||||
Save evaluation results to a markdown table.
|
||||
|
||||
@@ -137,39 +198,46 @@ def save_results_to_markdown(all_results: List[dict], output_file: str = "screen
|
||||
output_file: Output markdown file path
|
||||
"""
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write("# ScreenSpot-Pro Benchmark Results\n\n")
|
||||
f.write(f"# {title}\n\n")
|
||||
f.write(f"**Evaluation Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
|
||||
|
||||
# Summary table
|
||||
f.write("## Summary\n\n")
|
||||
f.write("| Model | Total Samples | Correct | Failed | Accuracy | Failure Rate |\n")
|
||||
f.write("|-------|---------------|---------|--------|----------|--------------|\n")
|
||||
f.write("| Model | Total Samples | Correct | Errors | Accuracy | Error Rate | Avg Time (s) | Time Range (s) | VRAM Max (GB) | VRAM Avg (GB) |\n")
|
||||
f.write("|-------|---------------|---------|--------|----------|------------|--------------|----------------|---------------|---------------|\n")
|
||||
|
||||
for result in all_results:
|
||||
model_name = result['model_name']
|
||||
total = result['total_samples']
|
||||
correct = result['correct_predictions']
|
||||
failed = result['failed_predictions']
|
||||
errors = result['failed_predictions']
|
||||
accuracy = result['accuracy'] * 100
|
||||
failure_rate = result['failure_rate'] * 100
|
||||
error_rate = result['failure_rate'] * 100
|
||||
avg_time = result.get('avg_prediction_time', 0.0)
|
||||
min_time = result.get('min_prediction_time', 0.0)
|
||||
max_time = result.get('max_prediction_time', 0.0)
|
||||
time_range = f"{min_time:.2f} - {max_time:.2f}"
|
||||
vram_max = result.get('vram_max_mb', 0.0) / 1024
|
||||
vram_avg = result.get('vram_avg_mb', 0.0) / 1024
|
||||
|
||||
f.write(f"| {model_name} | {total} | {correct} | {failed} | {accuracy:.2f}% | {failure_rate:.2f}% |\n")
|
||||
f.write(f"| {model_name} | {total} | {correct} | {errors} | {accuracy:.2f}% | {error_rate:.2f}% | {avg_time:.2f} | {time_range} | {vram_max:.1f} | {vram_avg:.1f} |\n")
|
||||
|
||||
# Detailed results for each model
|
||||
for result in all_results:
|
||||
f.write(f"\n## {result['model_name']} - Detailed Results\n\n")
|
||||
f.write("| Sample ID | Instruction | BBox | Predicted | Correct | Failed |\n")
|
||||
f.write("|-----------|-------------|------|-----------|---------|--------|\n")
|
||||
f.write("| Sample Index | Instruction | BBox | Predicted | Correct | Error | Time (s) |\n")
|
||||
f.write("|-----------|-------------|------|-----------|---------|-------|----------|\n")
|
||||
|
||||
for sample_result in result['results'][:10]: # Show first 10 samples
|
||||
sample_id = sample_result['id']
|
||||
sample_idx = sample_result['sample_idx']
|
||||
instruction = sample_result['instruction'][:50] + "..." if len(sample_result['instruction']) > 50 else sample_result['instruction']
|
||||
bbox = str(sample_result['bbox'])
|
||||
predicted = str(sample_result['predicted_coords']) if sample_result['predicted_coords'] else "None"
|
||||
correct = "PASS" if sample_result['is_correct'] else "FAIL"
|
||||
failed = "YES" if sample_result['failed'] else "NO"
|
||||
error = "YES" if sample_result['failed'] else "NO"
|
||||
pred_time = sample_result.get('prediction_time', 0.0)
|
||||
|
||||
f.write(f"| {sample_id} | {instruction} | {bbox} | {predicted} | {correct} | {failed} |\n")
|
||||
f.write(f"| {sample_idx} | {instruction} | {bbox} | {predicted} | {correct} | {error} | {pred_time:.2f} |\n")
|
||||
|
||||
if len(result['results']) > 10:
|
||||
f.write(f"\n*Showing first 10 of {len(result['results'])} samples*\n")
|
||||
@@ -177,76 +245,68 @@ def save_results_to_markdown(all_results: List[dict], output_file: str = "screen
|
||||
print(f"\nResults saved to: {output_file}")
|
||||
|
||||
|
||||
def save_visualizations(all_results: List[dict], dataset_list, output_dir: str = "output") -> None:
|
||||
def save_visualizations(all_results: List[dict], samples, output_dir: str = "output") -> None:
|
||||
"""
|
||||
Save visualizations of predicted coordinates vs bboxes to an output folder.
|
||||
|
||||
Args:
|
||||
all_results: List of evaluation results for each model
|
||||
dataset_list: List of dataset samples
|
||||
samples: List of sample dicts with image, bbox, instruction keys
|
||||
output_dir: Output directory path
|
||||
"""
|
||||
# Create output directory
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
for result in all_results:
|
||||
model_name = result['model_name'].replace('/', '_').replace('.', '_')
|
||||
model_name = result['model_name'].replace('/', '_').replace('\\', '_')
|
||||
model_dir = os.path.join(output_dir, model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
|
||||
print(f"\nSaving visualizations for {result['model_name']}...")
|
||||
print(f"Saving visualizations for {result['model_name']}...")
|
||||
|
||||
# Save first 10 samples for visualization
|
||||
for i, sample_result in enumerate(tqdm(result['results'][:10], desc=f"Saving {model_name} visualizations")):
|
||||
try:
|
||||
# Find the original sample
|
||||
sample_id = sample_result['id']
|
||||
sample = None
|
||||
for s in dataset_list:
|
||||
if s['id'] == sample_id:
|
||||
sample = s
|
||||
break
|
||||
|
||||
if sample is None:
|
||||
continue
|
||||
|
||||
# Get image and data
|
||||
image = sample['image'].copy()
|
||||
bbox = sample_result['bbox'] # [x1, y1, x2, y2]
|
||||
predicted_coords = sample_result['predicted_coords']
|
||||
is_correct = sample_result['is_correct']
|
||||
|
||||
# Draw on image
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
# Draw bounding box (ground truth) in green
|
||||
x1, y1, x2, y2 = bbox
|
||||
draw.rectangle([x1, y1, x2, y2], outline="green", width=3)
|
||||
draw.text((x1, y1-20), "Ground Truth", fill="green")
|
||||
|
||||
# Draw predicted click in red or blue
|
||||
if predicted_coords is not None:
|
||||
px, py = predicted_coords
|
||||
color = "blue" if is_correct else "red"
|
||||
# Draw crosshair
|
||||
crosshair_size = 15
|
||||
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=3)
|
||||
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=3)
|
||||
draw.text((px+10, py-20), f"Predicted ({px},{py})", fill=color)
|
||||
|
||||
# Add status text
|
||||
status = "CORRECT" if is_correct else "INCORRECT"
|
||||
status_color = "blue" if is_correct else "red"
|
||||
draw.text((10, 10), f"Status: {status}", fill=status_color)
|
||||
draw.text((10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black")
|
||||
|
||||
# Save image
|
||||
filename = f"sample_{i+1:02d}_{sample_id}_{status.lower()}.png"
|
||||
filepath = os.path.join(model_dir, filename)
|
||||
image.save(filepath)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving visualization for sample {sample_id}: {e}")
|
||||
# Get sample data using index
|
||||
sample_idx = sample_result['sample_idx']
|
||||
|
||||
if sample_idx < len(samples):
|
||||
sample = samples[sample_idx]
|
||||
image = sample['image'].copy() # Make a copy to avoid modifying original
|
||||
else:
|
||||
print(f"Warning: Could not find sample at index {sample_idx}")
|
||||
continue
|
||||
|
||||
bbox = sample_result['bbox']
|
||||
predicted_coords = sample_result['predicted_coords']
|
||||
is_correct = sample_result['is_correct']
|
||||
|
||||
# Draw on image
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
# Draw bounding box (ground truth) in green
|
||||
x1, y1, x2, y2 = bbox
|
||||
draw.rectangle([x1, y1, x2, y2], outline="green", width=3)
|
||||
draw.text((x1, y1-20), "Ground Truth", fill="green")
|
||||
|
||||
# Draw predicted click in red or blue
|
||||
if predicted_coords is not None:
|
||||
px, py = predicted_coords
|
||||
color = "blue" if is_correct else "red"
|
||||
# Draw crosshair
|
||||
crosshair_size = 15
|
||||
draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=3)
|
||||
draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=3)
|
||||
draw.text((px+10, py-20), f"Predicted ({px},{py})", fill=color)
|
||||
|
||||
# Add status text
|
||||
status = "CORRECT" if is_correct else "INCORRECT"
|
||||
status_color = "blue" if is_correct else "red"
|
||||
draw.text((10, 10), f"Status: {status}", fill=status_color)
|
||||
draw.text((10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black")
|
||||
|
||||
# Save image
|
||||
filename = f"sample_{i+1:02d}_idx{sample_idx}_{status.lower()}.png"
|
||||
filepath = os.path.join(model_dir, filename)
|
||||
image.save(filepath)
|
||||
|
||||
print(f"Visualizations saved to: {model_dir}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user