diff --git a/examples/som_examples.py b/examples/som_examples.py index 75b798ac..4dc3e38b 100644 --- a/examples/som_examples.py +++ b/examples/som_examples.py @@ -318,7 +318,6 @@ def run_experiments(input_path: str, output_dir: Path, use_ocr: bool = False): # Update totals total_icons += result.metadata.num_icons total_text += result.metadata.num_text - total_time += t.elapsed_time # Log detailed results detail_file = combo_dir / f"{Path(image_path).stem}_details.txt" @@ -360,6 +359,9 @@ def run_experiments(input_path: str, output_dir: Path, use_ocr: bool = False): ) text_count += 1 + # Update timing totals + total_time += t.elapsed_time + # Write summary for this combination avg_time = total_time / len(image_files) f.write( diff --git a/libs/som/som/detect.py b/libs/som/som/detect.py index 10a21804..79e64886 100644 --- a/libs/som/som/detect.py +++ b/libs/som/som/detect.py @@ -213,26 +213,54 @@ class OmniParser: text_detections = [] logger.info(f"Found {len(text_detections)} text regions") - # Convert text detections to typed objects and extend the list - elements.extend( - cast( - List[UIElement], - [ - TextElement( - id=len(elements) + i + 1, - bbox=BoundingBox( - x1=det["bbox"][0], - y1=det["bbox"][1], - x2=det["bbox"][2], - y2=det["bbox"][3], - ), - content=det["content"], - confidence=det["confidence"], - ) - for i, det in enumerate(text_detections) - ], - ) + # Convert text detections to typed objects + text_elements = cast( + List[UIElement], + [ + TextElement( + id=len(elements) + i + 1, + bbox=BoundingBox( + x1=det["bbox"][0], + y1=det["bbox"][1], + x2=det["bbox"][2], + y2=det["bbox"][3], + ), + content=det["content"], + confidence=det["confidence"], + ) + for i, det in enumerate(text_detections) + ], ) + + if elements and text_elements: + # Filter out non-OCR elements that have OCR elements with center points colliding with them + filtered_elements = [] + for elem in elements: # elements at this point contains only non-OCR elements + should_keep = True + for text_elem in text_elements: + # Calculate center point of the text element + center_x = (text_elem.bbox.x1 + text_elem.bbox.x2) / 2 + center_y = (text_elem.bbox.y1 + text_elem.bbox.y2) / 2 + + # Check if this center point is inside the non-OCR element + if (center_x >= elem.bbox.x1 and center_x <= elem.bbox.x2 and + center_y >= elem.bbox.y1 and center_y <= elem.bbox.y2): + should_keep = False + break + + if should_keep: + filtered_elements.append(elem) + elements = filtered_elements + + # Merge detections using NMS + all_elements = elements + text_elements + boxes = torch.tensor([elem.bbox.coordinates for elem in all_elements]) + scores = torch.tensor([elem.confidence for elem in all_elements]) + keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold) + elements = [all_elements[i] for i in keep_indices] + else: + # Just add text elements to the list if IOU doesn't need to be applied + elements.extend(text_elements) # Calculate drawing parameters based on image size box_overlay_ratio = max(image.size) / 3200 diff --git a/libs/som/som/ocr.py b/libs/som/som/ocr.py index a206e057..6d10e85a 100644 --- a/libs/som/som/ocr.py +++ b/libs/som/som/ocr.py @@ -67,17 +67,9 @@ class OCRProcessor: import easyocr # Use GPU if available - use_gpu = self.device in ["cuda"] # MPS not directly supported by EasyOCR - - # If using MPS, add warnings to explain why CPU is used - if self.device == "mps": - logger.warning("EasyOCR doesn't support MPS directly. Using CPU instead.") - logger.warning( - "To silence this warning, set environment variable: PYTORCH_ENABLE_MPS_FALLBACK=1" - ) - + use_gpu = self.device in ["cuda", "mps"] self.reader = easyocr.Reader(["en"], gpu=use_gpu) - + # Verify reader initialization if self.reader is None: raise ValueError("Failed to initialize EasyOCR reader") diff --git a/libs/som/som/visualization.py b/libs/som/som/visualization.py index 4212379c..038af0f5 100644 --- a/libs/som/som/visualization.py +++ b/libs/som/som/visualization.py @@ -112,10 +112,18 @@ class BoxAnnotator: # Keep track of used label areas to check for collisions used_areas = [] - # Store label information for second pass + # Store label information for third pass labels_to_draw = [] - # First pass: Draw all bounding boxes + # First pass: Initialize used_areas with all bounding boxes + for detection in detections: + box = detection["bbox"] + x1, y1, x2, y2 = [ + int(coord * dim) for coord, dim in zip(box, [image.width, image.height] * 2) + ] + used_areas.append((x1, y1, x2, y2)) + + # Second pass: Draw all bounding boxes for idx, detection in enumerate(detections, 1): # Get box coordinates box = detection["bbox"] @@ -166,22 +174,31 @@ class BoxAnnotator: lambda: (x1 - box_width - spacing, y2 + spacing), ] - def check_collision(x, y): - """Check if a label box collides with any existing ones or is inside bbox.""" + def check_occlusion(x, y): + """Check if a label box occludes any existing ones or is inside bbox.""" # First check if it's inside the bounding box if is_inside_bbox(x, y): return True # Then check collision with other labels new_box = (x, y, x + box_width, y + box_height) + label_width = new_box[2] - new_box[0] + label_height = new_box[3] - new_box[1] + for used_box in used_areas: if not ( new_box[2] < used_box[0] # new box is left of used box or new_box[0] > used_box[2] # new box is right of used box or new_box[3] < used_box[1] # new box is above used box - or new_box[1] > used_box[3] - ): # new box is below used box - return True + or new_box[1] > used_box[3] # new box is below used box + ): + # Calculate dimensions of the used box + used_box_width = used_box[2] - used_box[0] + used_box_height = used_box[3] - used_box[1] + + # Only consider as collision if used box is NOT more than 5x bigger in both dimensions + if not (used_box_width > 5 * label_width and used_box_height > 5 * label_height): + return True return False # Try each position until we find one without collision @@ -193,7 +210,7 @@ class BoxAnnotator: # Ensure position is within image bounds if x < 0 or y < 0 or x + box_width > image.width or y + box_height > image.height: continue - if not check_collision(x, y): + if not check_occlusion(x, y): label_x = x label_y = y break @@ -233,7 +250,7 @@ class BoxAnnotator: } ) - # Second pass: Draw all labels on top + # Third pass: Draw all labels on top for label_info in labels_to_draw: # Draw background box with white outline draw.rectangle(