File size: 5,586 Bytes
37f7a06
 
 
 
 
 
6673b6b
37f7a06
 
 
 
 
cf8d6e3
37f7a06
 
 
 
 
 
 
 
 
 
 
 
6673b6b
37f7a06
 
 
 
 
 
 
 
6673b6b
37f7a06
 
 
 
 
6673b6b
37f7a06
 
6673b6b
37f7a06
 
 
6673b6b
37f7a06
 
 
6673b6b
37f7a06
 
 
 
 
 
 
 
 
 
 
 
6673b6b
37f7a06
6673b6b
 
 
 
37f7a06
 
 
6673b6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c37a86
6673b6b
 
 
 
7c37a86
 
6673b6b
7c37a86
6673b6b
 
 
7c37a86
6673b6b
7c37a86
 
6673b6b
 
7c37a86
 
6673b6b
 
7c37a86
6673b6b
 
 
 
 
 
 
 
9d66407
6673b6b
 
 
37f7a06
 
 
 
 
 
 
 
 
7c37a86
6673b6b
9d66407
 
 
 
6673b6b
 
7c37a86
 
 
6673b6b
 
 
7c37a86
 
 
 
 
6673b6b
7c37a86
 
6673b6b
7c37a86
 
6673b6b
7c37a86
 
 
6673b6b
 
7c37a86
6673b6b
 
 
 
 
 
 
 
 
 
 
 
 
 
37f7a06
 
6673b6b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
import gradio as gr
from torchvision import transforms 
from lib.pvt import PolypPVT  # senin repo'daki model

# ----------------------
# Model yükleme
# ----------------------
pth_path = "./weights/PolypPVT.pth"
model = PolypPVT()
model.load_state_dict(torch.load(pth_path, map_location="cuda" if torch.cuda.is_available() else "cpu"))
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

# ----------------------
# Transform
# ----------------------
transform = transforms.Compose([
    transforms.Resize((352, 352)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# ----------------------
# Prediction function
# ----------------------
def predict(image: Image.Image, mask: Image.Image = None):
    # Convert and preprocess input
    input_tensor = transform(image).unsqueeze(0).to(device)
    
    with torch.no_grad():
        P1, P2 = model(input_tensor)
        res = F.interpolate(P1 + P2, size=(352, 352), mode="bilinear", align_corners=False)
        res = res.sigmoid().cpu().numpy().squeeze()
        res_norm = (res - res.min()) / (res.max() - res.min() + 1e-8)
    
    # Predicted mask binary
    pred_mask = (res_norm > 0.5).astype(np.uint8)
    
    # Make colored mask
    pred_mask_color = cv2.applyColorMap((res_norm * 255).astype(np.uint8), cv2.COLORMAP_JET)
    pred_mask_color = cv2.cvtColor(pred_mask_color, cv2.COLOR_BGR2RGB)
    
    # Overlay on original
    image_resized = np.array(image.resize((352, 352)))
    overlay = cv2.addWeighted(image_resized, 0.6, pred_mask_color, 0.4, 0)
    
    # If ground truth mask is provided → calculate IOU
    iou_score = None
    if mask is not None:
        mask_resized = mask.convert("L").resize((352, 352))
        gt_mask_bin = (np.array(mask_resized) > 127).astype(np.uint8)
        intersection = np.logical_and(pred_mask, gt_mask_bin).sum()
        union = np.logical_or(pred_mask, gt_mask_bin).sum()
        iou_score = intersection / (union + 1e-8)
        # GT mask to RGB
        gt_mask_rgb = np.stack([gt_mask_bin * 255]*3, axis=-1)
    else:
        gt_mask_rgb = np.zeros_like(image_resized)
    
    return (
        Image.fromarray(image_resized),  # Orijinal
        Image.fromarray(pred_mask_color),  # Tahmin maskesi
        Image.fromarray(overlay),  # Bindirilmiş
        Image.fromarray(gt_mask_rgb),  # Gerçek maske (boş olabilir)
        f"IOU: {iou_score:.4f}" if iou_score is not None else "No GT mask provided"
    )

# ----------------------
# CSS Stilleri
# ----------------------
css = """
/* Banner stilleri */
.banner {
    text-align: center;
    margin-bottom: 30px;
}

.banner img {
    max-width: 100%;
    height: auto;
}

/* Başlık stilleri */
.gradio-container h1 {
    text-align: center !important;
    font-size: 2.2rem !important;
    font-weight: bold !important;
    margin: 20px 0 !important;
}

/* Alt başlık stilleri */
.gradio-container h2 {
    text-align: center !important;
    font-size: 1.6rem !important;
    margin: 15px 0 !important;
}

.gradio-container h3 {
    text-align: center !important;
    font-size: 1.3rem !important;
    margin: 15px 0 !important;
}

/* Açıklama metni */
.gradio-container .gr-prose p {
    text-align: center !important;
    font-size: 1.1rem !important;
    margin: 10px 0 !important;
}
"""

# ----------------------
# HTML Banner
# ----------------------
banner_html = """
<div class="banner">
    <img src="tmp/vflai.png" alt="VFLAI Banner">
</div>
"""

# ----------------------
# Gradio Interface
# ----------------------
examples = [
    ["examples/image1.jpg", None],
    ["examples/image2.jpg", None],
    ["examples/image3.jpg", None],  # maskesiz de test edilebilir
]

with gr.Blocks(css=css, title="VFLAI Polip Segmentasyon") as demo:
    # Banner
    try:
        gr.Image("vflai.png", label="", show_label=False, interactive=False, height=200)
    except:
        gr.HTML(banner_html)
    
    # Ana başlık
    gr.Markdown("# Validebağ Fen Lisesi Yapay Zeka Takımı")
    gr.Markdown("## Teknofest 2025 Sağlıkta Yapay Zeka Yarışması")
    gr.Markdown("### Polip Segmentasyonu Test Arayüzü")
    
    # Ana interface
    with gr.Row():
        with gr.Column():
            gr.Markdown("## Giriş")
            input_image = gr.Image(type="pil", label="Giriş Görüntüsü")
            gt_mask = gr.Image(type="pil", label="Gerçek Maske (Opsiyonel)")
            predict_btn = gr.Button("Analiz Et", variant="primary")
        
        with gr.Column():
            gr.Markdown("## Sonuçlar")
            with gr.Row():
                original_output = gr.Image(label="Orijinal")
                pred_mask_output = gr.Image(label="Tahmin Maskesi")
            with gr.Row():
                overlay_output = gr.Image(label="Bindirme")
                gt_mask_output = gr.Image(label="Gerçek Maske")
            iou_output = gr.Label(label="IOU Skoru")
    
    # Örnekler
    gr.Markdown("## Örnek Görüntüler")
    gr.Examples(
        examples=examples,
        inputs=[input_image, gt_mask],
        outputs=[original_output, pred_mask_output, overlay_output, gt_mask_output, iou_output],
        fn=predict,
        cache_examples=True
    )
    
    # Buton işlevi
    predict_btn.click(
        fn=predict,
        inputs=[input_image, gt_mask],
        outputs=[original_output, pred_mask_output, overlay_output, gt_mask_output, iou_output]
    )

if __name__ == "__main__":
    demo.launch()