From 85404cccab24b5feead46df662d9f8f0e5ace6ff Mon Sep 17 00:00:00 2001 From: Dillon DuPont Date: Thu, 27 Mar 2025 11:27:04 -0400 Subject: [PATCH] apply nms merge to ocr detections --- libs/som/som/detect.py | 52 +++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 19 deletions(-) diff --git a/libs/som/som/detect.py b/libs/som/som/detect.py index 10a21804..41ab9ca3 100644 --- a/libs/som/som/detect.py +++ b/libs/som/som/detect.py @@ -213,26 +213,40 @@ class OmniParser: text_detections = [] logger.info(f"Found {len(text_detections)} text regions") - # Convert text detections to typed objects and extend the list - elements.extend( - cast( - List[UIElement], - [ - TextElement( - id=len(elements) + i + 1, - bbox=BoundingBox( - x1=det["bbox"][0], - y1=det["bbox"][1], - x2=det["bbox"][2], - y2=det["bbox"][3], - ), - content=det["content"], - confidence=det["confidence"], - ) - for i, det in enumerate(text_detections) - ], - ) + # Convert text detections to typed objects + text_elements = cast( + List[UIElement], + [ + TextElement( + id=len(elements) + i + 1, + bbox=BoundingBox( + x1=det["bbox"][0], + y1=det["bbox"][1], + x2=det["bbox"][2], + y2=det["bbox"][3], + ), + content=det["content"], + confidence=det["confidence"], + ) + for i, det in enumerate(text_detections) + ], ) + + # Merge detections using NMS + if elements and text_elements: + # Get all bounding boxes and scores + all_elements = elements + text_elements + boxes = torch.tensor([elem.bbox.coordinates for elem in all_elements]) + scores = torch.tensor([elem.confidence for elem in all_elements]) + + # Apply NMS with iou_threshold + keep_indices = torchvision.ops.nms(boxes, scores, iou_threshold) + + # Keep only the elements that passed NMS + elements = [all_elements[i] for i in keep_indices] + else: + # Just add text elements to the list if IOU doesn't need to be applied + elements.extend(text_elements) # Calculate drawing parameters based on image size box_overlay_ratio = max(image.size) / 3200