import torch import torch.nn.functional as F import numpy as np import cv2 from PIL import Image import gradio as gr from torchvision import transforms from lib.pvt import PolypPVT # senin repo'daki model # ---------------------- # Model yükleme # ---------------------- pth_path = "./weights/PolypPVT.pth" model = PolypPVT() model.load_state_dict(torch.load(pth_path, map_location="cuda" if torch.cuda.is_available() else "cpu")) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) model.eval() # ---------------------- # Transform # ---------------------- transform = transforms.Compose([ transforms.Resize((352, 352)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # ---------------------- # Prediction function # ---------------------- def predict(image: Image.Image, mask: Image.Image = None): # Convert and preprocess input input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): P1, P2 = model(input_tensor) res = F.interpolate(P1 + P2, size=(352, 352), mode="bilinear", align_corners=False) res = res.sigmoid().cpu().numpy().squeeze() res_norm = (res - res.min()) / (res.max() - res.min() + 1e-8) # Predicted mask binary pred_mask = (res_norm > 0.5).astype(np.uint8) # Make colored mask pred_mask_color = cv2.applyColorMap((res_norm * 255).astype(np.uint8), cv2.COLORMAP_JET) pred_mask_color = cv2.cvtColor(pred_mask_color, cv2.COLOR_BGR2RGB) # Overlay on original image_resized = np.array(image.resize((352, 352))) overlay = cv2.addWeighted(image_resized, 0.6, pred_mask_color, 0.4, 0) # If ground truth mask is provided → calculate IOU iou_score = None if mask is not None: mask_resized = mask.convert("L").resize((352, 352)) gt_mask_bin = (np.array(mask_resized) > 127).astype(np.uint8) intersection = np.logical_and(pred_mask, gt_mask_bin).sum() union = np.logical_or(pred_mask, gt_mask_bin).sum() iou_score = intersection / (union + 1e-8) # GT mask to RGB gt_mask_rgb = np.stack([gt_mask_bin * 255]*3, axis=-1) else: gt_mask_rgb = np.zeros_like(image_resized) return ( Image.fromarray(image_resized), # Orijinal Image.fromarray(pred_mask_color), # Tahmin maskesi Image.fromarray(overlay), # Bindirilmiş Image.fromarray(gt_mask_rgb), # Gerçek maske (boş olabilir) f"IOU: {iou_score:.4f}" if iou_score is not None else "No GT mask provided" ) # ---------------------- # CSS Stilleri # ---------------------- css = """ /* Banner stilleri */ .banner { text-align: center; margin-bottom: 30px; } .banner img { max-width: 100%; height: auto; } /* Başlık stilleri */ .gradio-container h1 { text-align: center !important; font-size: 2.2rem !important; font-weight: bold !important; margin: 20px 0 !important; } /* Alt başlık stilleri */ .gradio-container h2 { text-align: center !important; font-size: 1.6rem !important; margin: 15px 0 !important; } .gradio-container h3 { text-align: center !important; font-size: 1.3rem !important; margin: 15px 0 !important; } /* Açıklama metni */ .gradio-container .gr-prose p { text-align: center !important; font-size: 1.1rem !important; margin: 10px 0 !important; } """ # ---------------------- # HTML Banner # ---------------------- banner_html = """ """ # ---------------------- # Gradio Interface # ---------------------- examples = [ ["examples/image1.jpg", None], ["examples/image2.jpg", None], ["examples/image3.jpg", None], # maskesiz de test edilebilir ] with gr.Blocks(css=css, title="VFLAI Polip Segmentasyon") as demo: # Banner try: gr.Image("vflai.png", label="", show_label=False, interactive=False, height=200) except: gr.HTML(banner_html) # Ana başlık gr.Markdown("# Validebağ Fen Lisesi Yapay Zeka Takımı") gr.Markdown("## Teknofest 2025 Sağlıkta Yapay Zeka Yarışması") gr.Markdown("### Polip Segmentasyonu Test Arayüzü") # Ana interface with gr.Row(): with gr.Column(): gr.Markdown("## Giriş") input_image = gr.Image(type="pil", label="Giriş Görüntüsü") gt_mask = gr.Image(type="pil", label="Gerçek Maske (Opsiyonel)") predict_btn = gr.Button("Analiz Et", variant="primary") with gr.Column(): gr.Markdown("## Sonuçlar") with gr.Row(): original_output = gr.Image(label="Orijinal") pred_mask_output = gr.Image(label="Tahmin Maskesi") with gr.Row(): overlay_output = gr.Image(label="Bindirme") gt_mask_output = gr.Image(label="Gerçek Maske") iou_output = gr.Label(label="IOU Skoru") # Örnekler gr.Markdown("## Örnek Görüntüler") gr.Examples( examples=examples, inputs=[input_image, gt_mask], outputs=[original_output, pred_mask_output, overlay_output, gt_mask_output, iou_output], fn=predict, cache_examples=True ) # Buton işlevi predict_btn.click( fn=predict, inputs=[input_image, gt_mask], outputs=[original_output, pred_mask_output, overlay_output, gt_mask_output, iou_output] ) if __name__ == "__main__": demo.launch()