File size: 3,544 Bytes
37f7a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import gradio as gr
from torchvision import transforms

from lib.pvt import PolypPVT  # senin repo'daki model

# ----------------------
# Model yükleme
# ----------------------
pth_path = "polyp_segmentation\weights\PolypPVT.pth"

model = PolypPVT()
model.load_state_dict(torch.load(pth_path, map_location="cuda" if torch.cuda.is_available() else "cpu"))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# ----------------------
# Transform
# ----------------------
transform = transforms.Compose([
    transforms.Resize((352, 352)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# ----------------------
# Prediction function
# ----------------------
def predict(image: Image.Image, mask: Image.Image = None):
    # Convert and preprocess input
    input_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        P1, P2 = model(input_tensor)
        res = F.interpolate(P1 + P2, size=(352, 352), mode="bilinear", align_corners=False)
        res = res.sigmoid().cpu().numpy().squeeze()
        res_norm = (res - res.min()) / (res.max() - res.min() + 1e-8)

    # Predicted mask binary
    pred_mask = (res_norm > 0.5).astype(np.uint8)

    # Make colored mask
    pred_mask_color = cv2.applyColorMap((res_norm * 255).astype(np.uint8), cv2.COLORMAP_JET)
    pred_mask_color = cv2.cvtColor(pred_mask_color, cv2.COLOR_BGR2RGB)

    # Overlay on original
    image_resized = np.array(image.resize((352, 352)))
    overlay = cv2.addWeighted(image_resized, 0.6, pred_mask_color, 0.4, 0)

    # If ground truth mask is provided → calculate IOU
    iou_score = None
    if mask is not None:
        mask_resized = mask.convert("L").resize((352, 352))
        gt_mask_bin = (np.array(mask_resized) > 127).astype(np.uint8)
        intersection = np.logical_and(pred_mask, gt_mask_bin).sum()
        union = np.logical_or(pred_mask, gt_mask_bin).sum()
        iou_score = intersection / (union + 1e-8)

        # GT mask to RGB
        gt_mask_rgb = np.stack([gt_mask_bin * 255]*3, axis=-1)
    else:
        gt_mask_rgb = np.zeros_like(image_resized)

    return (
        Image.fromarray(image_resized),          # Orijinal
        Image.fromarray(pred_mask_color),        # Tahmin maskesi
        Image.fromarray(overlay),                # Bindirilmiş
        Image.fromarray(gt_mask_rgb),            # Gerçek maske (boş olabilir)
        f"IOU: {iou_score:.4f}" if iou_score is not None else "No GT mask provided"
    )

# ----------------------
# Gradio Interface
# ----------------------
examples = [
    ["examples/image1.jpg", None],
    ["examples/image2.jpg", None],
    ["examples/image3.jpg", None],  # maskesiz de test edilebilir
]

demo = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Image(type="pil", label="Ground Truth Mask (Optional)", optional=True)
    ],
    outputs=[
        gr.Image(label="Original"),
        gr.Image(label="Predicted Mask"),
        gr.Image(label="Overlay"),
        gr.Image(label="Ground Truth Mask"),
        gr.Label(label="IOU Score")
    ],
    title="Polyp Segmentation - PolypPVT",
    description="Upload an endoscopic image to predict polyp segmentation mask. Optionally, provide a ground truth mask to calculate IOU.",
    examples=examples,
)

if __name__ == "__main__":
    demo.launch()