momererkoc's picture
Update app.py
9d66407 verified
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import gradio as gr
from torchvision import transforms
from lib.pvt import PolypPVT # senin repo'daki model
# ----------------------
# Model yükleme
# ----------------------
pth_path = "./weights/PolypPVT.pth"
model = PolypPVT()
model.load_state_dict(torch.load(pth_path, map_location="cuda" if torch.cuda.is_available() else "cpu"))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()
# ----------------------
# Transform
# ----------------------
transform = transforms.Compose([
transforms.Resize((352, 352)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# ----------------------
# Prediction function
# ----------------------
def predict(image: Image.Image, mask: Image.Image = None):
# Convert and preprocess input
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
P1, P2 = model(input_tensor)
res = F.interpolate(P1 + P2, size=(352, 352), mode="bilinear", align_corners=False)
res = res.sigmoid().cpu().numpy().squeeze()
res_norm = (res - res.min()) / (res.max() - res.min() + 1e-8)
# Predicted mask binary
pred_mask = (res_norm > 0.5).astype(np.uint8)
# Make colored mask
pred_mask_color = cv2.applyColorMap((res_norm * 255).astype(np.uint8), cv2.COLORMAP_JET)
pred_mask_color = cv2.cvtColor(pred_mask_color, cv2.COLOR_BGR2RGB)
# Overlay on original
image_resized = np.array(image.resize((352, 352)))
overlay = cv2.addWeighted(image_resized, 0.6, pred_mask_color, 0.4, 0)
# If ground truth mask is provided → calculate IOU
iou_score = None
if mask is not None:
mask_resized = mask.convert("L").resize((352, 352))
gt_mask_bin = (np.array(mask_resized) > 127).astype(np.uint8)
intersection = np.logical_and(pred_mask, gt_mask_bin).sum()
union = np.logical_or(pred_mask, gt_mask_bin).sum()
iou_score = intersection / (union + 1e-8)
# GT mask to RGB
gt_mask_rgb = np.stack([gt_mask_bin * 255]*3, axis=-1)
else:
gt_mask_rgb = np.zeros_like(image_resized)
return (
Image.fromarray(image_resized), # Orijinal
Image.fromarray(pred_mask_color), # Tahmin maskesi
Image.fromarray(overlay), # Bindirilmiş
Image.fromarray(gt_mask_rgb), # Gerçek maske (boş olabilir)
f"IOU: {iou_score:.4f}" if iou_score is not None else "No GT mask provided"
)
# ----------------------
# CSS Stilleri
# ----------------------
css = """
/* Banner stilleri */
.banner {
text-align: center;
margin-bottom: 30px;
}
.banner img {
max-width: 100%;
height: auto;
}
/* Başlık stilleri */
.gradio-container h1 {
text-align: center !important;
font-size: 2.2rem !important;
font-weight: bold !important;
margin: 20px 0 !important;
}
/* Alt başlık stilleri */
.gradio-container h2 {
text-align: center !important;
font-size: 1.6rem !important;
margin: 15px 0 !important;
}
.gradio-container h3 {
text-align: center !important;
font-size: 1.3rem !important;
margin: 15px 0 !important;
}
/* Açıklama metni */
.gradio-container .gr-prose p {
text-align: center !important;
font-size: 1.1rem !important;
margin: 10px 0 !important;
}
"""
# ----------------------
# HTML Banner
# ----------------------
banner_html = """
<div class="banner">
<img src="tmp/vflai.png" alt="VFLAI Banner">
</div>
"""
# ----------------------
# Gradio Interface
# ----------------------
examples = [
["examples/image1.jpg", None],
["examples/image2.jpg", None],
["examples/image3.jpg", None], # maskesiz de test edilebilir
]
with gr.Blocks(css=css, title="VFLAI Polip Segmentasyon") as demo:
# Banner
try:
gr.Image("vflai.png", label="", show_label=False, interactive=False, height=200)
except:
gr.HTML(banner_html)
# Ana başlık
gr.Markdown("# Validebağ Fen Lisesi Yapay Zeka Takımı")
gr.Markdown("## Teknofest 2025 Sağlıkta Yapay Zeka Yarışması")
gr.Markdown("### Polip Segmentasyonu Test Arayüzü")
# Ana interface
with gr.Row():
with gr.Column():
gr.Markdown("## Giriş")
input_image = gr.Image(type="pil", label="Giriş Görüntüsü")
gt_mask = gr.Image(type="pil", label="Gerçek Maske (Opsiyonel)")
predict_btn = gr.Button("Analiz Et", variant="primary")
with gr.Column():
gr.Markdown("## Sonuçlar")
with gr.Row():
original_output = gr.Image(label="Orijinal")
pred_mask_output = gr.Image(label="Tahmin Maskesi")
with gr.Row():
overlay_output = gr.Image(label="Bindirme")
gt_mask_output = gr.Image(label="Gerçek Maske")
iou_output = gr.Label(label="IOU Skoru")
# Örnekler
gr.Markdown("## Örnek Görüntüler")
gr.Examples(
examples=examples,
inputs=[input_image, gt_mask],
outputs=[original_output, pred_mask_output, overlay_output, gt_mask_output, iou_output],
fn=predict,
cache_examples=True
)
# Buton işlevi
predict_btn.click(
fn=predict,
inputs=[input_image, gt_mask],
outputs=[original_output, pred_mask_output, overlay_output, gt_mask_output, iou_output]
)
if __name__ == "__main__":
demo.launch()