From 2076ec75966de2c4b2961be7d8585f3e816e91b0 Mon Sep 17 00:00:00 2001 From: Dillon DuPont Date: Tue, 29 Jul 2025 20:48:44 -0400 Subject: [PATCH] added GTA1 agent and click benchmarks (ss-pro, repl) --- .../adapters/huggingfacelocal_adapter.py | 6 +- libs/python/agent/agent/loops/__init__.py | 3 +- libs/python/agent/agent/loops/gta1.py | 4 +- libs/python/agent/agent/loops/omniparser.py | 2 + libs/python/agent/agent/loops/uitars.py | 81 +++-- libs/python/agent/benchmarks/.gitignore | 2 + libs/python/agent/benchmarks/interactive.py | 201 +++++++++++ .../agent/benchmarks/models/__init__.py | 4 + libs/python/agent/benchmarks/models/base.py | 36 ++ libs/python/agent/benchmarks/models/gta1.py | 162 +++++++++ libs/python/agent/benchmarks/ss-pro.py | 157 +++++++++ libs/python/agent/benchmarks/utils.py | 316 ++++++++++++++++++ 12 files changed, 939 insertions(+), 35 deletions(-) create mode 100644 libs/python/agent/benchmarks/.gitignore create mode 100644 libs/python/agent/benchmarks/interactive.py create mode 100644 libs/python/agent/benchmarks/models/__init__.py create mode 100644 libs/python/agent/benchmarks/models/base.py create mode 100644 libs/python/agent/benchmarks/models/gta1.py create mode 100644 libs/python/agent/benchmarks/ss-pro.py create mode 100644 libs/python/agent/benchmarks/utils.py diff --git a/libs/python/agent/agent/adapters/huggingfacelocal_adapter.py b/libs/python/agent/agent/adapters/huggingfacelocal_adapter.py index f8706868..5692401d 100644 --- a/libs/python/agent/agent/adapters/huggingfacelocal_adapter.py +++ b/libs/python/agent/agent/adapters/huggingfacelocal_adapter.py @@ -48,7 +48,11 @@ class HuggingFaceLocalAdapter(CustomLLM): ) # Load processor - processor = AutoProcessor.from_pretrained(model_name) + processor = AutoProcessor.from_pretrained( + model_name, + min_pixels=3136, + max_pixels=4096 * 2160 + ) # Cache them self.models[model_name] = model diff --git a/libs/python/agent/agent/loops/__init__.py b/libs/python/agent/agent/loops/__init__.py index aa159411..91722e55 100644 --- a/libs/python/agent/agent/loops/__init__.py +++ b/libs/python/agent/agent/loops/__init__.py @@ -7,5 +7,6 @@ from . import anthropic from . import openai from . import uitars from . import omniparser +from . import gta1 -__all__ = ["anthropic", "openai", "uitars", "omniparser"] +__all__ = ["anthropic", "openai", "uitars", "omniparser", "gta1"] diff --git a/libs/python/agent/agent/loops/gta1.py b/libs/python/agent/agent/loops/gta1.py index 4d0d3349..fb272f30 100644 --- a/libs/python/agent/agent/loops/gta1.py +++ b/libs/python/agent/agent/loops/gta1.py @@ -1,5 +1,7 @@ """ GTA1 agent loop implementation for click prediction using litellm.acompletion +Paper: https://arxiv.org/pdf/2507.05791 +Code: https://github.com/Yan98/GTA1 """ import asyncio @@ -20,7 +22,7 @@ You are an expert UI element locator. Given a GUI image and a user's element des Output the coordinate pair exactly: (x,y) -''' +'''.strip() def extract_coordinates(raw_string: str) -> Tuple[float, float]: """Extract coordinates from model output.""" diff --git a/libs/python/agent/agent/loops/omniparser.py b/libs/python/agent/agent/loops/omniparser.py index e92ef660..aff73edf 100644 --- a/libs/python/agent/agent/loops/omniparser.py +++ b/libs/python/agent/agent/loops/omniparser.py @@ -1,5 +1,7 @@ """ OpenAI computer-use-preview agent loop implementation using liteLLM +Paper: https://arxiv.org/abs/2408.00203 +Code: https://github.com/microsoft/OmniParser """ import asyncio diff --git a/libs/python/agent/agent/loops/uitars.py b/libs/python/agent/agent/loops/uitars.py index f5188288..f715ef61 100644 --- a/libs/python/agent/agent/loops/uitars.py +++ b/libs/python/agent/agent/loops/uitars.py @@ -1,5 +1,7 @@ """ UITARS agent loop implementation using liteLLM for ByteDance-Seed/UI-TARS-1.5-7B +Paper: https://arxiv.org/abs/2501.12326 +Code: https://github.com/bytedance/UI-TARS """ import asyncio @@ -79,6 +81,18 @@ Action: ... {instruction} """ +GROUNDING_UITARS_PROMPT_TEMPLATE = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. + +## Output Format + +Action: ... + + +## Action Space +click(point='<|box_start|>(x1,y1)<|box_end|>') + +## User Instruction +{instruction}""" def round_by_factor(number: float, factor: int) -> int: """Returns the closest integer to 'number' that is divisible by 'factor'.""" @@ -511,7 +525,7 @@ class UITARSConfig: async def predict_step( self, - messages: Messages, + messages: List[Dict[str, Any]], model: str, tools: Optional[List[Dict[str, Any]]] = None, max_retries: Optional[int] = None, @@ -729,13 +743,15 @@ class UITARSConfig: Tuple with (x, y) coordinates or None """ try: - # Create a simple click instruction for UITARS - user_prompt = UITARS_PROMPT_TEMPLATE.format( - instruction=f"Click on: {instruction}", - action_space=UITARS_ACTION_SPACE, - language="English" + # Create prompt using grounding template + user_prompt = GROUNDING_UITARS_PROMPT_TEMPLATE.format( + instruction=instruction ) + # Process image for UITARS + processed_image, original_width, original_height = process_image_for_uitars(image_b64) + encoded_image = pil_to_base64(processed_image) + # Prepare messages for liteLLM litellm_messages = [ { @@ -746,46 +762,47 @@ class UITARSConfig: "role": "user", "content": [ {"type": "text", "text": user_prompt}, - {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}} + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}} ] } ] + # Prepare API call kwargs + api_kwargs = { + "model": model, + "messages": litellm_messages, + "max_tokens": 100, + "temperature": 0.0, + "do_sample": False + } + # Call liteLLM with UITARS model - response = await litellm.acompletion( - model=model, - messages=litellm_messages, - max_tokens=100, - temperature=0.0 - ) + response = await litellm.acompletion(**api_kwargs) # Extract response content response_content = response.choices[0].message.content.strip() # type: ignore - # Parse UITARS response to extract click coordinates - parsed_responses = parse_uitars_response(response_content, 1024, 768) # Default dimensions + # Parse the response to extract click coordinates + # Look for click action with coordinates + click_pattern = r"click\(point='<\|box_start\|>\((\d+),(\d+)\)<\|box_end\|>'\)" + match = re.search(click_pattern, response_content) - if parsed_responses and len(parsed_responses) > 0: - action_type = parsed_responses[0].get("action_type") - if action_type == "click": - action_inputs = parsed_responses[0].get("action_inputs", {}) - start_box = action_inputs.get("start_box") - if start_box: - # Parse coordinates from start_box - try: - coords = eval(start_box) # Parse the coordinate list - if len(coords) >= 2: - # Convert normalized coordinates back to pixel coordinates - x = int(coords[0] * 1024) - y = int(coords[1] * 768) - return (x, y) - except: - pass + if match: + x, y = int(match.group(1)), int(match.group(2)) + # Scale coordinates back to original image dimensions + scale_x = original_width / processed_image.width + scale_y = original_height / processed_image.height + + scaled_x = int(x * scale_x) + scaled_y = int(y * scale_y) + + return (scaled_x, scaled_y) return None except Exception as e: - print(f"Error in UITARS predict_click: {e}") + # Log error and return None + print(f"Error in predict_click: {e}") return None def get_capabilities(self) -> List[AgentCapability]: diff --git a/libs/python/agent/benchmarks/.gitignore b/libs/python/agent/benchmarks/.gitignore new file mode 100644 index 00000000..b9f463f1 --- /dev/null +++ b/libs/python/agent/benchmarks/.gitignore @@ -0,0 +1,2 @@ +output/ +interactive_output/ diff --git a/libs/python/agent/benchmarks/interactive.py b/libs/python/agent/benchmarks/interactive.py new file mode 100644 index 00000000..6d0aba82 --- /dev/null +++ b/libs/python/agent/benchmarks/interactive.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python3 +""" +Interactive Click Prediction Tool + +Takes screenshots and allows testing multiple models interactively. +Models are loaded/unloaded one at a time to avoid memory issues. +""" + +import asyncio +import os +from datetime import datetime +from typing import List, Dict, Any + +from utils import ( + ModelWrapper, + take_screenshot, + save_prediction_visualization, + get_available_models +) + + +async def predict_with_all_models(image, instruction: str, models) -> List[Dict[str, Any]]: + """ + Predict click coordinates with all models sequentially. + + Args: + image: PIL Image to analyze + instruction: Instruction text + models: List of model instances + + Returns: + List of prediction results + """ + predictions = [] + + for model in models: + model_wrapper = ModelWrapper(model) + print(f"\nšŸ”„ Loading {model_wrapper.model_name}...") + + try: + # Load model + await model_wrapper.load_model() + + # Predict + coords = await model_wrapper.predict_click(image, instruction) + + predictions.append({ + 'model_name': model_wrapper.model_name, + 'coords': coords, + 'error': None + }) + + if coords: + print(f"āœ… {model_wrapper.model_name}: ({coords[0]}, {coords[1]})") + else: + print(f"āŒ {model_wrapper.model_name}: No prediction") + + except Exception as e: + print(f"āŒ {model_wrapper.model_name}: ERROR - {str(e)}") + predictions.append({ + 'model_name': model_wrapper.model_name, + 'coords': None, + 'error': str(e) + }) + + finally: + # Always unload model to free memory + try: + await model_wrapper.unload_model() + print(f"šŸ—‘ļø Unloaded {model_wrapper.model_name}") + except Exception as e: + print(f"āš ļø Error unloading {model_wrapper.model_name}: {e}") + + return predictions + + +def print_header(): + """Print the interactive tool header.""" + print("=" * 60) + print("šŸ–±ļø Interactive Click Prediction Tool") + print("=" * 60) + print("Commands:") + print(" • Type an instruction to test models on last screenshot") + print(" • 'screenshot' - Take a new screenshot") + print(" • 'models' - List available models") + print(" • 'quit' or 'exit' - Exit the tool") + print("=" * 60) + print("šŸ’” Tip: Take a screenshot first, then send instructions to test models!") + + +def print_models(models): + """Print available models.""" + print("\nšŸ“‹ Available Models:") + for i, model in enumerate(models, 1): + if isinstance(model, str): + print(f" {i}. {model}") + else: + print(f" {i}. models.{model.__class__.__name__}") + + +async def main(): + """ + Main interactive loop. + """ + print_header() + + # Get available models + models = get_available_models() + print_models(models) + + # Create output directory for visualizations + output_dir = "interactive_output" + os.makedirs(output_dir, exist_ok=True) + + session_count = 0 + last_screenshot = None + screenshot_timestamp = None + + while True: + try: + # Get user input + print(f"\n{'='*40}") + user_input = input("šŸŽÆ Enter instruction (or command): ").strip() + + if not user_input: + continue + + # Handle commands + if user_input.lower() in ['quit', 'exit', 'q']: + print("šŸ‘‹ Goodbye!") + break + + elif user_input.lower() == 'models': + print_models(models) + continue + + elif user_input.lower() == 'screenshot': + print("šŸ“ø Taking screenshot...") + try: + last_screenshot = take_screenshot() + screenshot_timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + screenshot_path = os.path.join(output_dir, f"screenshot_{screenshot_timestamp}.png") + last_screenshot.save(screenshot_path) + print(f"āœ… Screenshot captured and saved to: {screenshot_path}") + print(f"šŸ“ Ready for instructions! Screenshot size: {last_screenshot.size}") + except Exception as e: + print(f"āŒ Error taking screenshot: {e}") + continue + + # Handle instruction input + if last_screenshot is None: + print("āš ļø No screenshot available! Please take a screenshot first using 'screenshot' command.") + continue + + session_count += 1 + print(f"\nšŸŽÆ Session {session_count}: '{user_input}'") + print(f"šŸ“· Using screenshot from: {screenshot_timestamp}") + + # Predict with all models using last screenshot + print(f"\nšŸ¤– Testing {len(models)} models on screenshot...") + predictions = await predict_with_all_models(last_screenshot, user_input, models) + + # Display results summary + print(f"\nšŸ“Š Results Summary:") + print("-" * 50) + for pred in predictions: + if pred['coords']: + print(f"āœ… {pred['model_name']}: ({pred['coords'][0]}, {pred['coords'][1]})") + elif pred['error']: + print(f"āŒ {pred['model_name']}: ERROR - {pred['error']}") + else: + print(f"āŒ {pred['model_name']}: No prediction") + + # Save visualization + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + vis_filename = f"session_{session_count:03d}_{timestamp}.png" + vis_path = os.path.join(output_dir, vis_filename) + + try: + save_prediction_visualization(last_screenshot, user_input, predictions, vis_path) + print(f"\nšŸ’¾ Visualization saved to: {vis_path}") + except Exception as e: + print(f"āš ļø Error saving visualization: {e}") + + print(f"\n✨ Session {session_count} completed!") + + except KeyboardInterrupt: + print("\n\nšŸ‘‹ Interrupted by user. Goodbye!") + break + except Exception as e: + print(f"\nāŒ Unexpected error: {e}") + print("Continuing...") + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nšŸ‘‹ Goodbye!") + except Exception as e: + print(f"āŒ Fatal error: {e}") diff --git a/libs/python/agent/benchmarks/models/__init__.py b/libs/python/agent/benchmarks/models/__init__.py new file mode 100644 index 00000000..51033a7b --- /dev/null +++ b/libs/python/agent/benchmarks/models/__init__.py @@ -0,0 +1,4 @@ +from .base import ModelProtocol +from .gta1 import GTA1Model + +__all__ = ["ModelProtocol", "GTA1Model"] diff --git a/libs/python/agent/benchmarks/models/base.py b/libs/python/agent/benchmarks/models/base.py new file mode 100644 index 00000000..8ad100a3 --- /dev/null +++ b/libs/python/agent/benchmarks/models/base.py @@ -0,0 +1,36 @@ +""" +Base protocol for benchmark models. +""" + +from typing import Protocol, Optional, Tuple +from PIL import Image + + +class ModelProtocol(Protocol): + """Protocol for benchmark models that can predict click coordinates.""" + + @property + def model_name(self) -> str: + """Return the name of the model.""" + ... + + async def load_model(self) -> None: + """Load the model into memory.""" + ... + + async def unload_model(self) -> None: + """Unload the model from memory.""" + ... + + async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]: + """ + Predict click coordinates for the given image and instruction. + + Args: + image: PIL Image to analyze + instruction: Text instruction describing what to click + + Returns: + Tuple of (x, y) coordinates or None if prediction fails + """ + ... diff --git a/libs/python/agent/benchmarks/models/gta1.py b/libs/python/agent/benchmarks/models/gta1.py new file mode 100644 index 00000000..2bb4fe1d --- /dev/null +++ b/libs/python/agent/benchmarks/models/gta1.py @@ -0,0 +1,162 @@ +""" +GTA1 model implementation for benchmarking. +""" + +from typing import Optional, Tuple +from PIL import Image +import torch +import re +import gc +from qwen_vl_utils import process_vision_info, smart_resize +from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor + +from .base import ModelProtocol + + +class GTA1Model: + """Ground truth GTA1 model implementation.""" + + def __init__(self, model_path: str = "HelloKKMe/GTA1-7B"): + self.model_path = model_path + self.model = None + self.processor = None + self.max_new_tokens = 32 + + self.system_prompt = ''' +You are an expert UI element locator. Given a GUI image and a user's element description, provide the coordinates of the specified element as a single (x,y) point. The image resolution is height {height} and width {width}. For elements with area, return the center point. + +Output the coordinate pair exactly: +(x,y) +'''.strip() + + @property + def model_name(self) -> str: + """Return the name of the model.""" + return f"GTA1-{self.model_path.split('/')[-1]}" + + async def load_model(self) -> None: + """Load the model into memory.""" + if self.model is None: + print(f"Loading GTA1 model: {self.model_path}") + self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + self.model_path, + torch_dtype=torch.bfloat16, + device_map="auto" + ) + self.processor = AutoProcessor.from_pretrained( + self.model_path, + min_pixels=3136, + max_pixels=4096 * 2160 + ) + print("GTA1 model loaded successfully") + + async def unload_model(self) -> None: + """Unload the model from memory.""" + if self.model is not None: + print("Unloading GTA1 model from GPU...") + del self.model + del self.processor + self.model = None + self.processor = None + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + print("GTA1 model unloaded") + + def _extract_coordinates(self, raw_string: str) -> Tuple[int, int]: + """Extract coordinates from model output.""" + try: + matches = re.findall(r"\((-?\d*\.?\d+),\s*(-?\d*\.?\d+)\)", raw_string) + return tuple(map(int, map(float, matches[0]))) # type: ignore + except: + return (0, 0) + + async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]: + """ + Predict click coordinates for the given image and instruction. + + Args: + image: PIL Image to analyze + instruction: Text instruction describing what to click + + Returns: + Tuple of (x, y) coordinates or None if prediction fails + """ + if self.model is None or self.processor is None: + await self.load_model() + + assert self.processor is not None + assert self.model is not None + + try: + width, height = image.width, image.height + + # Resize image according to processor requirements + resized_height, resized_width = smart_resize( + image.height, + image.width, + factor=self.processor.image_processor.patch_size * self.processor.image_processor.merge_size, + min_pixels=self.processor.image_processor.min_pixels, + max_pixels=self.processor.image_processor.max_pixels, + ) + resized_image = image.resize((resized_width, resized_height)) + scale_x, scale_y = width / resized_width, height / resized_height + + # Prepare messages + system_message = { + "role": "system", + "content": self.system_prompt.format(height=resized_height, width=resized_width) + } + + user_message = { + "role": "user", + "content": [ + {"type": "image", "image": resized_image}, + {"type": "text", "text": instruction} + ] + } + + # Process inputs + image_inputs, video_inputs = process_vision_info([system_message, user_message]) + text = self.processor.apply_chat_template( + [system_message, user_message], + tokenize=False, + add_generation_prompt=True + ) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt" + ) + inputs = inputs.to(self.model.device) + + # Generate prediction + output_ids = self.model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + do_sample=False, + temperature=1.0, + use_cache=True + ) + generated_ids = [ + output_ids[len(input_ids):] + for input_ids, output_ids in zip(inputs.input_ids, output_ids) + ] + output_text = self.processor.batch_decode( + generated_ids, + skip_special_tokens=True, + clean_up_tokenization_spaces=True + )[0] + + # Extract and rescale coordinates + pred_x, pred_y = self._extract_coordinates(output_text) + pred_x = int(pred_x * scale_x) + pred_y = int(pred_y * scale_y) + + return (pred_x, pred_y) + + except Exception as e: + print(f"Error in GTA1 prediction: {e}") + return None diff --git a/libs/python/agent/benchmarks/ss-pro.py b/libs/python/agent/benchmarks/ss-pro.py new file mode 100644 index 00000000..57f2c971 --- /dev/null +++ b/libs/python/agent/benchmarks/ss-pro.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +""" +ScreenSpot-Pro Benchmark Script + +Evaluates models on the ScreenSpot-Pro dataset for click prediction accuracy. +Supports both ComputerAgent model strings and custom model classes. +""" + +import asyncio +from typing import Optional + +from datasets import load_dataset +from tqdm import tqdm + +from utils import ( + ModelWrapper, + is_click_in_bbox, + save_results_to_markdown, + save_visualizations, + get_available_models +) + + +async def evaluate_model(model_wrapper: ModelWrapper, dataset, max_samples: Optional[int] = None) -> dict: + """ + Evaluate a model on the ScreenSpot-Pro dataset. + + Args: + model_wrapper: ModelWrapper instance + dataset: ScreenSpot-Pro dataset (list of samples) + max_samples: Maximum number of samples to evaluate (None for all) + + Returns: + Dictionary with evaluation results + """ + print(f"\nEvaluating model: {model_wrapper.model_name}") + + # Load model + await model_wrapper.load_model() + + total_samples = len(dataset) + if max_samples is not None: + total_samples = min(max_samples, total_samples) + + correct_predictions = 0 + failed_predictions = 0 + results = [] + + try: + for i in tqdm(range(total_samples), desc=f"Evaluating {model_wrapper.model_name}"): + sample = dataset[i] + + # Extract sample data + image = sample['image'] + instruction = sample['instruction'] + bbox = sample['bbox'] # [x1, y1, x2, y2] + sample_id = sample['id'] + + # Predict click coordinates + try: + click_coords = await model_wrapper.predict_click(image, instruction) + + # Check if prediction is correct + is_correct = is_click_in_bbox(click_coords, bbox) + + if is_correct: + correct_predictions += 1 + + results.append({ + 'id': sample_id, + 'instruction': instruction, + 'bbox': bbox, + 'predicted_coords': click_coords, + 'is_correct': is_correct, + 'failed': False + }) + + except Exception as e: + print(f"\nError predicting sample {sample_id}: {e}") + failed_predictions += 1 + results.append({ + 'id': sample_id, + 'instruction': instruction, + 'bbox': bbox, + 'predicted_coords': None, + 'is_correct': False, + 'failed': True, + 'error': str(e) + }) + + finally: + # Unload model + await model_wrapper.unload_model() + + # Calculate metrics + accuracy = correct_predictions / total_samples if total_samples > 0 else 0.0 + failure_rate = failed_predictions / total_samples if total_samples > 0 else 0.0 + + return { + 'model_name': model_wrapper.model_name, + 'total_samples': total_samples, + 'correct_predictions': correct_predictions, + 'failed_predictions': failed_predictions, + 'accuracy': accuracy, + 'failure_rate': failure_rate, + 'results': results + } + + +async def main(): + """ + Main function to run the benchmark. + """ + # Load dataset + print("Loading ScreenSpot-Pro dataset...") + ds = load_dataset("lmms-lab/ScreenSpot-Pro") + dataset = ds['train'] # type: ignore + # Convert to list to support indexing + dataset_list = list(dataset) + print(f"Dataset loaded: {len(dataset_list)} samples") + + # Get available models + models = get_available_models() + + # Evaluation settings + max_samples = 5 # Set to None to evaluate on full dataset + + # Run evaluations + all_results = [] + + for model in models: + try: + model_wrapper = ModelWrapper(model) + result = await evaluate_model(model_wrapper, dataset_list, max_samples) + all_results.append(result) + + # Print summary + print(f"\n{result['model_name']} Results:") + print(f" Accuracy: {result['accuracy']*100:.2f}%") + print(f" Correct: {result['correct_predictions']}/{result['total_samples']}") + print(f" Failed: {result['failed_predictions']}") + + except Exception as e: + print(f"\nError evaluating model {model}: {e}") + continue + + # Save results + if all_results: + save_results_to_markdown(all_results) + save_visualizations(all_results, dataset_list) + print("\nBenchmark completed successfully!") + else: + print("\nNo successful evaluations completed.") + + +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/libs/python/agent/benchmarks/utils.py b/libs/python/agent/benchmarks/utils.py new file mode 100644 index 00000000..c1fc41cf --- /dev/null +++ b/libs/python/agent/benchmarks/utils.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +""" +Shared utilities for ScreenSpot-Pro benchmarking and interactive testing. +""" + +import asyncio +import base64 +import os +import sys +from datetime import datetime +from io import BytesIO +from typing import List, Union, Tuple, Optional + +from PIL import Image, ImageDraw +from tqdm import tqdm +import gc +import torch + +# Add parent directory to path for imports +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) +from agent.agent import ComputerAgent +from models import GTA1Model +from models.base import ModelProtocol + +def get_available_models() -> List[Union[str, ModelProtocol]]: + """ + Get list of available models for testing. + + Returns: + List of model strings and model classes + """ + local_provider = "huggingface-local/" # Options: huggingface-local/ or mlx/ + + models = [ + # === ComputerAgent model strings === + f"{local_provider}HelloKKMe/GTA1-7B", + # f"{local_provider}HelloKKMe/GTA1-32B", # Uncomment if you have this model + + # === Reference model classes === + GTA1Model("HelloKKMe/GTA1-7B"), + # GTA1Model("HelloKKMe/GTA1-32B"), # Uncomment if you have this model + ] + + return models + + +def is_click_in_bbox(click_coords: Optional[Tuple[int, int]], bbox: List[int]) -> bool: + """ + Check if click coordinates are within the bounding box. + + Args: + click_coords: (x, y) coordinates or None + bbox: [x1, y1, x2, y2] bounding box + + Returns: + True if click is within bbox, False otherwise + """ + if click_coords is None: + return False + + x, y = click_coords + x1, y1, x2, y2 = bbox + + return x1 <= x <= x2 and y1 <= y <= y2 + + +def image_to_base64(image: Image.Image) -> str: + """ + Convert PIL Image to base64 string. + + Args: + image: PIL Image + + Returns: + Base64 encoded image string + """ + buffered = BytesIO() + image.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode() + + +class ModelWrapper: + """ + Wrapper to provide unified interface for both ComputerAgent and custom models. + """ + + def __init__(self, model: Union[str, ModelProtocol]): + self.model = model + self.is_computer_agent = isinstance(model, str) + self.agent: Optional[ComputerAgent] = None + + if self.is_computer_agent: + self.model_name = str(model) + else: + self.model_name = f"models.{model.__class__.__name__}" + + async def load_model(self) -> None: + """Load the model.""" + if self.is_computer_agent: + self.agent = ComputerAgent(model=str(self.model)) + else: + await self.model.load_model() # type: ignore + + async def unload_model(self) -> None: + """Unload the model.""" + if not self.is_computer_agent: + await self.model.unload_model() # type: ignore + else: + del self.agent + self.agent = None + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + + async def predict_click(self, image: Image.Image, instruction: str) -> Optional[Tuple[int, int]]: + """Predict click coordinates.""" + if self.is_computer_agent: + if self.agent is None: + await self.load_model() + + if self.agent is not None: + image_b64 = image_to_base64(image) + result = await self.agent.predict_click(instruction=instruction, image_b64=image_b64) + return result + return None + else: + return await self.model.predict_click(image, instruction) # type: ignore + + +def save_results_to_markdown(all_results: List[dict], output_file: str = "screenspot_pro_results.md") -> None: + """ + Save evaluation results to a markdown table. + + Args: + all_results: List of evaluation results for each model + output_file: Output markdown file path + """ + with open(output_file, 'w', encoding='utf-8') as f: + f.write("# ScreenSpot-Pro Benchmark Results\n\n") + f.write(f"**Evaluation Date:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n") + + # Summary table + f.write("## Summary\n\n") + f.write("| Model | Total Samples | Correct | Failed | Accuracy | Failure Rate |\n") + f.write("|-------|---------------|---------|--------|----------|--------------|\n") + + for result in all_results: + model_name = result['model_name'] + total = result['total_samples'] + correct = result['correct_predictions'] + failed = result['failed_predictions'] + accuracy = result['accuracy'] * 100 + failure_rate = result['failure_rate'] * 100 + + f.write(f"| {model_name} | {total} | {correct} | {failed} | {accuracy:.2f}% | {failure_rate:.2f}% |\n") + + # Detailed results for each model + for result in all_results: + f.write(f"\n## {result['model_name']} - Detailed Results\n\n") + f.write("| Sample ID | Instruction | BBox | Predicted | Correct | Failed |\n") + f.write("|-----------|-------------|------|-----------|---------|--------|\n") + + for sample_result in result['results'][:10]: # Show first 10 samples + sample_id = sample_result['id'] + instruction = sample_result['instruction'][:50] + "..." if len(sample_result['instruction']) > 50 else sample_result['instruction'] + bbox = str(sample_result['bbox']) + predicted = str(sample_result['predicted_coords']) if sample_result['predicted_coords'] else "None" + correct = "PASS" if sample_result['is_correct'] else "FAIL" + failed = "YES" if sample_result['failed'] else "NO" + + f.write(f"| {sample_id} | {instruction} | {bbox} | {predicted} | {correct} | {failed} |\n") + + if len(result['results']) > 10: + f.write(f"\n*Showing first 10 of {len(result['results'])} samples*\n") + + print(f"\nResults saved to: {output_file}") + + +def save_visualizations(all_results: List[dict], dataset_list, output_dir: str = "output") -> None: + """ + Save visualizations of predicted coordinates vs bboxes to an output folder. + + Args: + all_results: List of evaluation results for each model + dataset_list: List of dataset samples + output_dir: Output directory path + """ + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + for result in all_results: + model_name = result['model_name'].replace('/', '_').replace('.', '_') + model_dir = os.path.join(output_dir, model_name) + os.makedirs(model_dir, exist_ok=True) + + print(f"\nSaving visualizations for {result['model_name']}...") + + for i, sample_result in enumerate(tqdm(result['results'][:10], desc=f"Saving {model_name} visualizations")): + try: + # Find the original sample + sample_id = sample_result['id'] + sample = None + for s in dataset_list: + if s['id'] == sample_id: + sample = s + break + + if sample is None: + continue + + # Get image and data + image = sample['image'].copy() + bbox = sample_result['bbox'] # [x1, y1, x2, y2] + predicted_coords = sample_result['predicted_coords'] + is_correct = sample_result['is_correct'] + + # Draw on image + draw = ImageDraw.Draw(image) + + # Draw bounding box (ground truth) in green + x1, y1, x2, y2 = bbox + draw.rectangle([x1, y1, x2, y2], outline="green", width=3) + draw.text((x1, y1-20), "Ground Truth", fill="green") + + # Draw predicted click in red or blue + if predicted_coords is not None: + px, py = predicted_coords + color = "blue" if is_correct else "red" + # Draw crosshair + crosshair_size = 15 + draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=3) + draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=3) + draw.text((px+10, py-20), f"Predicted ({px},{py})", fill=color) + + # Add status text + status = "CORRECT" if is_correct else "INCORRECT" + status_color = "blue" if is_correct else "red" + draw.text((10, 10), f"Status: {status}", fill=status_color) + draw.text((10, 30), f"Instruction: {sample_result['instruction'][:50]}...", fill="black") + + # Save image + filename = f"sample_{i+1:02d}_{sample_id}_{status.lower()}.png" + filepath = os.path.join(model_dir, filename) + image.save(filepath) + + except Exception as e: + print(f"Error saving visualization for sample {sample_id}: {e}") + continue + + print(f"Visualizations saved to: {model_dir}") + + +def save_prediction_visualization(image: Image.Image, instruction: str, predictions: List[dict], + output_file: str = "interactive_prediction.png") -> None: + """ + Save visualization of multiple model predictions on a single image. + + Args: + image: PIL Image to visualize + instruction: Instruction text + predictions: List of prediction dicts with keys: model_name, coords, error + output_file: Output file path + """ + # Create a copy of the image + vis_image = image.copy() + draw = ImageDraw.Draw(vis_image) + + # Colors for different models + colors = ["red", "blue", "orange", "purple", "brown", "pink", "gray", "olive"] + + # Draw predictions + for i, pred in enumerate(predictions): + color = colors[i % len(colors)] + model_name = pred['model_name'] + coords = pred.get('coords') + error = pred.get('error') + + if coords is not None: + px, py = coords + # Draw crosshair + crosshair_size = 20 + draw.line([(px-crosshair_size, py), (px+crosshair_size, py)], fill=color, width=4) + draw.line([(px, py-crosshair_size), (px, py+crosshair_size)], fill=color, width=4) + # Draw model name + draw.text((px+15, py+15), f"{model_name}: ({px},{py})", fill=color) + else: + # Draw error text + draw.text((10, 50 + i*20), f"{model_name}: ERROR - {error}", fill=color) + + # Add instruction at the top + draw.text((10, 10), f"Instruction: {instruction}", fill="black") + + # Save image + vis_image.save(output_file) + print(f"Prediction visualization saved to: {output_file}") + + +def take_screenshot() -> Image.Image: + """ + Take a screenshot of the current screen. + + Returns: + PIL Image of the screenshot + """ + try: + import pyautogui + screenshot = pyautogui.screenshot() + return screenshot + except ImportError: + print("pyautogui not installed. Please install it with: pip install pyautogui") + raise + except Exception as e: + print(f"Error taking screenshot: {e}") + raise +