WizardWang01 commited on
Commit
d6deb62
·
verified ·
1 Parent(s): 4f606ba

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -0
app.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image, ImageFilter
3
+ import torch
4
+ from transformers import AutoImageProcessor, AutoModelForSemanticSegmentation, AutoModelForDepthEstimation
5
+ from scipy.ndimage import gaussian_filter
6
+ import gradio as gr
7
+
8
+ # Global models (loaded once at startup)
9
+ segmentation_model = None
10
+ segmentation_processor = None
11
+ depth_model = None
12
+ depth_processor = None
13
+ device = None
14
+
15
+ def load_models():
16
+ """Load all required models at startup"""
17
+ global segmentation_model, segmentation_processor, depth_model, depth_processor, device
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ print(f"Using device: {device}")
21
+
22
+ # Load segmentation model
23
+ print("Loading segmentation model...")
24
+ seg_model_id = "nvidia/segformer-b0-finetuned-ade-512-512"
25
+ segmentation_processor = AutoImageProcessor.from_pretrained(seg_model_id)
26
+ segmentation_model = AutoModelForSemanticSegmentation.from_pretrained(seg_model_id)
27
+ segmentation_model.eval()
28
+ segmentation_model.to(device)
29
+
30
+ # Load depth estimation model
31
+ print("Loading depth estimation model...")
32
+ depth_model_id = "depth-anything/Depth-Anything-V2-Base-hf"
33
+ depth_processor = AutoImageProcessor.from_pretrained(depth_model_id)
34
+ depth_model = AutoModelForDepthEstimation.from_pretrained(depth_model_id)
35
+ depth_model.eval()
36
+ depth_model.to(device)
37
+
38
+ print("Models loaded successfully!")
39
+
40
+ def get_person_mask(image):
41
+ """Extract person mask from image using semantic segmentation"""
42
+ # Resize to 512x512 for processing
43
+ img_512 = image.resize((512, 512), Image.BILINEAR)
44
+
45
+ # Run segmentation
46
+ inputs = segmentation_processor(images=img_512, return_tensors="pt").to(device)
47
+ with torch.no_grad():
48
+ outputs = segmentation_model(**inputs)
49
+ logits = torch.nn.functional.interpolate(
50
+ outputs.logits, size=(512, 512), mode="bilinear", align_corners=False
51
+ )
52
+ pred = logits.argmax(dim=1)[0].cpu().numpy()
53
+
54
+ # Find person class ID
55
+ id2label = segmentation_model.config.id2label
56
+ label2id = {v.lower(): int(k) for k, v in id2label.items()}
57
+ person_key = next((k for k in label2id.keys() if k in ["person", "people", "human"]), None)
58
+
59
+ if person_key is None:
60
+ # If no person found, return empty mask
61
+ return Image.new("L", (512, 512), 0)
62
+
63
+ person_id = label2id[person_key]
64
+ mask = (pred == person_id).astype(np.uint8) * 255
65
+
66
+ return Image.fromarray(mask, mode="L")
67
+
68
+ def gaussian_blur_effect(image, blur_radius=15):
69
+ """Apply Gaussian blur to background, keep person sharp"""
70
+ if image is None:
71
+ return None
72
+
73
+ # Convert to RGB if needed
74
+ if image.mode != "RGB":
75
+ image = image.convert("RGB")
76
+
77
+ # Resize to 512x512
78
+ img_512 = image.resize((512, 512), Image.BILINEAR)
79
+
80
+ # Get person mask
81
+ mask_img = get_person_mask(img_512)
82
+
83
+ # Apply Gaussian blur to entire image
84
+ blurred_img = img_512.filter(ImageFilter.GaussianBlur(radius=blur_radius))
85
+
86
+ # Composite: person (sharp) + background (blurred)
87
+ input_array = np.array(img_512)
88
+ blurred_array = np.array(blurred_img)
89
+ mask_array = np.array(mask_img) / 255.0
90
+ mask_3ch = np.stack([mask_array] * 3, axis=-1)
91
+
92
+ output_array = (input_array * mask_3ch + blurred_array * (1 - mask_3ch)).astype(np.uint8)
93
+ output_img = Image.fromarray(output_array)
94
+
95
+ return output_img
96
+
97
+ def get_depth_map(image):
98
+ """Estimate depth map from image"""
99
+ # Resize to 512x512
100
+ img_512 = image.resize((512, 512), Image.BILINEAR)
101
+
102
+ # Run depth estimation
103
+ inputs = depth_processor(images=img_512, return_tensors="pt").to(device)
104
+ with torch.no_grad():
105
+ outputs = depth_model(**inputs)
106
+ predicted_depth = outputs.predicted_depth
107
+
108
+ # Interpolate to 512x512
109
+ prediction = torch.nn.functional.interpolate(
110
+ predicted_depth.unsqueeze(1),
111
+ size=(512, 512),
112
+ mode="bicubic",
113
+ align_corners=False,
114
+ )
115
+
116
+ depth_map = prediction.squeeze().cpu().numpy()
117
+ return depth_map
118
+
119
+ def lens_blur_effect(image, max_blur=15, focus_threshold=5.0):
120
+ """Apply depth-based lens blur (foreground sharp, background blurred)"""
121
+ if image is None:
122
+ return None
123
+
124
+ # Convert to RGB if needed
125
+ if image.mode != "RGB":
126
+ image = image.convert("RGB")
127
+
128
+ # Resize to 512x512
129
+ img_512 = image.resize((512, 512), Image.BILINEAR)
130
+
131
+ # Get depth map
132
+ depth_map = get_depth_map(img_512)
133
+
134
+ # Invert depth (higher values = farther = more blur)
135
+ depth_normalized = (depth_map.max() - depth_map) / (depth_map.max() - depth_map.min())
136
+ depth_normalized = depth_normalized * max_blur
137
+
138
+ # Create blur map
139
+ blur_map = np.zeros_like(depth_normalized)
140
+ close_mask = depth_normalized <= focus_threshold
141
+ blur_map[close_mask] = 0.0
142
+
143
+ far_mask = depth_normalized > focus_threshold
144
+ blur_map[far_mask] = ((depth_normalized[far_mask] - focus_threshold) / (max_blur - focus_threshold)) * max_blur
145
+
146
+ # Apply variable blur
147
+ img_array = np.array(img_512).astype(np.float32)
148
+ output_array = img_array.copy()
149
+
150
+ num_blur_levels = 20
151
+ for level in range(1, num_blur_levels + 1):
152
+ sigma_min = (level - 1) * max_blur / num_blur_levels
153
+ sigma_max = level * max_blur / num_blur_levels
154
+ sigma_avg = (sigma_min + sigma_max) / 2.0
155
+
156
+ mask = ((blur_map >= sigma_min) & (blur_map < sigma_max)).astype(np.float32)
157
+
158
+ if mask.sum() > 0 and sigma_avg > 0.1:
159
+ blurred = np.zeros_like(img_array)
160
+ for c in range(3):
161
+ blurred[:, :, c] = gaussian_filter(img_array[:, :, c], sigma=sigma_avg)
162
+
163
+ mask_3ch = np.stack([mask] * 3, axis=-1)
164
+ output_array = output_array * (1 - mask_3ch) + blurred * mask_3ch
165
+
166
+ output_array = np.clip(output_array, 0, 255).astype(np.uint8)
167
+ output_img = Image.fromarray(output_array)
168
+
169
+ return output_img
170
+
171
+ # Load models at startup
172
+ load_models()
173
+
174
+ # Create Gradio interface
175
+ with gr.Blocks(title="Image Blur Effects Demo") as demo:
176
+ gr.Markdown("""
177
+ # 🎨 Image Blur Effects Demo
178
+
179
+ Upload an image to apply **Gaussian Blur** or **Lens Blur** effects.
180
+
181
+ - **Gaussian Blur**: Detects people and blurs the background, keeping the person sharp.
182
+ - **Lens Blur**: Uses depth estimation to simulate camera lens bokeh effect (foreground sharp, background blurred).
183
+ """)
184
+
185
+ with gr.Tab("Gaussian Blur"):
186
+ gr.Markdown("### Background blur with person detection")
187
+ with gr.Row():
188
+ with gr.Column():
189
+ gaussian_input = gr.Image(type="pil", label="Input Image")
190
+ gaussian_radius = gr.Slider(
191
+ minimum=5, maximum=30, value=15, step=1,
192
+ label="Blur Radius (σ)"
193
+ )
194
+ gaussian_btn = gr.Button("Apply Gaussian Blur", variant="primary")
195
+ with gr.Column():
196
+ gaussian_output = gr.Image(type="pil", label="Output Image")
197
+
198
+ gaussian_btn.click(
199
+ fn=gaussian_blur_effect,
200
+ inputs=[gaussian_input, gaussian_radius],
201
+ outputs=gaussian_output
202
+ )
203
+
204
+ gr.Examples(
205
+ examples=[["self.jpg"], ["self-pic.jpg"]],
206
+ inputs=gaussian_input,
207
+ label="Example Images"
208
+ )
209
+
210
+ with gr.Tab("Lens Blur (Depth-Based)"):
211
+ gr.Markdown("### Depth-based bokeh effect simulation")
212
+ with gr.Row():
213
+ with gr.Column():
214
+ lens_input = gr.Image(type="pil", label="Input Image")
215
+ lens_max_blur = gr.Slider(
216
+ minimum=5, maximum=25, value=15, step=1,
217
+ label="Max Blur Intensity"
218
+ )
219
+ lens_focus = gr.Slider(
220
+ minimum=0, maximum=10, value=5.0, step=0.5,
221
+ label="Focus Threshold (lower = more blur)"
222
+ )
223
+ lens_btn = gr.Button("Apply Lens Blur", variant="primary")
224
+ with gr.Column():
225
+ lens_output = gr.Image(type="pil", label="Output Image")
226
+
227
+ lens_btn.click(
228
+ fn=lens_blur_effect,
229
+ inputs=[lens_input, lens_max_blur, lens_focus],
230
+ outputs=lens_output
231
+ )
232
+
233
+ gr.Examples(
234
+ examples=[["self.jpg"], ["self-pic.jpg"]],
235
+ inputs=lens_input,
236
+ label="Example Images"
237
+ )
238
+
239
+ gr.Markdown("""
240
+ ---
241
+ **Technical Details:**
242
+ - Segmentation: NVIDIA SegFormer (ADE20K)
243
+ - Depth Estimation: Depth Anything V2
244
+ - All images resized to 512×512 for processing
245
+ """)
246
+
247
+ if __name__ == "__main__":
248
+ demo.launch()
249
+