from typing import Any, Dict, List, Optional import torch from PIL import Image, ImageDraw, ImageFont import streamlit as st from transformers import pipeline from const import WHITE, color_for_label @st.cache_resource(show_spinner=False) def get_detector(model_id: str): has_cuda = torch.cuda.is_available() device = 0 if has_cuda else -1 torch_dtype = torch.float16 if has_cuda else torch.float32 return pipeline( task="object-detection", model=model_id, device=device, torch_dtype=torch_dtype, ) @st.cache_data(show_spinner=False, ttl=600) def run_detection(model_id: str, image: Image.Image) -> List[Dict[str, Any]]: detector = get_detector(model_id) return detector(image) # returns list of dicts with label, score, box def _get_font() -> Optional[ImageFont.FreeTypeFont]: try: return ImageFont.load_default() except Exception: return None def draw_boxes( image: Image.Image, predictions: List[Dict[str, Any]], threshold: float, ) -> Image.Image: annotated = image.copy() draw = ImageDraw.Draw(annotated) font = _get_font() for pred in predictions: score = float(pred.get("score", 0.0)) if score < threshold: continue label = str(pred.get("label", "logo")) box = pred.get("box", {}) x0 = float(box.get("xmin", box.get("x_min", 0))) y0 = float(box.get("ymin", box.get("y_min", 0))) x1 = float(box.get("xmax", box.get("x_max", 0))) y1 = float(box.get("ymax", box.get("y_max", 0))) color = color_for_label(label) # Rectangle draw.rectangle([(x0, y0), (x1, y1)], outline=color, width=3) # Label background text = f"{label} {score:.2f}" try: tx0, ty0, tx1, _ = draw.textbbox( (int(x0), int(y0)), text, font=font, ) except Exception: tx0, ty0 = int(x0), int(y0) - 20 tx1 = int(x0) + 8 * len(text) bg_top = min(ty0, y0) bg_bottom = max(ty0, y0) draw.rectangle( [(tx0, bg_top - 2), (tx1, bg_bottom + 2)], fill=color, ) # Text draw.text( (int(x0) + 2, int(y0) - 18), text, fill=WHITE, font=font, ) return annotated