dl_course_hw3 / src /streamlit_app.py
katyan010's picture
fix names
5f44bc6
import io
from PIL import Image
import streamlit as st
import config
from utils import draw_boxes, run_detection
# App config
st.set_page_config(
page_title=config.PAGE_TITLE,
page_icon=config.PAGE_ICON,
layout=config.LAYOUT,
)
# Sidebar controls
st.sidebar.header("⚙️ Настройки")
model_label = st.sidebar.selectbox(
"Hugging Face модель",
options=list(config.MODEL_CATALOG.keys()),
index=0,
help="Например, YOLO модель для детекции",
)
model_id = config.MODEL_CATALOG[model_label]
threshold = st.sidebar.slider(
"Порог уверенности",
min_value=0.0,
max_value=1.0,
value=float(config.DEFAULT_THRESHOLD),
step=0.01,
)
st.title(config.PAGE_ICON + " " + config.PAGE_TITLE)
st.write(
"Загрузите изображение. Модель найдёт объекты "
"и отрисует bounding boxes."
)
uploaded = st.file_uploader(
"Выберите изображение",
type=config.UPLOADER_TYPES,
accept_multiple_files=False,
)
if uploaded is not None:
try:
image = Image.open(uploaded).convert("RGB")
except Exception as exc:
st.error(f"Не удалось открыть изображение: {exc}")
st.stop()
with st.spinner("Детекция логотипов…"):
try:
predictions = run_detection(
model_id,
image,
)
except Exception as exc:
st.error(f"Ошибка инференса: {exc}")
st.stop()
cols = st.columns(2)
with cols[0]:
st.image(
image,
caption="Оригинал",
use_container_width=True,
)
if isinstance(predictions, dict) and predictions.get("error"):
err_msg = predictions.get("error")
st.error(f"Ошибка модели: {err_msg}")
st.stop()
annotated_image = draw_boxes(image, predictions, threshold)
with cols[1]:
st.image(
annotated_image,
caption="С найденными боксами",
use_container_width=True,
)
# Stats and download
shown = sum(
1
for p in predictions # type: ignore[assignment]
if float(p.get("score", 0.0)) >= threshold
)
total = len(predictions) # type: ignore[arg-type]
st.caption(
f"Показано боксов: {shown} из {total} "
f"(порог {threshold:.2f})"
)
predictions_str = "\n".join(
[f"{p['label']}: {round(p['score'], 2)}" for p in predictions]
)
st.markdown(f"**{predictions_str}**")
buf = io.BytesIO()
annotated_image.save(buf, format="PNG")
st.download_button(
label="Скачать размеченное изображение",
data=buf.getvalue(),
file_name="detections.png",
mime="image/png",
type="primary",
)