| import base64 |
| import io |
| from typing import Any, Dict, List, Literal, Optional, Tuple, Union |
|
|
| import cv2 |
| import easyocr |
| import numpy as np |
| import torch |
| from PIL import Image |
| from PIL.Image import Image as ImageType |
| from supervision.detection.core import Detections |
| from supervision.draw.color import Color, ColorPalette |
| from torchvision.ops import box_convert |
| from torchvision.transforms import ToPILImage |
| from transformers import AutoModelForCausalLM, AutoProcessor |
| from transformers.image_utils import load_image |
| from ultralytics import YOLO |
|
|
| |
| |
| easyocr.Reader(["en"]) |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, model_dir: str = "/repository") -> None: |
| self.device = ( |
| torch.device("cuda") if torch.cuda.is_available() |
| else (torch.device("mps") if torch.backends.mps.is_available() |
| else torch.device("cpu")) |
| ) |
|
|
| |
| self.yolo = YOLO(f"{model_dir}/icon_detect/model.pt") |
|
|
| |
| self.processor = AutoProcessor.from_pretrained( |
| "microsoft/Florence-2-base", trust_remote_code=True |
| ) |
| self.model = AutoModelForCausalLM.from_pretrained( |
| f"{model_dir}/icon_caption", |
| torch_dtype=torch.float16, |
| trust_remote_code=True, |
| ).to(self.device) |
|
|
| |
| self.ocr = easyocr.Reader(["en"]) |
|
|
| |
| self.annotator = BoxAnnotator() |
|
|
| def __call__(self, data: Dict[str, Any]) -> Any: |
| |
| |
| |
| |
| |
| |
| |
| data = data.pop("inputs") |
|
|
| |
| image = load_image(data["image"]) |
|
|
| ocr_texts, ocr_bboxes = self.check_ocr_bboxes( |
| image, |
| out_format="xyxy", |
| ocr_kwargs={"text_threshold": 0.8}, |
| ) |
| annotated_image, filtered_bboxes_out = self.get_som_labeled_img( |
| image, |
| image_size=data.get("image_size", None), |
| ocr_texts=ocr_texts, |
| ocr_bboxes=ocr_bboxes, |
| bbox_threshold=data.get("bbox_threshold", 0.05), |
| iou_threshold=data.get("iou_threshold", None), |
| ) |
| return { |
| "image": annotated_image, |
| "bboxes": filtered_bboxes_out, |
| } |
|
|
| def check_ocr_bboxes( |
| self, |
| image: ImageType, |
| out_format: Literal["xywh", "xyxy"] = "xywh", |
| ocr_kwargs: Optional[Dict[str, Any]] = {}, |
| ) -> Tuple[List[str], List[List[int]]]: |
| if image.mode == "RBGA": |
| image = image.convert("RGB") |
|
|
| result = self.ocr.readtext(np.array(image), **ocr_kwargs) |
| texts = [str(item[1]) for item in result] |
| bboxes = [ |
| self.coordinates_to_bbox(item[0], format=out_format) for item in result |
| ] |
| return (texts, bboxes) |
|
|
| @staticmethod |
| def coordinates_to_bbox( |
| coordinates: np.ndarray, format: Literal["xywh", "xyxy"] = "xywh" |
| ) -> List[int]: |
| match format: |
| case "xywh": |
| return [ |
| int(coordinates[0][0]), |
| int(coordinates[0][1]), |
| int(coordinates[2][0] - coordinates[0][0]), |
| int(coordinates[2][1] - coordinates[0][1]), |
| ] |
| case "xyxy": |
| return [ |
| int(coordinates[0][0]), |
| int(coordinates[0][1]), |
| int(coordinates[2][0]), |
| int(coordinates[2][1]), |
| ] |
|
|
| @staticmethod |
| def bbox_area(bbox: List[int], w: int, h: int) -> int: |
| bbox = [bbox[0] * w, bbox[1] * h, bbox[2] * w, bbox[3] * h] |
| return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) |
|
|
| @staticmethod |
| def remove_bbox_overlap( |
| xyxy_bboxes: List[Dict[str, Any]], |
| ocr_bboxes: Optional[List[Dict[str, Any]]] = None, |
| iou_threshold: Optional[float] = 0.7, |
| ) -> List[Dict[str, Any]]: |
| filtered_bboxes = [] |
| if ocr_bboxes is not None: |
| filtered_bboxes.extend(ocr_bboxes) |
|
|
| for i, bbox_outter in enumerate(xyxy_bboxes): |
| bbox_left = bbox_outter["bbox"] |
| valid_bbox = True |
|
|
| for j, bbox_inner in enumerate(xyxy_bboxes): |
| if i == j: |
| continue |
|
|
| bbox_right = bbox_inner["bbox"] |
| if ( |
| intersection_over_union( |
| bbox_left, |
| bbox_right, |
| ) |
| > iou_threshold |
| ) and (area(bbox_left) > area(bbox_right)): |
| valid_bbox = False |
| break |
|
|
| if valid_bbox is False: |
| continue |
|
|
| if ocr_bboxes is None: |
| filtered_bboxes.append(bbox_outter) |
| continue |
|
|
| box_added = False |
| ocr_labels = [] |
| for ocr_bbox in ocr_bboxes: |
| if not box_added: |
| bbox_right = ocr_bbox["bbox"] |
| if overlap(bbox_right, bbox_left): |
| try: |
| ocr_labels.append(ocr_bbox["content"]) |
| filtered_bboxes.remove(ocr_bbox) |
| except Exception: |
| continue |
| elif overlap(bbox_left, bbox_right): |
| box_added = True |
| break |
|
|
| if not box_added: |
| filtered_bboxes.append( |
| { |
| "type": "icon", |
| "bbox": bbox_outter["bbox"], |
| "interactivity": True, |
| "content": " ".join(ocr_labels) if ocr_labels else None, |
| } |
| ) |
|
|
| return filtered_bboxes |
|
|
| def get_som_labeled_img( |
| self, |
| image: ImageType, |
| image_size: Optional[Dict[Literal["w", "h"], int]] = None, |
| ocr_texts: Optional[List[str]] = None, |
| ocr_bboxes: Optional[List[List[int]]] = None, |
| bbox_threshold: float = 0.01, |
| iou_threshold: Optional[float] = None, |
| caption_prompt: Optional[str] = None, |
| caption_batch_size: int = 64, |
| ) -> Tuple[str, List[Dict[str, Any]]]: |
| if image.mode == "RBGA": |
| image = image.convert("RGB") |
|
|
| w, h = image.size |
| if image_size is None: |
| imgsz = {"h": h, "w": w} |
| else: |
| imgsz = [image_size.get("h", h), image_size.get("w", w)] |
|
|
| out = self.yolo.predict( |
| image, |
| imgsz=imgsz, |
| conf=bbox_threshold, |
| iou=iou_threshold or 0.7, |
| verbose=False, |
| )[0] |
| if out.boxes is None: |
| raise RuntimeError( |
| "YOLO prediction failed to produce the bounding boxes..." |
| ) |
|
|
| xyxy_bboxes = out.boxes.xyxy |
| xyxy_bboxes = xyxy_bboxes / torch.Tensor([w, h, w, h]).to(xyxy_bboxes.device) |
| image_np = np.asarray(image) |
|
|
| if ocr_bboxes: |
| ocr_bboxes = torch.tensor(ocr_bboxes) / torch.Tensor([w, h, w, h]) |
| ocr_bboxes = ocr_bboxes.tolist() |
|
|
| ocr_bboxes = [ |
| { |
| "type": "text", |
| "bbox": bbox, |
| "interactivity": False, |
| "content": text, |
| "source": "box_ocr_content_ocr", |
| } |
| for bbox, text in zip(ocr_bboxes, ocr_texts) |
| if self.bbox_area(bbox, w, h) > 0 |
| ] |
| xyxy_bboxes = [ |
| { |
| "type": "icon", |
| "bbox": bbox, |
| "interactivity": True, |
| "content": None, |
| "source": "box_yolo_content_yolo", |
| } |
| for bbox in xyxy_bboxes.tolist() |
| if self.bbox_area(bbox, w, h) > 0 |
| ] |
|
|
| filtered_bboxes = self.remove_bbox_overlap( |
| xyxy_bboxes=xyxy_bboxes, |
| ocr_bboxes=ocr_bboxes, |
| iou_threshold=iou_threshold or 0.7, |
| ) |
|
|
| filtered_bboxes_out = sorted( |
| filtered_bboxes, key=lambda x: x["content"] is None |
| ) |
| starting_idx = next( |
| ( |
| idx |
| for idx, bbox in enumerate(filtered_bboxes_out) |
| if bbox["content"] is None |
| ), |
| -1, |
| ) |
|
|
| filtered_bboxes = torch.tensor([box["bbox"] for box in filtered_bboxes_out]) |
| non_ocr_bboxes = filtered_bboxes[starting_idx:] |
|
|
| bbox_images = [] |
| for _, coordinates in enumerate(non_ocr_bboxes): |
| try: |
| xmin, xmax = ( |
| int(coordinates[0] * image_np.shape[1]), |
| int(coordinates[2] * image_np.shape[1]), |
| ) |
| ymin, ymax = ( |
| int(coordinates[1] * image_np.shape[0]), |
| int(coordinates[3] * image_np.shape[0]), |
| ) |
| cropped_image = image_np[ymin:ymax, xmin:xmax, :] |
| cropped_image = cv2.resize(cropped_image, (64, 64)) |
| bbox_images.append(ToPILImage()(cropped_image)) |
| except Exception: |
| continue |
|
|
| if caption_prompt is None: |
| caption_prompt = "<CAPTION>" |
|
|
| captions = [] |
| for idx in range(0, len(bbox_images), caption_batch_size): |
| batch = bbox_images[idx : idx + caption_batch_size] |
| inputs = self.processor( |
| images=batch, |
| text=[caption_prompt] * len(batch), |
| return_tensors="pt", |
| do_resize=False, |
| ) |
| if self.device.type in {"cuda", "mps"}: |
| inputs = inputs.to(device=self.device, dtype=torch.float16) |
|
|
| with torch.inference_mode(): |
| generated_ids = self.model.generate( |
| input_ids=inputs["input_ids"], |
| pixel_values=inputs["pixel_values"], |
| max_new_tokens=20, |
| num_beams=1, |
| do_sample=False, |
| early_stopping=False, |
| ) |
|
|
| generated_texts = self.processor.batch_decode( |
| generated_ids, skip_special_tokens=True |
| ) |
| captions.extend([text.strip() for text in generated_texts]) |
|
|
| ocr_texts = [f"Text Box ID {idx}: {text}" for idx, text in enumerate(ocr_texts)] |
| for _, bbox in enumerate(filtered_bboxes_out): |
| if bbox["content"] is None: |
| bbox["content"] = captions.pop(0) |
|
|
| filtered_bboxes = box_convert( |
| boxes=filtered_bboxes, in_fmt="xyxy", out_fmt="cxcywh" |
| ) |
|
|
| annotated_image = image_np.copy() |
| bboxes_annotate = filtered_bboxes * torch.Tensor([w, h, w, h]) |
| xyxy_annotate = box_convert( |
| bboxes_annotate, in_fmt="cxcywh", out_fmt="xyxy" |
| ).numpy() |
| detections = Detections(xyxy=xyxy_annotate) |
| labels = [str(idx) for idx in range(bboxes_annotate.shape[0])] |
|
|
| annotated_image = self.annotator.annotate( |
| scene=annotated_image, |
| detections=detections, |
| labels=labels, |
| image_size=(w, h), |
| ) |
| assert w == annotated_image.shape[1] and h == annotated_image.shape[0] |
|
|
| out_image = Image.fromarray(annotated_image) |
| out_buffer = io.BytesIO() |
| out_image.save(out_buffer, format="PNG") |
| encoded_image = base64.b64encode(out_buffer.getvalue()).decode("ascii") |
|
|
| return encoded_image, filtered_bboxes_out |
|
|
|
|
| def area(bbox: List[int]) -> int: |
| return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) |
|
|
|
|
| def intersection_area(bbox_left: List[int], bbox_right: List[int]) -> int: |
| return max( |
| 0, min(bbox_left[2], bbox_right[2]) - min(bbox_left[0], bbox_right[0]) |
| ) * max(0, min(bbox_left[3], bbox_right[3]) - min(bbox_left[1], bbox_right[1])) |
|
|
|
|
| def intersection_over_union(bbox_left: List[int], bbox_right: List[int]) -> float: |
| intersection = intersection_area(bbox_left, bbox_right) |
| bbox_left_area = area(bbox_left) |
| bbox_right_area = area(bbox_right) |
| union = bbox_left_area + bbox_right_area - intersection + 1e-6 |
|
|
| ratio_left, ratio_right = 0, 0 |
| if bbox_left_area > 0 and bbox_right_area > 0: |
| ratio_left = intersection / bbox_left_area |
| ratio_right = intersection / bbox_right_area |
| return max(intersection / union, ratio_left, ratio_right) |
|
|
|
|
| def overlap(bbox_left: List[int], bbox_right: List[int]) -> bool: |
| intersection = intersection_area(bbox_left, bbox_right) |
| ratio_left = intersection / area(bbox_left) |
| return ratio_left > 0.80 |
|
|
|
|
| class BoxAnnotator: |
| def __init__( |
| self, |
| color: Union[Color, ColorPalette] = ColorPalette.DEFAULT, |
| thickness: int = 3, |
| text_color: Color = Color.BLACK, |
| text_scale: float = 0.5, |
| text_thickness: int = 2, |
| text_padding: int = 10, |
| avoid_overlap: bool = True, |
| ): |
| self.color: Union[Color, ColorPalette] = color |
| self.thickness: int = thickness |
| self.text_color: Color = text_color |
| self.text_scale: float = text_scale |
| self.text_thickness: int = text_thickness |
| self.text_padding: int = text_padding |
| self.avoid_overlap: bool = avoid_overlap |
|
|
| def annotate( |
| self, |
| scene: np.ndarray, |
| detections: Detections, |
| labels: Optional[List[str]] = None, |
| skip_label: bool = False, |
| image_size: Optional[Tuple[int, int]] = None, |
| ) -> np.ndarray: |
| font = cv2.FONT_HERSHEY_SIMPLEX |
| for i in range(len(detections)): |
| x1, y1, x2, y2 = detections.xyxy[i].astype(int) |
| class_id = ( |
| detections.class_id[i] if detections.class_id is not None else None |
| ) |
| idx = class_id if class_id is not None else i |
| color = ( |
| self.color.by_idx(idx) |
| if isinstance(self.color, ColorPalette) |
| else self.color |
| ) |
| cv2.rectangle( |
| img=scene, |
| pt1=(x1, y1), |
| pt2=(x2, y2), |
| color=color.as_bgr(), |
| thickness=self.thickness, |
| ) |
| if skip_label: |
| continue |
|
|
| text = ( |
| f"{class_id}" |
| if (labels is None or len(detections) != len(labels)) |
| else labels[i] |
| ) |
|
|
| text_width, text_height = cv2.getTextSize( |
| text=text, |
| fontFace=font, |
| fontScale=self.text_scale, |
| thickness=self.text_thickness, |
| )[0] |
|
|
| if not self.avoid_overlap: |
| text_x = x1 + self.text_padding |
| text_y = y1 - self.text_padding |
|
|
| text_background_x1 = x1 |
| text_background_y1 = y1 - 2 * self.text_padding - text_height |
|
|
| text_background_x2 = x1 + 2 * self.text_padding + text_width |
| text_background_y2 = y1 |
| else: |
| ( |
| text_x, |
| text_y, |
| text_background_x1, |
| text_background_y1, |
| text_background_x2, |
| text_background_y2, |
| ) = self.get_optimal_label_pos( |
| self.text_padding, |
| text_width, |
| text_height, |
| x1, |
| y1, |
| x2, |
| y2, |
| detections, |
| image_size, |
| ) |
|
|
| cv2.rectangle( |
| img=scene, |
| pt1=(text_background_x1, text_background_y1), |
| pt2=(text_background_x2, text_background_y2), |
| color=color.as_bgr(), |
| thickness=cv2.FILLED, |
| ) |
| box_color = color.as_rgb() |
| luminance = ( |
| 0.299 * box_color[0] + 0.587 * box_color[1] + 0.114 * box_color[2] |
| ) |
| text_color = (0, 0, 0) if luminance > 160 else (255, 255, 255) |
| cv2.putText( |
| img=scene, |
| text=text, |
| org=(text_x, text_y), |
| fontFace=font, |
| fontScale=self.text_scale, |
| color=text_color, |
| thickness=self.text_thickness, |
| lineType=cv2.LINE_AA, |
| ) |
| return scene |
|
|
| @staticmethod |
| def get_optimal_label_pos( |
| text_padding, text_width, text_height, x1, y1, x2, y2, detections, image_size |
| ): |
| def get_is_overlap( |
| detections, |
| text_background_x1, |
| text_background_y1, |
| text_background_x2, |
| text_background_y2, |
| image_size, |
| ): |
| is_overlap = False |
| for i in range(len(detections)): |
| detection = detections.xyxy[i].astype(int) |
| if ( |
| intersection_over_union( |
| [ |
| text_background_x1, |
| text_background_y1, |
| text_background_x2, |
| text_background_y2, |
| ], |
| detection, |
| ) |
| > 0.3 |
| ): |
| is_overlap = True |
| break |
| if ( |
| text_background_x1 < 0 |
| or text_background_x2 > image_size[0] |
| or text_background_y1 < 0 |
| or text_background_y2 > image_size[1] |
| ): |
| is_overlap = True |
| return is_overlap |
|
|
| text_x = x1 + text_padding |
| text_y = y1 - text_padding |
|
|
| text_background_x1 = x1 |
| text_background_y1 = y1 - 2 * text_padding - text_height |
|
|
| text_background_x2 = x1 + 2 * text_padding + text_width |
| text_background_y2 = y1 |
| is_overlap = get_is_overlap( |
| detections, |
| text_background_x1, |
| text_background_y1, |
| text_background_x2, |
| text_background_y2, |
| image_size, |
| ) |
| if not is_overlap: |
| return ( |
| text_x, |
| text_y, |
| text_background_x1, |
| text_background_y1, |
| text_background_x2, |
| text_background_y2, |
| ) |
|
|
| text_x = x1 - text_padding - text_width |
| text_y = y1 + text_padding + text_height |
|
|
| text_background_x1 = x1 - 2 * text_padding - text_width |
| text_background_y1 = y1 |
|
|
| text_background_x2 = x1 |
| text_background_y2 = y1 + 2 * text_padding + text_height |
| is_overlap = get_is_overlap( |
| detections, |
| text_background_x1, |
| text_background_y1, |
| text_background_x2, |
| text_background_y2, |
| image_size, |
| ) |
| if not is_overlap: |
| return ( |
| text_x, |
| text_y, |
| text_background_x1, |
| text_background_y1, |
| text_background_x2, |
| text_background_y2, |
| ) |
|
|
| text_x = x2 + text_padding |
| text_y = y1 + text_padding + text_height |
|
|
| text_background_x1 = x2 |
| text_background_y1 = y1 |
|
|
| text_background_x2 = x2 + 2 * text_padding + text_width |
| text_background_y2 = y1 + 2 * text_padding + text_height |
|
|
| is_overlap = get_is_overlap( |
| detections, |
| text_background_x1, |
| text_background_y1, |
| text_background_x2, |
| text_background_y2, |
| image_size, |
| ) |
| if not is_overlap: |
| return ( |
| text_x, |
| text_y, |
| text_background_x1, |
| text_background_y1, |
| text_background_x2, |
| text_background_y2, |
| ) |
|
|
| text_x = x2 - text_padding - text_width |
| text_y = y1 - text_padding |
|
|
| text_background_x1 = x2 - 2 * text_padding - text_width |
| text_background_y1 = y1 - 2 * text_padding - text_height |
|
|
| text_background_x2 = x2 |
| text_background_y2 = y1 |
|
|
| is_overlap = get_is_overlap( |
| detections, |
| text_background_x1, |
| text_background_y1, |
| text_background_x2, |
| text_background_y2, |
| image_size, |
| ) |
| if not is_overlap: |
| return ( |
| text_x, |
| text_y, |
| text_background_x1, |
| text_background_y1, |
| text_background_x2, |
| text_background_y2, |
| ) |
|
|
| return ( |
| text_x, |
| text_y, |
| text_background_x1, |
| text_background_y1, |
| text_background_x2, |
| text_background_y2, |
| ) |