Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import gradio as gr | |
| from torchvision import transforms | |
| from lib.pvt import PolypPVT # senin repo'daki model | |
| # ---------------------- | |
| # Model yükleme | |
| # ---------------------- | |
| pth_path = "./weights/PolypPVT.pth" | |
| model = PolypPVT() | |
| model.load_state_dict(torch.load(pth_path, map_location="cuda" if torch.cuda.is_available() else "cpu")) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| model.eval() | |
| # ---------------------- | |
| # Transform | |
| # ---------------------- | |
| transform = transforms.Compose([ | |
| transforms.Resize((352, 352)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| # ---------------------- | |
| # Prediction function | |
| # ---------------------- | |
| def predict(image: Image.Image, mask: Image.Image = None): | |
| # Convert and preprocess input | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| P1, P2 = model(input_tensor) | |
| res = F.interpolate(P1 + P2, size=(352, 352), mode="bilinear", align_corners=False) | |
| res = res.sigmoid().cpu().numpy().squeeze() | |
| res_norm = (res - res.min()) / (res.max() - res.min() + 1e-8) | |
| # Predicted mask binary | |
| pred_mask = (res_norm > 0.5).astype(np.uint8) | |
| # Make colored mask | |
| pred_mask_color = cv2.applyColorMap((res_norm * 255).astype(np.uint8), cv2.COLORMAP_JET) | |
| pred_mask_color = cv2.cvtColor(pred_mask_color, cv2.COLOR_BGR2RGB) | |
| # Overlay on original | |
| image_resized = np.array(image.resize((352, 352))) | |
| overlay = cv2.addWeighted(image_resized, 0.6, pred_mask_color, 0.4, 0) | |
| # If ground truth mask is provided → calculate IOU | |
| iou_score = None | |
| if mask is not None: | |
| mask_resized = mask.convert("L").resize((352, 352)) | |
| gt_mask_bin = (np.array(mask_resized) > 127).astype(np.uint8) | |
| intersection = np.logical_and(pred_mask, gt_mask_bin).sum() | |
| union = np.logical_or(pred_mask, gt_mask_bin).sum() | |
| iou_score = intersection / (union + 1e-8) | |
| # GT mask to RGB | |
| gt_mask_rgb = np.stack([gt_mask_bin * 255]*3, axis=-1) | |
| else: | |
| gt_mask_rgb = np.zeros_like(image_resized) | |
| return ( | |
| Image.fromarray(image_resized), # Orijinal | |
| Image.fromarray(pred_mask_color), # Tahmin maskesi | |
| Image.fromarray(overlay), # Bindirilmiş | |
| Image.fromarray(gt_mask_rgb), # Gerçek maske (boş olabilir) | |
| f"IOU: {iou_score:.4f}" if iou_score is not None else "No GT mask provided" | |
| ) | |
| # ---------------------- | |
| # CSS Stilleri | |
| # ---------------------- | |
| css = """ | |
| /* Banner stilleri */ | |
| .banner { | |
| text-align: center; | |
| margin-bottom: 30px; | |
| } | |
| .banner img { | |
| max-width: 100%; | |
| height: auto; | |
| } | |
| /* Başlık stilleri */ | |
| .gradio-container h1 { | |
| text-align: center !important; | |
| font-size: 2.2rem !important; | |
| font-weight: bold !important; | |
| margin: 20px 0 !important; | |
| } | |
| /* Alt başlık stilleri */ | |
| .gradio-container h2 { | |
| text-align: center !important; | |
| font-size: 1.6rem !important; | |
| margin: 15px 0 !important; | |
| } | |
| .gradio-container h3 { | |
| text-align: center !important; | |
| font-size: 1.3rem !important; | |
| margin: 15px 0 !important; | |
| } | |
| /* Açıklama metni */ | |
| .gradio-container .gr-prose p { | |
| text-align: center !important; | |
| font-size: 1.1rem !important; | |
| margin: 10px 0 !important; | |
| } | |
| """ | |
| # ---------------------- | |
| # HTML Banner | |
| # ---------------------- | |
| banner_html = """ | |
| <div class="banner"> | |
| <img src="tmp/vflai.png" alt="VFLAI Banner"> | |
| </div> | |
| """ | |
| # ---------------------- | |
| # Gradio Interface | |
| # ---------------------- | |
| examples = [ | |
| ["examples/image1.jpg", None], | |
| ["examples/image2.jpg", None], | |
| ["examples/image3.jpg", None], # maskesiz de test edilebilir | |
| ] | |
| with gr.Blocks(css=css, title="VFLAI Polip Segmentasyon") as demo: | |
| # Banner | |
| try: | |
| gr.Image("vflai.png", label="", show_label=False, interactive=False, height=200) | |
| except: | |
| gr.HTML(banner_html) | |
| # Ana başlık | |
| gr.Markdown("# Validebağ Fen Lisesi Yapay Zeka Takımı") | |
| gr.Markdown("## Teknofest 2025 Sağlıkta Yapay Zeka Yarışması") | |
| gr.Markdown("### Polip Segmentasyonu Test Arayüzü") | |
| # Ana interface | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("## Giriş") | |
| input_image = gr.Image(type="pil", label="Giriş Görüntüsü") | |
| gt_mask = gr.Image(type="pil", label="Gerçek Maske (Opsiyonel)") | |
| predict_btn = gr.Button("Analiz Et", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown("## Sonuçlar") | |
| with gr.Row(): | |
| original_output = gr.Image(label="Orijinal") | |
| pred_mask_output = gr.Image(label="Tahmin Maskesi") | |
| with gr.Row(): | |
| overlay_output = gr.Image(label="Bindirme") | |
| gt_mask_output = gr.Image(label="Gerçek Maske") | |
| iou_output = gr.Label(label="IOU Skoru") | |
| # Örnekler | |
| gr.Markdown("## Örnek Görüntüler") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[input_image, gt_mask], | |
| outputs=[original_output, pred_mask_output, overlay_output, gt_mask_output, iou_output], | |
| fn=predict, | |
| cache_examples=True | |
| ) | |
| # Buton işlevi | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[input_image, gt_mask], | |
| outputs=[original_output, pred_mask_output, overlay_output, gt_mask_output, iou_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |