mirror of
https://github.com/trycua/computer.git
synced 2026-01-07 05:50:13 -06:00
Merge pull request #87 from ddupont808/feature/som/fix-overlap-cpu
[SOM] Fix overlapping bounding boxes and added GPU/MPS support
This commit is contained in:
@@ -318,7 +318,6 @@ def run_experiments(input_path: str, output_dir: Path, use_ocr: bool = False):
|
||||
# Update totals
|
||||
total_icons += result.metadata.num_icons
|
||||
total_text += result.metadata.num_text
|
||||
total_time += t.elapsed_time
|
||||
|
||||
# Log detailed results
|
||||
detail_file = combo_dir / f"{Path(image_path).stem}_details.txt"
|
||||
@@ -360,6 +359,9 @@ def run_experiments(input_path: str, output_dir: Path, use_ocr: bool = False):
|
||||
)
|
||||
text_count += 1
|
||||
|
||||
# Update timing totals
|
||||
total_time += t.elapsed_time
|
||||
|
||||
# Write summary for this combination
|
||||
avg_time = total_time / len(image_files)
|
||||
f.write(
|
||||
|
||||
@@ -213,26 +213,54 @@ class OmniParser:
|
||||
text_detections = []
|
||||
logger.info(f"Found {len(text_detections)} text regions")
|
||||
|
||||
# Convert text detections to typed objects and extend the list
|
||||
elements.extend(
|
||||
cast(
|
||||
List[UIElement],
|
||||
[
|
||||
TextElement(
|
||||
id=len(elements) + i + 1,
|
||||
bbox=BoundingBox(
|
||||
x1=det["bbox"][0],
|
||||
y1=det["bbox"][1],
|
||||
x2=det["bbox"][2],
|
||||
y2=det["bbox"][3],
|
||||
),
|
||||
content=det["content"],
|
||||
confidence=det["confidence"],
|
||||
)
|
||||
for i, det in enumerate(text_detections)
|
||||
],
|
||||
)
|
||||
# Convert text detections to typed objects
|
||||
text_elements = cast(
|
||||
List[UIElement],
|
||||
[
|
||||
TextElement(
|
||||
id=len(elements) + i + 1,
|
||||
bbox=BoundingBox(
|
||||
x1=det["bbox"][0],
|
||||
y1=det["bbox"][1],
|
||||
x2=det["bbox"][2],
|
||||
y2=det["bbox"][3],
|
||||
),
|
||||
content=det["content"],
|
||||
confidence=det["confidence"],
|
||||
)
|
||||
for i, det in enumerate(text_detections)
|
||||
],
|
||||
)
|
||||
|
||||
if elements and text_elements:
|
||||
# Filter out non-OCR elements that have OCR elements with center points colliding with them
|
||||
filtered_elements = []
|
||||
for elem in elements: # elements at this point contains only non-OCR elements
|
||||
should_keep = True
|
||||
for text_elem in text_elements:
|
||||
# Calculate center point of the text element
|
||||
center_x = (text_elem.bbox.x1 + text_elem.bbox.x2) / 2
|
||||
center_y = (text_elem.bbox.y1 + text_elem.bbox.y2) / 2
|
||||
|
||||
# Check if this center point is inside the non-OCR element
|
||||
if (center_x >= elem.bbox.x1 and center_x <= elem.bbox.x2 and
|
||||
center_y >= elem.bbox.y1 and center_y <= elem.bbox.y2):
|
||||
should_keep = False
|
||||
break
|
||||
|
||||
if should_keep:
|
||||
filtered_elements.append(elem)
|
||||
elements = filtered_elements
|
||||
|
||||
# Merge detections using NMS
|
||||
all_elements = elements + text_elements
|
||||
boxes = torch.tensor([elem.bbox.coordinates for elem in all_elements])
|
||||
scores = torch.tensor([elem.confidence for elem in all_elements])
|
||||
keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold)
|
||||
elements = [all_elements[i] for i in keep_indices]
|
||||
else:
|
||||
# Just add text elements to the list if IOU doesn't need to be applied
|
||||
elements.extend(text_elements)
|
||||
|
||||
# Calculate drawing parameters based on image size
|
||||
box_overlay_ratio = max(image.size) / 3200
|
||||
|
||||
@@ -67,17 +67,9 @@ class OCRProcessor:
|
||||
import easyocr
|
||||
|
||||
# Use GPU if available
|
||||
use_gpu = self.device in ["cuda"] # MPS not directly supported by EasyOCR
|
||||
|
||||
# If using MPS, add warnings to explain why CPU is used
|
||||
if self.device == "mps":
|
||||
logger.warning("EasyOCR doesn't support MPS directly. Using CPU instead.")
|
||||
logger.warning(
|
||||
"To silence this warning, set environment variable: PYTORCH_ENABLE_MPS_FALLBACK=1"
|
||||
)
|
||||
|
||||
use_gpu = self.device in ["cuda", "mps"]
|
||||
self.reader = easyocr.Reader(["en"], gpu=use_gpu)
|
||||
|
||||
|
||||
# Verify reader initialization
|
||||
if self.reader is None:
|
||||
raise ValueError("Failed to initialize EasyOCR reader")
|
||||
|
||||
@@ -112,10 +112,18 @@ class BoxAnnotator:
|
||||
# Keep track of used label areas to check for collisions
|
||||
used_areas = []
|
||||
|
||||
# Store label information for second pass
|
||||
# Store label information for third pass
|
||||
labels_to_draw = []
|
||||
|
||||
# First pass: Draw all bounding boxes
|
||||
# First pass: Initialize used_areas with all bounding boxes
|
||||
for detection in detections:
|
||||
box = detection["bbox"]
|
||||
x1, y1, x2, y2 = [
|
||||
int(coord * dim) for coord, dim in zip(box, [image.width, image.height] * 2)
|
||||
]
|
||||
used_areas.append((x1, y1, x2, y2))
|
||||
|
||||
# Second pass: Draw all bounding boxes
|
||||
for idx, detection in enumerate(detections, 1):
|
||||
# Get box coordinates
|
||||
box = detection["bbox"]
|
||||
@@ -166,22 +174,31 @@ class BoxAnnotator:
|
||||
lambda: (x1 - box_width - spacing, y2 + spacing),
|
||||
]
|
||||
|
||||
def check_collision(x, y):
|
||||
"""Check if a label box collides with any existing ones or is inside bbox."""
|
||||
def check_occlusion(x, y):
|
||||
"""Check if a label box occludes any existing ones or is inside bbox."""
|
||||
# First check if it's inside the bounding box
|
||||
if is_inside_bbox(x, y):
|
||||
return True
|
||||
|
||||
# Then check collision with other labels
|
||||
new_box = (x, y, x + box_width, y + box_height)
|
||||
label_width = new_box[2] - new_box[0]
|
||||
label_height = new_box[3] - new_box[1]
|
||||
|
||||
for used_box in used_areas:
|
||||
if not (
|
||||
new_box[2] < used_box[0] # new box is left of used box
|
||||
or new_box[0] > used_box[2] # new box is right of used box
|
||||
or new_box[3] < used_box[1] # new box is above used box
|
||||
or new_box[1] > used_box[3]
|
||||
): # new box is below used box
|
||||
return True
|
||||
or new_box[1] > used_box[3] # new box is below used box
|
||||
):
|
||||
# Calculate dimensions of the used box
|
||||
used_box_width = used_box[2] - used_box[0]
|
||||
used_box_height = used_box[3] - used_box[1]
|
||||
|
||||
# Only consider as collision if used box is NOT more than 5x bigger in both dimensions
|
||||
if not (used_box_width > 5 * label_width and used_box_height > 5 * label_height):
|
||||
return True
|
||||
return False
|
||||
|
||||
# Try each position until we find one without collision
|
||||
@@ -193,7 +210,7 @@ class BoxAnnotator:
|
||||
# Ensure position is within image bounds
|
||||
if x < 0 or y < 0 or x + box_width > image.width or y + box_height > image.height:
|
||||
continue
|
||||
if not check_collision(x, y):
|
||||
if not check_occlusion(x, y):
|
||||
label_x = x
|
||||
label_y = y
|
||||
break
|
||||
@@ -233,7 +250,7 @@ class BoxAnnotator:
|
||||
}
|
||||
)
|
||||
|
||||
# Second pass: Draw all labels on top
|
||||
# Third pass: Draw all labels on top
|
||||
for label_info in labels_to_draw:
|
||||
# Draw background box with white outline
|
||||
draw.rectangle(
|
||||
|
||||
Reference in New Issue
Block a user