mirror of
https://github.com/trycua/lume.git
synced 2026-04-26 06:28:49 -05:00
478 lines
17 KiB
Python
478 lines
17 KiB
Python
from pathlib import Path
|
|
from typing import Union, List, Dict, Any, Tuple, Optional, cast
|
|
import logging
|
|
import torch
|
|
import torchvision.ops
|
|
import cv2
|
|
import numpy as np
|
|
import time
|
|
import torchvision.transforms as T
|
|
from PIL import Image
|
|
import io
|
|
import base64
|
|
import argparse
|
|
import signal
|
|
from contextlib import contextmanager
|
|
|
|
from ultralytics import YOLO
|
|
from huggingface_hub import hf_hub_download
|
|
import supervision as sv
|
|
from supervision.detection.core import Detections
|
|
|
|
from .detection import DetectionProcessor
|
|
from .ocr import OCRProcessor
|
|
from .visualization import BoxAnnotator
|
|
from .models import BoundingBox, UIElement, IconElement, TextElement, ParserMetadata, ParseResult
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TimeoutException(Exception):
|
|
pass
|
|
|
|
|
|
@contextmanager
|
|
def timeout(seconds: int):
|
|
def timeout_handler(signum, frame):
|
|
raise TimeoutException("OCR process timed out")
|
|
|
|
# Register the signal handler
|
|
original_handler = signal.signal(signal.SIGALRM, timeout_handler)
|
|
signal.alarm(seconds)
|
|
|
|
try:
|
|
yield
|
|
finally:
|
|
signal.alarm(0)
|
|
signal.signal(signal.SIGALRM, original_handler)
|
|
|
|
|
|
def process_text_box(box, image):
|
|
"""Process a single text box with OCR."""
|
|
try:
|
|
import easyocr
|
|
from typing import List, Tuple, Any, Sequence
|
|
|
|
x1 = int(min(point[0] for point in box))
|
|
y1 = int(min(point[1] for point in box))
|
|
x2 = int(max(point[0] for point in box))
|
|
y2 = int(max(point[1] for point in box))
|
|
|
|
# Add padding
|
|
pad = 2
|
|
x1 = max(0, x1 - pad)
|
|
y1 = max(0, y1 - pad)
|
|
x2 = min(image.shape[1], x2 + pad)
|
|
y2 = min(image.shape[0], y2 + pad)
|
|
|
|
region = image[y1:y2, x1:x2]
|
|
if region.size > 0:
|
|
reader = easyocr.Reader(["en"])
|
|
results = reader.readtext(region)
|
|
if results and len(results) > 0:
|
|
# EasyOCR returns a list of tuples (bbox, text, confidence)
|
|
first_result = results[0]
|
|
if isinstance(first_result, (list, tuple)) and len(first_result) >= 3:
|
|
text = str(first_result[1])
|
|
confidence = float(first_result[2])
|
|
if confidence > 0.5:
|
|
return text, [x1, y1, x2, y2], confidence
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def check_ocr_box(image_path: Union[str, Path]) -> Tuple[List[str], List[List[float]]]:
|
|
"""Check OCR box using EasyOCR."""
|
|
# Read image once
|
|
if isinstance(image_path, str):
|
|
image_path = Path(image_path)
|
|
|
|
# Read image into memory
|
|
image_cv = cv2.imread(str(image_path))
|
|
if image_cv is None:
|
|
logger.error(f"Failed to read image: {image_path}")
|
|
return [], []
|
|
|
|
# Get image dimensions
|
|
img_height, img_width = image_cv.shape[:2]
|
|
confidence_threshold = 0.5
|
|
|
|
# Use EasyOCR
|
|
import ssl
|
|
import easyocr
|
|
|
|
# Create unverified SSL context for development
|
|
ssl._create_default_https_context = ssl._create_unverified_context
|
|
try:
|
|
reader = easyocr.Reader(["en"])
|
|
with timeout(5): # 5 second timeout for EasyOCR
|
|
results = reader.readtext(image_cv, paragraph=False, text_threshold=0.5)
|
|
except TimeoutException:
|
|
logger.warning("EasyOCR timed out, returning no results")
|
|
return [], []
|
|
except Exception as e:
|
|
logger.warning(f"EasyOCR failed: {str(e)}")
|
|
return [], []
|
|
finally:
|
|
# Restore default SSL context
|
|
ssl._create_default_https_context = ssl.create_default_context
|
|
|
|
texts = []
|
|
boxes = []
|
|
|
|
for box, text, conf in results:
|
|
# Convert box format to [x1, y1, x2, y2]
|
|
x1 = min(point[0] for point in box)
|
|
y1 = min(point[1] for point in box)
|
|
x2 = max(point[0] for point in box)
|
|
y2 = max(point[1] for point in box)
|
|
|
|
if float(conf) > 0.5: # Only keep higher confidence detections
|
|
texts.append(text)
|
|
boxes.append([x1, y1, x2, y2])
|
|
|
|
return texts, boxes
|
|
|
|
|
|
class OmniParser:
|
|
"""Enhanced UI parser using computer vision and OCR for detecting interactive elements."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_path: Optional[Union[str, Path]] = None,
|
|
cache_dir: Optional[Union[str, Path]] = None,
|
|
force_device: Optional[str] = None,
|
|
):
|
|
"""Initialize the OmniParser.
|
|
|
|
Args:
|
|
model_path: Optional path to the YOLO model
|
|
cache_dir: Optional directory to cache model files
|
|
force_device: Force specific device (cpu/cuda/mps)
|
|
"""
|
|
self.detector = DetectionProcessor(
|
|
model_path=Path(model_path) if model_path else None,
|
|
cache_dir=Path(cache_dir) if cache_dir else None,
|
|
force_device=force_device,
|
|
)
|
|
self.ocr = OCRProcessor()
|
|
self.visualizer = BoxAnnotator()
|
|
|
|
def process_image(
|
|
self,
|
|
image: Image.Image,
|
|
box_threshold: float = 0.3,
|
|
iou_threshold: float = 0.1,
|
|
use_ocr: bool = True,
|
|
) -> Tuple[Image.Image, List[UIElement]]:
|
|
"""Process an image to detect UI elements and optionally text.
|
|
|
|
Args:
|
|
image: Input PIL Image
|
|
box_threshold: Confidence threshold for detection
|
|
iou_threshold: IOU threshold for NMS
|
|
use_ocr: Whether to enable OCR processing
|
|
|
|
Returns:
|
|
Tuple of (annotated image, list of detections)
|
|
"""
|
|
try:
|
|
logger.info("Starting UI element detection...")
|
|
|
|
# Detect icons
|
|
icon_detections = self.detector.detect_icons(
|
|
image=image, box_threshold=box_threshold, iou_threshold=iou_threshold
|
|
)
|
|
logger.info(f"Found {len(icon_detections)} interactive elements")
|
|
|
|
# Convert icon detections to typed objects
|
|
elements: List[UIElement] = cast(
|
|
List[UIElement],
|
|
[
|
|
IconElement(
|
|
id=i + 1,
|
|
bbox=BoundingBox(
|
|
x1=det["bbox"][0],
|
|
y1=det["bbox"][1],
|
|
x2=det["bbox"][2],
|
|
y2=det["bbox"][3],
|
|
),
|
|
confidence=det["confidence"],
|
|
scale=det.get("scale"),
|
|
)
|
|
for i, det in enumerate(icon_detections)
|
|
],
|
|
)
|
|
|
|
# Run OCR if enabled
|
|
if use_ocr:
|
|
logger.info("Running OCR detection...")
|
|
text_detections = self.ocr.detect_text(image=image, confidence_threshold=0.5)
|
|
if text_detections is None:
|
|
text_detections = []
|
|
logger.info(f"Found {len(text_detections)} text regions")
|
|
|
|
# Convert text detections to typed objects
|
|
text_elements = cast(
|
|
List[UIElement],
|
|
[
|
|
TextElement(
|
|
id=len(elements) + i + 1,
|
|
bbox=BoundingBox(
|
|
x1=det["bbox"][0],
|
|
y1=det["bbox"][1],
|
|
x2=det["bbox"][2],
|
|
y2=det["bbox"][3],
|
|
),
|
|
content=det["content"],
|
|
confidence=det["confidence"],
|
|
)
|
|
for i, det in enumerate(text_detections)
|
|
],
|
|
)
|
|
|
|
if elements and text_elements:
|
|
# Filter out non-OCR elements that have OCR elements with center points colliding with them
|
|
filtered_elements = []
|
|
for elem in elements: # elements at this point contains only non-OCR elements
|
|
should_keep = True
|
|
for text_elem in text_elements:
|
|
# Calculate center point of the text element
|
|
center_x = (text_elem.bbox.x1 + text_elem.bbox.x2) / 2
|
|
center_y = (text_elem.bbox.y1 + text_elem.bbox.y2) / 2
|
|
|
|
# Check if this center point is inside the non-OCR element
|
|
if (center_x >= elem.bbox.x1 and center_x <= elem.bbox.x2 and
|
|
center_y >= elem.bbox.y1 and center_y <= elem.bbox.y2):
|
|
should_keep = False
|
|
break
|
|
|
|
if should_keep:
|
|
filtered_elements.append(elem)
|
|
elements = filtered_elements
|
|
|
|
# Merge detections using NMS
|
|
all_elements = elements + text_elements
|
|
boxes = torch.tensor([elem.bbox.coordinates for elem in all_elements])
|
|
scores = torch.tensor([elem.confidence for elem in all_elements])
|
|
keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold)
|
|
elements = [all_elements[i] for i in keep_indices]
|
|
else:
|
|
# Just add text elements to the list if IOU doesn't need to be applied
|
|
elements.extend(text_elements)
|
|
|
|
# Calculate drawing parameters based on image size
|
|
box_overlay_ratio = max(image.size) / 3200
|
|
draw_config = {
|
|
"font_size": int(12 * box_overlay_ratio),
|
|
"box_thickness": max(int(2 * box_overlay_ratio), 1),
|
|
"text_padding": max(int(3 * box_overlay_ratio), 1),
|
|
}
|
|
|
|
# Convert elements back to dict format for visualization
|
|
detection_dicts = [
|
|
{
|
|
"type": elem.type,
|
|
"bbox": elem.bbox.coordinates,
|
|
"confidence": elem.confidence,
|
|
"content": elem.content if isinstance(elem, TextElement) else None,
|
|
}
|
|
for elem in elements
|
|
]
|
|
|
|
# Create visualization
|
|
logger.info("Creating visualization...")
|
|
annotated_image = self.visualizer.draw_boxes(
|
|
image=image.copy(), detections=detection_dicts, draw_config=draw_config
|
|
)
|
|
logger.info("Visualization complete")
|
|
|
|
return annotated_image, elements
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in process_image: {str(e)}")
|
|
import traceback
|
|
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
def parse(
|
|
self,
|
|
screenshot_data: Union[bytes, str],
|
|
box_threshold: float = 0.3,
|
|
iou_threshold: float = 0.1,
|
|
use_ocr: bool = True,
|
|
) -> ParseResult:
|
|
"""Parse a UI screenshot to detect interactive elements and text.
|
|
|
|
Args:
|
|
screenshot_data: Raw bytes or base64 string of the screenshot
|
|
box_threshold: Confidence threshold for detection
|
|
iou_threshold: IOU threshold for NMS
|
|
use_ocr: Whether to enable OCR processing
|
|
|
|
Returns:
|
|
ParseResult object containing elements, annotated image, and metadata
|
|
"""
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# Convert input to PIL Image
|
|
if isinstance(screenshot_data, str):
|
|
screenshot_data = base64.b64decode(screenshot_data)
|
|
image = Image.open(io.BytesIO(screenshot_data)).convert("RGB")
|
|
|
|
# Process image
|
|
annotated_image, elements = self.process_image(
|
|
image=image,
|
|
box_threshold=box_threshold,
|
|
iou_threshold=iou_threshold,
|
|
use_ocr=use_ocr,
|
|
)
|
|
|
|
# Convert annotated image to base64
|
|
buffered = io.BytesIO()
|
|
annotated_image.save(buffered, format="PNG")
|
|
annotated_image_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
|
|
|
# Generate screen info text
|
|
screen_info = []
|
|
parsed_content_list = []
|
|
|
|
# Set element IDs and generate human-readable descriptions
|
|
for i, elem in enumerate(elements):
|
|
# Set the ID (1-indexed)
|
|
elem.id = i + 1
|
|
|
|
if isinstance(elem, IconElement):
|
|
screen_info.append(
|
|
f"Box #{i+1}: Icon (confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates})"
|
|
)
|
|
parsed_content_list.append(
|
|
{
|
|
"id": i + 1,
|
|
"type": "icon",
|
|
"bbox": elem.bbox.coordinates,
|
|
"confidence": elem.confidence,
|
|
"content": None,
|
|
}
|
|
)
|
|
elif isinstance(elem, TextElement):
|
|
screen_info.append(
|
|
f"Box #{i+1}: Text '{elem.content}' (confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates})"
|
|
)
|
|
parsed_content_list.append(
|
|
{
|
|
"id": i + 1,
|
|
"type": "text",
|
|
"bbox": elem.bbox.coordinates,
|
|
"confidence": elem.confidence,
|
|
"content": elem.content,
|
|
}
|
|
)
|
|
|
|
# Calculate metadata
|
|
latency = time.time() - start_time
|
|
width, height = image.size
|
|
|
|
# Create ParseResult object with enhanced properties
|
|
result = ParseResult(
|
|
elements=elements,
|
|
annotated_image_base64=annotated_image_base64,
|
|
screen_info=screen_info,
|
|
parsed_content_list=parsed_content_list,
|
|
metadata=ParserMetadata(
|
|
image_size=(width, height),
|
|
num_icons=len([e for e in elements if isinstance(e, IconElement)]),
|
|
num_text=len([e for e in elements if isinstance(e, TextElement)]),
|
|
device=self.detector.device,
|
|
ocr_enabled=use_ocr,
|
|
latency=latency,
|
|
),
|
|
)
|
|
|
|
# Return the ParseResult object directly
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in parse: {str(e)}")
|
|
import traceback
|
|
|
|
logger.error(traceback.format_exc())
|
|
raise
|
|
|
|
|
|
def main():
|
|
"""Command line interface for UI element detection."""
|
|
parser = argparse.ArgumentParser(description="Detect UI elements and text in images")
|
|
parser.add_argument("image_path", help="Path to the input image")
|
|
parser.add_argument("--model-path", help="Path to YOLO model")
|
|
parser.add_argument(
|
|
"--box-threshold", type=float, default=0.3, help="Box confidence threshold (default: 0.3)"
|
|
)
|
|
parser.add_argument(
|
|
"--iou-threshold", type=float, default=0.1, help="IOU threshold (default: 0.1)"
|
|
)
|
|
parser.add_argument(
|
|
"--ocr", action="store_true", default=True, help="Enable OCR processing (default: True)"
|
|
)
|
|
parser.add_argument("--output", help="Output path for annotated image")
|
|
args = parser.parse_args()
|
|
|
|
# Setup logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
try:
|
|
# Initialize parser
|
|
parser = OmniParser(model_path=args.model_path)
|
|
|
|
# Load and process image
|
|
logger.info(f"Loading image from: {args.image_path}")
|
|
image = Image.open(args.image_path).convert("RGB")
|
|
logger.info(f"Image loaded successfully, size: {image.size}")
|
|
|
|
# Process image
|
|
annotated_image, elements = parser.process_image(
|
|
image=image,
|
|
box_threshold=args.box_threshold,
|
|
iou_threshold=args.iou_threshold,
|
|
use_ocr=args.ocr,
|
|
)
|
|
|
|
# Save output image
|
|
output_path = args.output or str(
|
|
Path(args.image_path).parent
|
|
/ f"{Path(args.image_path).stem}_analyzed{Path(args.image_path).suffix}"
|
|
)
|
|
logger.info(f"Saving annotated image to: {output_path}")
|
|
|
|
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
annotated_image.save(output_path)
|
|
logger.info(f"Image saved successfully to {output_path}")
|
|
|
|
# Print detections
|
|
logger.info("\nDetections:")
|
|
for i, elem in enumerate(elements):
|
|
if isinstance(elem, IconElement):
|
|
logger.info(
|
|
f"Interactive element {i}: confidence={elem.confidence:.3f}, bbox={elem.bbox.coordinates}"
|
|
)
|
|
elif isinstance(elem, TextElement):
|
|
logger.info(f"Text {i}: '{elem.content}', bbox={elem.bbox.coordinates}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing image: {str(e)}")
|
|
import traceback
|
|
|
|
logger.error(traceback.format_exc())
|
|
return 1
|
|
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
import sys
|
|
|
|
sys.exit(main())
|