katyan010 commited on
Commit
6214b0d
·
1 Parent(s): 33678f5

logo detection

Browse files
Files changed (2) hide show
  1. requirements.txt +3 -3
  2. src/streamlit_app.py +197 -38
requirements.txt CHANGED
@@ -1,3 +1,3 @@
1
- altair
2
- pandas
3
- streamlit
 
1
+ streamlit>=1.33.0
2
+ huggingface_hub>=0.23.0
3
+ pillow>=10.0.0
src/streamlit_app.py CHANGED
@@ -1,40 +1,199 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
4
  import streamlit as st
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ from huggingface_hub import InferenceClient
6
+ from PIL import Image, ImageDraw, ImageFont
7
  import streamlit as st
8
 
9
+
10
+ # App config
11
+ st.set_page_config(
12
+ page_title="Logo Detector Hugging Face",
13
+ page_icon="🔎",
14
+ layout="centered",
15
+ )
16
+
17
+ DEFAULT_MODEL_ID = "keremberke/yolov8m-logo-detector"
18
+
19
+
20
+ def get_hf_token() -> Optional[str]:
21
+ """Prefer Streamlit secrets, then environment variable."""
22
+ token: Optional[str] = None
23
+ try:
24
+ # st.secrets behaves like a dict when available
25
+ token = st.secrets.get("HF_TOKEN") # type: ignore[attr-defined]
26
+ except Exception:
27
+ token = None
28
+ if not token:
29
+ token = os.environ.get("HF_TOKEN")
30
+ return token
31
+
32
+
33
+ @st.cache_resource(show_spinner=False)
34
+ def get_client(model_id: str, token: Optional[str]) -> InferenceClient:
35
+ return InferenceClient(model=model_id, token=token)
36
+
37
+
38
+ @st.cache_data(show_spinner=False, ttl=600)
39
+ def run_detection(
40
+ model_id: str,
41
+ token: Optional[str],
42
+ image_bytes: bytes,
43
+ ) -> List[Dict[str, Any]]:
44
+ client = InferenceClient(model=model_id, token=token)
45
+ # The object_detection endpoint returns a list of dicts with keys:
46
+ # label, score, and box {xmin, ymin, xmax, ymax}
47
+ return client.object_detection(image=image_bytes)
48
+
49
+
50
+ def draw_boxes(
51
+ image: Image.Image,
52
+ predictions: List[Dict[str, Any]],
53
+ threshold: float,
54
+ ) -> Image.Image:
55
+ annotated = image.copy()
56
+ draw = ImageDraw.Draw(annotated)
57
+ try:
58
+ font = ImageFont.load_default()
59
+ except Exception:
60
+ font = None # Pillow will fallback
61
+
62
+ for pred in predictions:
63
+ score = float(pred.get("score", 0.0))
64
+ if score < threshold:
65
+ continue
66
+ label = str(pred.get("label", "logo"))
67
+ box = pred.get("box", {})
68
+ # Support alternative key names just in case
69
+ x0 = box.get("xmin", box.get("x_min", 0))
70
+ y0 = box.get("ymin", box.get("y_min", 0))
71
+ x1 = box.get("xmax", box.get("x_max", 0))
72
+ y1 = box.get("ymax", box.get("y_max", 0))
73
+ x0, y0, x1, y1 = float(x0), float(y0), float(x1), float(y1)
74
+
75
+ # Rectangle
76
+ draw.rectangle([(x0, y0), (x1, y1)], outline=(255, 0, 0), width=3)
77
+
78
+ # Label background
79
+ text = f"{label} {score:.2f}"
80
+ try:
81
+ # Compute text bounding box for background
82
+ text_bbox = draw.textbbox((int(x0), int(y0)), text, font=font)
83
+ tx0, ty0, tx1, _ = text_bbox
84
+ except Exception:
85
+ # Fallback: rough estimate for background width/height
86
+ tx0, ty0 = int(x0), int(y0) - 20
87
+ tx1 = int(x0) + 8 * len(text)
88
+ bg_top = min(ty0, y0)
89
+ bg_bottom = max(ty0, y0)
90
+ draw.rectangle(
91
+ [(tx0, bg_top - 2), (tx1, bg_bottom + 2)],
92
+ fill=(255, 0, 0),
93
+ )
94
+
95
+ # Text
96
+ draw.text(
97
+ (int(x0) + 2, int(y0) - 18),
98
+ text,
99
+ fill=(255, 255, 255),
100
+ font=font,
101
+ )
102
+
103
+ return annotated
104
+
105
+
106
+ # Sidebar controls
107
+ st.sidebar.header("⚙️ Настройки")
108
+ model_id = st.sidebar.text_input(
109
+ "Hugging Face модель",
110
+ value=DEFAULT_MODEL_ID,
111
+ help="Например, YOLO модель для детекции логотипов",
112
+ )
113
+ threshold = st.sidebar.slider(
114
+ "Порог уверенности",
115
+ min_value=0.0,
116
+ max_value=1.0,
117
+ value=0.30,
118
+ step=0.01,
119
+ )
120
+
121
+ st.title("🔎 Поиск логотипов на изображении (Hugging Face • YOLO)")
122
+ st.write(
123
+ "Загрузите изображение. Модель найдёт логотипы "
124
+ "и отрисует bounding boxes."
125
+ )
126
+
127
+ # Token hint
128
+ hf_token = get_hf_token()
129
+ if not hf_token:
130
+ st.info(
131
+ (
132
+ "Опционально укажите токен `HF_TOKEN` через `st.secrets` "
133
+ "или переменную окружения."
134
+ )
135
+ )
136
+
137
+ uploaded = st.file_uploader(
138
+ "Выберите изображение",
139
+ type=["png", "jpg", "jpeg"],
140
+ accept_multiple_files=False,
141
+ )
142
+
143
+ if uploaded is not None:
144
+ try:
145
+ image = Image.open(uploaded).convert("RGB")
146
+ except Exception as exc:
147
+ st.error(f"Не удалось открыть изображение: {exc}")
148
+ st.stop()
149
+
150
+ cols = st.columns(2)
151
+ with cols[0]:
152
+ st.image(image, caption="Оригинал", use_column_width=True)
153
+
154
+ with st.spinner("Детекция логотипов…"):
155
+ try:
156
+ predictions = run_detection(
157
+ model_id,
158
+ hf_token,
159
+ uploaded.getvalue(),
160
+ )
161
+ except Exception as exc:
162
+ st.error(f"Ошибка инференса: {exc}")
163
+ st.stop()
164
+
165
+ if isinstance(predictions, dict) and predictions.get("error"):
166
+ err_msg = predictions.get("error")
167
+ st.error(f"Ошибка модели: {err_msg}")
168
+ st.stop()
169
+
170
+ annotated_image = draw_boxes(image, predictions, threshold)
171
+
172
+ with cols[1]:
173
+ st.image(
174
+ annotated_image,
175
+ caption="С найденными боксами",
176
+ use_column_width=True,
177
+ )
178
+
179
+ # Stats and download
180
+ shown = sum(
181
+ 1
182
+ for p in predictions
183
+ if float(p.get("score", 0.0)) >= threshold
184
+ )
185
+ total = len(predictions)
186
+ st.caption(
187
+ f"Показано боксов: {shown} из {total} "
188
+ f"(порог {threshold:.2f})"
189
+ )
190
+
191
+ buf = io.BytesIO()
192
+ annotated_image.save(buf, format="PNG")
193
+ st.download_button(
194
+ label="Скачать размеченное изображение",
195
+ data=buf.getvalue(),
196
+ file_name="detections.png",
197
+ mime="image/png",
198
+ type="primary",
199
+ )