File size: 2,400 Bytes
a0709ea
4485c2a
a0709ea
4485c2a
 
a0709ea
4485c2a
94f6d87
4485c2a
 
 
a0709ea
 
 
 
 
 
 
 
 
 
4485c2a
 
 
a0709ea
 
 
4485c2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0709ea
4485c2a
 
57f7dec
4485c2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from typing import Any, Dict, List, Optional

import torch
from PIL import Image, ImageDraw, ImageFont
import streamlit as st
from transformers import pipeline

from const import WHITE, color_for_label


@st.cache_resource(show_spinner=False)
def get_detector(model_id: str):
    has_cuda = torch.cuda.is_available()
    device = 0 if has_cuda else -1
    torch_dtype = torch.float16 if has_cuda else torch.float32
    return pipeline(
        task="object-detection",
        model=model_id,
        device=device,
        torch_dtype=torch_dtype,
    )


@st.cache_data(show_spinner=False, ttl=600)
def run_detection(model_id: str, image: Image.Image) -> List[Dict[str, Any]]:
    detector = get_detector(model_id)
    return detector(image)  # returns list of dicts with label, score, box


def _get_font() -> Optional[ImageFont.FreeTypeFont]:
    try:
        return ImageFont.load_default()
    except Exception:
        return None


def draw_boxes(
    image: Image.Image,
    predictions: List[Dict[str, Any]],
    threshold: float,
) -> Image.Image:
    annotated = image.copy()
    draw = ImageDraw.Draw(annotated)
    font = _get_font()

    for pred in predictions:
        score = float(pred.get("score", 0.0))
        if score < threshold:
            continue
        label = str(pred.get("label", "logo"))
        box = pred.get("box", {})
        x0 = float(box.get("xmin", box.get("x_min", 0)))
        y0 = float(box.get("ymin", box.get("y_min", 0)))
        x1 = float(box.get("xmax", box.get("x_max", 0)))
        y1 = float(box.get("ymax", box.get("y_max", 0)))

        color = color_for_label(label)

        # Rectangle
        draw.rectangle([(x0, y0), (x1, y1)], outline=color, width=3)

        # Label background
        text = f"{label} {score:.2f}"
        try:
            tx0, ty0, tx1, _ = draw.textbbox(
                (int(x0), int(y0)),
                text,
                font=font,
            )
        except Exception:
            tx0, ty0 = int(x0), int(y0) - 20
            tx1 = int(x0) + 8 * len(text)
        bg_top = min(ty0, y0)
        bg_bottom = max(ty0, y0)
        draw.rectangle(
            [(tx0, bg_top - 2), (tx1, bg_bottom + 2)],
            fill=color,
        )

        # Text
        draw.text(
            (int(x0) + 2, int(y0) - 18),
            text,
            fill=WHITE,
            font=font,
        )

    return annotated