arcanoXIII commited on
Commit
7e08bf1
·
verified ·
1 Parent(s): a5bb83c

Upload 13 files

Browse files
README.md CHANGED
@@ -1,14 +1,130 @@
1
- ---
2
- title: RADAR Demo
3
- emoji: 🐨
4
- colorFrom: pink
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- short_description: Demo release for RADAR.
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ <h1 align="center"> Towards Reliable Identification of Diffusion-based Image Manipulations (NeurIPS 2025) </h1>
3
+
4
+ :rotating_light: This repository contains code snippets and checkpoints of our work "**Towards Reliable Identification of Diffusion-based Image Manipulations**" :rotating_light:
5
+
6
+ by [Alex Costanzino](https://alex-costanzino.github.io/)*1, [Woody Bayliss](https://www.linkedin.com/in/woody-bayliss-750a27325/)*2, [Juil Sock](https://skf0321.wixsite.com/juil)*2, [Marc Gorriz Blanch](https://www.linkedin.com/in/marc-gorriz/)*2, [Danijela Horak](https://www.linkedin.com/in/danijela-horak-3b417a70/?originalSubdomain=uk)*2, [Ivan Laptev](https://mbzuai.ac.ae/study/faculty/ivan-laptev/)*3, [Philip Tor](https://eng.ox.ac.uk/people/philip-torr/)*4 and [Fabio Pizzati](https://fabvio.github.io/)*3.
7
+
8
+ *1 University of Bologna, *2 BBC R&D, *3 MBZUAI, *4 University of Oxford
9
+
10
+
11
+ <div class="alert alert-info">
12
+
13
+
14
+ <h2 align="center">
15
+
16
+ [Project Page](https://alex-costanzino.github.io/radar/) | [Paper (ArXiv)](https://www.arxiv.org/abs/2506.05466)
17
+ </h2>
18
+
19
+
20
+ ## :bookmark_tabs: Table of Contents
21
+
22
+ 1. [Introduction](#clapper-introduction)
23
+ 2. [Dataset](#file_cabinet)
24
+ 3. [Checkpoints](#inbox_tray)
25
+ 4. [Code](#memo-code)
26
+ 6. [Contacts](#envelope-contacts)
27
+
28
+ </div>
29
+
30
+ ## :clapper: Introduction
31
+
32
+ Changing facial expressions, gestures, or background details may dramatically alter the meaning conveyed by an image. Notably, recent advances in diffusion models greatly improve the quality of image manipulation while also opening the door to misuse. Identifying changes made to authentic images, thus, becomes an important task, constantly challenged by new diffusion-based editing tools.
33
+
34
+ To this end, we propose a novel approach for ReliAble iDentification of inpainted AReas (RADAR). RADAR builds on existing foundation models and combines features from different image modalities. It also incorporates an auxiliary contrastive loss that helps to isolate manipulated image patches.
35
+
36
+ We demonstrate these techniques to significantly improve both the accuracy of our method and its generalisation to a large number of diffusion models. To support realistic evaluation, we further introduce BBC-PAIR, a new comprehensive benchmark, with images tampered by 28 diffusion models. Our experiments show that RADAR achieves excellent results, outperforming the state-of-the-art in detecting and localising image edits made by both seen and unseen diffusion models.
37
+
38
+ <h4 align="center">
39
+
40
+ </h4>
41
+
42
+ <img src="./assets/architecture.jpg" alt="Alt text" style="width: 800px;" title="architecture">
43
+
44
+ :fountain_pen: If you find this code useful in your research, please cite:
45
+
46
+ ```bibtex
47
+ @article{costanzino2025radar,
48
+ author = {Costanzino, Alex and Bayliss, Woody and Sock, Juil and Gorriz Blanch, Marc and Horak, Danijela and Laptev, Ivan and Torr, Philip and Pizzati, Fabio},
49
+ title = {Towards Reliable Identification of Diffusion-based Image Manipulations},
50
+ journal = {arXiv},
51
+ year = {2025},
52
+ }
53
+ ```
54
+
55
+ <h2 id="file_cabinet"> :file_cabinet: Dataset </h2>
56
+
57
+ In our experiments, we employed our benchmark: [BBC - PAIR (Paired Authentic and Inpainted References)](https://github.com/bbc/PAIR).
58
+
59
+
60
+ <h2 id="inbox_tray"> :inbox_tray: Checkpoints </h2>
61
+
62
+ Here, you can download the weights of the networks employed in the results our paper.
63
+
64
+ To use these weights, please follow these steps:
65
+
66
+ 1. Create a folder named `checkpoints` in the project directory;
67
+ 2. Download the weights [[Download]](https://drive.google.com/drive/folders/1VYbLoK3HM238AU1_xszBM4DYd2nai0jf?usp=sharing);
68
+ 3. Copy the downloaded weights into the `checkpoints` folder.
69
+
70
+
71
+ ## :memo: Code
72
+
73
+ <div class="alert alert-info">
74
+
75
+ **Warning**:
76
+ - The code utilizes `wandb` during training to log results. Please be sure to have a wandb account. Otherwise, if you prefer to not use `wandb`, disable it in `A_train.py` with the flag `mode = 'disabled'`.
77
+
78
+ </div>
79
+
80
+ ### :hammer_and_wrench: Setup Instructions
81
+
82
+ **Dependencies**: Ensure that you have installed all the necessary dependencies. The list of dependencies can be found in the `requirements.txt` file.
83
+
84
+ ### :rocket: Single-image Inference
85
+
86
+ The `C_infer_single-image.py` run the framework on an image to generate a continous tampering map.
87
+
88
+ You can specify the following arguments:
89
+ - `--checkpoint_savepath`: Path to the directory of the checkpoints, i.e., `checkpoints`;
90
+ - `--label`: A label to tell apart the experiments, i.e., `multimodal_15k_10inp`;
91
+ - `--image_path`: Path to the image to assess;
92
+ - `--epochs_no`: Number of epochs employed during the framework optimization.
93
+ - `--img_size`: Resolution employed during the framework optimization;
94
+ - `--batch_size`: Number of samples per batch employed during the framework optimization.
95
+
96
+ If you haven't downloaded the checkpoints yet, you can find the download links in the **Checkpoints** section above.
97
+
98
+ ### :rocket: Full Inference [TBA]
99
+
100
+ <!-- The `B_infer.py` script test the framework. It can be used to generate anomaly maps as well.
101
+
102
+ You can specify the following options:
103
+ - `--dataset_path`: Path to the root directory of the dataset.
104
+ - `--checkpoint_folder`: Path to the directory of the checkpoints, i.e., `checkpoints/checkpoints_visa`.
105
+ - `--class_name`: Class on which the framework was trained, i.e., `candle`.
106
+ - `--epochs_no`: Number of epochs used in framework optimization.
107
+ - `--batch_size`: Number of samples per batch employed for framework optimization.
108
+ - `--qualitative_folder`: Folder on which the anomaly maps are saved.
109
+ - `--quantitative_folder`: Folder on which the metrics are saved.
110
+ - `--visualize_plot`: Flag to visualize qualitatived during inference.
111
+ - `--produce_qualitatives`: Flag to save qualitatived during inference.
112
+
113
+ If you haven't downloaded the checkpoints yet, you can find the download links in the **Checkpoints** section above. -->
114
+
115
+ ### :rocket: Train [TBA]
116
+
117
+ <!-- The `A_train.py` script train the framework.
118
+
119
+ You can specify the following options:
120
+ - `--dataset_path`: Path to the root directory of the dataset.
121
+ - `--checkpoint_savepath`: Path to the directory on which checkpoints will be saved, i.e., `checkpoints/checkpoints_visa`.
122
+ - `--class_name`: Class on which the framework is trained, i.e., `candle`.
123
+ - `--epochs_no`: Number of epochs for framework optimization.
124
+ - `--img_size`: Resolution employed during training.
125
+ - `--batch_size`: Number of samples per batch for framework optimization.
126
+ - `--label`: A label to tell apart the experiments. -->
127
+
128
+ ## :envelope: Contacts
129
+
130
+ For questions, please send an email to [email protected].
app.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision
4
+ from PIL import Image
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import matplotlib
8
+ matplotlib.use('Agg')
9
+ import io
10
+ import os
11
+
12
+ # Import your models
13
+ from models.feature_extractor import FeatureExtractor, FeatureExtractorDepth
14
+ from models.projector import SiameseProjector
15
+ from models.fuser import DoubleCrossAttentionFusion
16
+ from loaders.loader_utils import SquarePad
17
+
18
+ # Configuration
19
+ CHECKPOINT_PATH = './checkpoints'
20
+ MODEL_LABEL = 'multimodal_15k_10inp'
21
+ EPOCHS = 120
22
+ BATCH_SIZE = 4
23
+ IMAGE_SIZE = 896
24
+
25
+ # Load models
26
+ print("Loading models...")
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ print(f"Using device: {device}")
29
+
30
+ model_name = f'{MODEL_LABEL}_{EPOCHS}ep_{BATCH_SIZE}bs'
31
+
32
+ rgb_transform = torchvision.transforms.Compose([
33
+ SquarePad(),
34
+ torchvision.transforms.Resize((IMAGE_SIZE, IMAGE_SIZE),
35
+ interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
36
+ torchvision.transforms.ToTensor(),
37
+ torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
38
+ std=[0.229, 0.224, 0.225]),
39
+ torchvision.transforms.Lambda(lambda img: img.unsqueeze(0)),
40
+ ])
41
+
42
+ fe_rgb = FeatureExtractor().to(device).eval()
43
+ fe_depth = FeatureExtractorDepth().to(device).eval()
44
+
45
+ fusion_block = DoubleCrossAttentionFusion(hidden_dim=fe_rgb.embed_dim).to(device)
46
+ fusion_block.load_state_dict(torch.load(
47
+ os.path.join(CHECKPOINT_PATH, f'fusion_block_{model_name}.pth'),
48
+ weights_only=False,
49
+ map_location=device
50
+ ))
51
+ fusion_block.eval()
52
+
53
+ projector = SiameseProjector(inner_features=fe_rgb.embed_dim).to(device)
54
+ projector.load_state_dict(torch.load(
55
+ os.path.join(CHECKPOINT_PATH, f'projector_{model_name}.pth'),
56
+ weights_only=False,
57
+ map_location=device
58
+ ))
59
+ projector.eval()
60
+
61
+ print("Models loaded successfully!")
62
+
63
+ def detect_manipulation(image):
64
+ """Process image and return heatmap"""
65
+ if image is None:
66
+ return None
67
+
68
+ # Convert to PIL if needed
69
+ if isinstance(image, np.ndarray):
70
+ rgb_input = Image.fromarray(image).convert('RGB')
71
+ else:
72
+ rgb_input = image.convert('RGB')
73
+
74
+ original_size = rgb_input.size
75
+
76
+ # Transform and process
77
+ rgb = rgb_transform(rgb_input)
78
+ rgb = rgb.to(device)
79
+
80
+ with torch.no_grad():
81
+ rgb_feat = fe_rgb(rgb)
82
+ depth_feat = fe_depth(rgb)
83
+ fused_feat = fusion_block(rgb_feat, depth_feat)
84
+ _, segmentation_map = projector(fused_feat)
85
+ segmentation_map = torch.sigmoid(segmentation_map)
86
+
87
+ # Resize back to original
88
+ segmentation_map = torch.nn.functional.interpolate(
89
+ segmentation_map,
90
+ size=[max(original_size), max(original_size)],
91
+ mode='bilinear'
92
+ ).squeeze()
93
+ segmentation_map = torchvision.transforms.functional.center_crop(
94
+ segmentation_map,
95
+ original_size[::-1]
96
+ )
97
+
98
+ heatmap = segmentation_map.cpu().detach().numpy()
99
+
100
+ # Create visualization
101
+ fig, ax = plt.subplots(figsize=(10, 10))
102
+ ax.imshow(heatmap, cmap='jet')
103
+ ax.axis('off')
104
+ plt.tight_layout(pad=0)
105
+
106
+ # Convert to image
107
+ buf = io.BytesIO()
108
+ plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=150)
109
+ buf.seek(0)
110
+ result_image = Image.open(buf)
111
+ plt.close(fig)
112
+
113
+ return result_image
114
+
115
+ # Custom CSS to match your design
116
+ custom_css = """
117
+ #component-0 {
118
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
119
+ padding: 20px;
120
+ }
121
+
122
+ .gradio-container {
123
+ max-width: 1200px !important;
124
+ margin: auto !important;
125
+ }
126
+
127
+ #title {
128
+ text-align: center;
129
+ color: white;
130
+ font-size: 3em;
131
+ margin-bottom: 10px;
132
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
133
+ }
134
+
135
+ #subtitle {
136
+ text-align: center;
137
+ color: white;
138
+ font-size: 1.2em;
139
+ margin-bottom: 30px;
140
+ opacity: 0.95;
141
+ }
142
+
143
+ .main-card {
144
+ background: white;
145
+ border-radius: 20px;
146
+ padding: 40px;
147
+ box-shadow: 0 20px 60px rgba(0,0,0,0.3);
148
+ }
149
+
150
+ #info-box {
151
+ background: #e8f4fd;
152
+ border-left: 4px solid #2196F3;
153
+ padding: 15px;
154
+ border-radius: 5px;
155
+ margin-bottom: 20px;
156
+ color: #1976D2;
157
+ }
158
+
159
+ .upload-container .transition {
160
+ border: 3px dashed #667eea !important;
161
+ border-radius: 15px !important;
162
+ background: #f8f9ff !important;
163
+ }
164
+
165
+ .upload-container .transition:hover {
166
+ border-color: #764ba2 !important;
167
+ background: #f0f2ff !important;
168
+ }
169
+
170
+ button.primary {
171
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
172
+ border: none !important;
173
+ color: white !important;
174
+ border-radius: 25px !important;
175
+ padding: 12px 30px !important;
176
+ font-size: 1em !important;
177
+ }
178
+
179
+ button.primary:hover {
180
+ transform: translateY(-2px);
181
+ box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4) !important;
182
+ }
183
+
184
+ .output-image {
185
+ border-radius: 15px;
186
+ box-shadow: 0 5px 15px rgba(0,0,0,0.1);
187
+ }
188
+ """
189
+
190
+ # Create Gradio interface
191
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
192
+ gr.HTML("""
193
+ <div id="title">🎯 RADAR</div>
194
+ <div id="subtitle">ReliAble iDentification of inpainted AReas</div>
195
+ """)
196
+
197
+ with gr.Column(elem_classes="main-card"):
198
+ gr.HTML("""
199
+ <div id="info-box">
200
+ <strong>ℹ️ About RADAR:</strong> Upload an image to detect and localize regions
201
+ that have been manipulated using diffusion-based inpainting models.
202
+ The output shows a heatmap where red areas indicate detected manipulations.
203
+ </div>
204
+ """)
205
+
206
+ with gr.Row():
207
+ with gr.Column():
208
+ input_image = gr.Image(
209
+ label="Upload Image",
210
+ type="pil",
211
+ elem_classes="upload-container"
212
+ )
213
+ submit_btn = gr.Button("🔍 Detect Manipulations", variant="primary", size="lg")
214
+
215
+ with gr.Column():
216
+ output_image = gr.Image(
217
+ label="Manipulation Heatmap",
218
+ type="pil",
219
+ elem_classes="output-image"
220
+ )
221
+
222
+ gr.Examples(
223
+ examples=[
224
+ # Add paths to your example images here
225
+ # ["examples/example1.png"],
226
+ # ["examples/example2.png"],
227
+ ],
228
+ inputs=input_image,
229
+ outputs=output_image,
230
+ fn=detect_manipulation,
231
+ cache_examples=False,
232
+ )
233
+
234
+ # Connect the button
235
+ submit_btn.click(
236
+ fn=detect_manipulation,
237
+ inputs=input_image,
238
+ outputs=output_image
239
+ )
240
+
241
+ # Also trigger on image upload
242
+ input_image.change(
243
+ fn=detect_manipulation,
244
+ inputs=input_image,
245
+ outputs=output_image
246
+ )
247
+
248
+ # Launch
249
+ if __name__ == "__main__":
250
+ demo.launch()
checkpoints/fusion_block_multimodal_15k_10inp_120ep_4bs.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:70601e0d63edddbfe42da48674dc9053dc626176c5fcda273a2018b10b80bd10
3
+ size 23646170
checkpoints/projector_multimodal_15k_10inp_120ep_4bs.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e5ec7007ebeae7f08054de2ab223899386bb1b977e028545175241efae8ce388
3
+ size 14208328
loaders/.DS_Store ADDED
Binary file (6.15 kB). View file
 
loaders/dataloader.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ import torchvision
6
+
7
+ import json
8
+
9
+ from torch.utils.data import Dataset
10
+ from torch.utils.data import DataLoader
11
+
12
+ class SquarePad:
13
+ def __call__(self, image):
14
+ max_wh = max(image.size)
15
+ p_left, p_top = [(max_wh - s) // 2 for s in image.size]
16
+ p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
17
+ padding = (p_left, p_top, p_right, p_bottom)
18
+ return transforms.functional.pad(image, padding, padding_mode = 'edge')
19
+
20
+ class BaseDataset(Dataset):
21
+ def __init__(self, img_size, dataset_path, inpainter):
22
+ self.IMAGENET_MEAN = [0.485, 0.456, 0.406]
23
+ self.IMAGENET_STD = [0.229, 0.224, 0.225]
24
+
25
+ self.img_size = img_size
26
+ self.dataset_path = dataset_path
27
+ self.inpainter = inpainter
28
+
29
+ self.json_path = os.path.join(dataset_path, 'DFDS_V2/DFDS_V2.0_2Percent.json')
30
+ # self.json_path = os.path.join(dataset_path, 'DFDS_V2.0_2Percent.json')
31
+
32
+ self.data = self.load_json()
33
+ self.data_train = self.data[0:500]
34
+
35
+ self.rgb_transform = transforms.Compose([
36
+ SquarePad(),
37
+ transforms.Resize((img_size, img_size), interpolation = transforms.InterpolationMode.BICUBIC),
38
+ transforms.ToTensor(),
39
+ transforms.Normalize(mean = self.IMAGENET_MEAN, std = self.IMAGENET_STD)
40
+ ])
41
+
42
+ def load_json(self):
43
+ with open(self.json_path, 'r') as file:
44
+ data = json.load(file)
45
+ return data
46
+
47
+ class TrainDataset(BaseDataset):
48
+ def __init__(self, img_size, dataset_path):
49
+ super().__init__(img_size = img_size, dataset_path = dataset_path, inpainter = None)
50
+
51
+ self.gt_transform = transforms.Compose([
52
+ SquarePad(),
53
+ transforms.Resize((img_size, img_size), interpolation = transforms.InterpolationMode.BICUBIC),
54
+ transforms.ToTensor()]
55
+ )
56
+
57
+ self.img_paths_pos, self.img_paths_neg, self.mask_paths_neg = self.load_dataset()
58
+
59
+ def load_dataset(self):
60
+ positive_imgs = [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
61
+ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
62
+ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
63
+ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
64
+ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
65
+ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
66
+ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
67
+ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
68
+ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
69
+ [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train]
70
+
71
+ negative_imgs = [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD1_Inpaint'].lstrip('/')) for data in self.data_train] + \
72
+ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD1.5_Inpaint'].lstrip('/')) for data in self.data_train] + \
73
+ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD2_Inpaint'].lstrip('/')) for data in self.data_train] + \
74
+ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SDXL_Inpaint'].lstrip('/')) for data in self.data_train] + \
75
+ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD3_Inpaint'].lstrip('/')) for data in self.data_train] + \
76
+ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD3.5_Inpaint'].lstrip('/')) for data in self.data_train] + \
77
+ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['kadinsky2.2_Inpaint'].lstrip('/')) for data in self.data_train] + \
78
+ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['kadinsky3.1_Inpaint'].lstrip('/')) for data in self.data_train] + \
79
+ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['FLUX_SHNELL_Inpaint'].lstrip('/')) for data in self.data_train] + \
80
+ [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['FLUX_DEV_FILL'].lstrip('/')) for data in self.data_train]
81
+
82
+ negative_masks = [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
83
+ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
84
+ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
85
+ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
86
+ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
87
+ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
88
+ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
89
+ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
90
+ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
91
+ [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train]
92
+
93
+ return positive_imgs, negative_imgs, negative_masks
94
+
95
+ def __len__(self):
96
+ return len(self.data_train) * 3
97
+
98
+ def __getitem__(self, idx):
99
+ img_path_pos, img_path_neg, gt_neg = self.img_paths_pos[idx], self.img_paths_neg[idx], self.mask_paths_neg[idx]
100
+
101
+ img_pos = Image.open(img_path_pos.replace('/Open_V7/','')).convert('RGB')
102
+ img_neg = Image.open(img_path_neg.replace('/Open_V7/','')).convert('RGB')
103
+
104
+ rgb_pos = self.rgb_transform(img_pos)
105
+ rgb_neg = self.rgb_transform(img_neg)
106
+
107
+ gt_pos = torch.zeros([1, img_pos.size[1], img_pos.size[0]])
108
+ gt_pos = torchvision.transforms.functional.to_pil_image(gt_pos)
109
+ gt_pos = self.gt_transform(gt_pos)
110
+
111
+ gt_neg = Image.open(gt_neg.replace('/Open_V7/','')).convert('L')
112
+ gt_neg = self.gt_transform(gt_neg)
113
+ gt_neg = torch.where(gt_neg > 0.5, 1., .0)
114
+
115
+ return rgb_pos, gt_pos, rgb_neg, gt_neg
116
+
117
+ class TestDataset(BaseDataset):
118
+ def __init__(self, img_size, dataset_path, inpainter):
119
+ super().__init__(img_size = img_size, dataset_path = dataset_path, inpainter = inpainter)
120
+
121
+ self.gt_transform = transforms.Compose([
122
+ transforms.ToTensor()])
123
+
124
+ self.img_paths, self.mask_paths, self.labels = self.load_dataset()
125
+
126
+ def load_dataset(self):
127
+ positive_imgs = [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data[500:600]]
128
+ positive_masks = [None for data in self.data[500:600]]
129
+
130
+ negative_imgs = [os.path.join(self.dataset_path, data['masks'][0]['inpainters'][self.inpainter].lstrip('/')) for data in self.data[600:700]]
131
+ negative_masks = [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data[600:700]]
132
+
133
+ labels = [0.0 for data in self.data[500:600]] + [1.0 for data in self.data[600:700]]
134
+
135
+ return positive_imgs + negative_imgs, positive_masks + negative_masks, labels
136
+
137
+ def __len__(self):
138
+ return len(self.data[500:700])
139
+
140
+ def __getitem__(self, idx):
141
+ img_path, gt, label = self.img_paths[idx], self.mask_paths[idx], self.labels[idx]
142
+
143
+ img = Image.open(img_path.replace('/Open_V7/','').replace('data/', 'data/DFDS_V2/')).convert('RGB')
144
+
145
+ rgb = self.rgb_transform(img)
146
+
147
+ if gt == None:
148
+ gt = torch.zeros(
149
+ [1, img.size[1], img.size[0]])
150
+ else:
151
+ gt = Image.open(gt.replace('/Open_V7/','').replace('data/', 'data/DFDS_V2/')).convert('L')
152
+ gt = self.gt_transform(gt)
153
+ gt = torch.where(gt > 0.5, 1., .0)
154
+
155
+ return rgb, label, gt, img_path
156
+
157
+
158
+ def get_data_loader(split, img_size, batch_size, dataset_path, inpainter = None):
159
+ if split == 'train':
160
+ dataset = TrainDataset(img_size, dataset_path)
161
+ data_loader = DataLoader(dataset = dataset, batch_size = batch_size, shuffle = True, num_workers = 8, drop_last = True, pin_memory = False)
162
+
163
+ elif split == 'test':
164
+ dataset = TestDataset(img_size, dataset_path, inpainter)
165
+ data_loader = DataLoader(dataset = dataset, batch_size = batch_size, shuffle = False, num_workers = 8, drop_last = False, pin_memory = False)
166
+
167
+ return data_loader
loaders/loader_utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import transforms
2
+
3
+ imagenet_denormalize = transforms.Compose([
4
+ transforms.Normalize(mean = [0., 0., 0.], std = [1/0.229, 1/0.224, 1/0.225]),
5
+ transforms.Normalize(mean = [-0.485, -0.456, -0.406], std = [1., 1., 1.])
6
+ ])
7
+
8
+ class SquarePad:
9
+ def __call__(self, image):
10
+ max_wh = max(image.size)
11
+ p_left, p_top = [(max_wh - s) // 2 for s in image.size]
12
+ p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
13
+ padding = (p_left, p_top, p_right, p_bottom)
14
+ return transforms.functional.pad(image, padding, padding_mode = 'edge')
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/feature_extractor.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForDepthEstimation
3
+
4
+ class FeatureExtractor(torch.nn.Module):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ self.fe = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
9
+
10
+ self.patch_size = self.fe.patch_size
11
+ self.embed_dim = self.fe.embed_dim
12
+
13
+ def forward(self, x):
14
+ return self.fe.forward_features(x)['x_norm_patchtokens']
15
+
16
+ class FeatureExtractorDepth(torch.nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ self.fe = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Base-hf")
21
+
22
+ self.patch_size = 14
23
+ self.embed_dim = 768
24
+
25
+ def forward(self, x):
26
+ x = self.fe(x, output_hidden_states=True).hidden_states
27
+ return x[-1][:,1:,:]
models/fuser.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class DoubleCrossAttentionFusion(torch.nn.Module):
4
+ def __init__(self, hidden_dim=768, num_heads=8, dropout=0.1):
5
+ super().__init__()
6
+
7
+ # 1. Per-modality normalization.
8
+ self.norm_rgb = torch.nn.LayerNorm(hidden_dim)
9
+ self.norm_depth = torch.nn.LayerNorm(hidden_dim)
10
+
11
+ # 2. Cross-attention.
12
+ self.cross_attn_depth = torch.torch.nn.MultiheadAttention(
13
+ embed_dim=hidden_dim,
14
+ num_heads=num_heads,
15
+ dropout=dropout,
16
+ batch_first=True,
17
+ )
18
+
19
+ self.cross_attn_rgb = torch.torch.nn.MultiheadAttention(
20
+ embed_dim=hidden_dim,
21
+ num_heads=num_heads,
22
+ dropout=dropout,
23
+ batch_first=True,
24
+ )
25
+
26
+ # 3. Mixing.
27
+ self.mixer = torch.nn.Sequential(
28
+ torch.nn.Linear(hidden_dim * 2, hidden_dim),
29
+ torch.nn.GELU(),
30
+ torch.nn.Dropout(dropout)
31
+ )
32
+
33
+ # 4. Output normalisation.
34
+ self.out_norm = torch.nn.LayerNorm(hidden_dim)
35
+
36
+ def forward(self, rgb_features, depth_features):
37
+ # 1. Normalize inputs.
38
+ rgb = self.norm_rgb(rgb_features)
39
+ depth = self.norm_depth(depth_features)
40
+
41
+ # 2a. Cross-attention (depth -> rgb).
42
+ attn_out_depth, _ = self.cross_attn_depth(
43
+ query=depth,
44
+ key=rgb,
45
+ value=rgb,
46
+ need_weights=False
47
+ )
48
+
49
+ # 2b. Cross-attention (rgb -> depth).
50
+ attn_out_rgb, _ = self.cross_attn_rgb(
51
+ query=rgb,
52
+ key=depth,
53
+ value=depth,
54
+ need_weights=False
55
+ )
56
+
57
+ # 3a. Residuals.
58
+ depth_attn = depth + attn_out_depth
59
+ rgb_attn = rgb + attn_out_rgb
60
+
61
+ # 3b. Mixing.
62
+ fused = self.mixer(torch.cat([depth_attn, rgb_attn], dim=-1))
63
+
64
+ # 4. Output normalisation.
65
+ return self.out_norm(fused)
models/losses.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Author: Yonglong Tian ([email protected])
3
+ Date: May 07, 2020
4
+ """
5
+ from __future__ import print_function
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class SupervisedContrastiveLoss(nn.Module):
12
+ """Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
13
+ It also supports the unsupervised contrastive loss in SimCLR"""
14
+ def __init__(self, temperature=0.1, contrast_mode='all', base_temperature=0.1):
15
+ super(SupervisedContrastiveLoss, self).__init__()
16
+ self.temperature = temperature
17
+ self.contrast_mode = contrast_mode
18
+ self.base_temperature = base_temperature
19
+
20
+ def forward(self, features, labels=None, mask=None):
21
+ """Compute loss for model. If both `labels` and `mask` are None,
22
+ it degenerates to SimCLR unsupervised loss:
23
+ https://arxiv.org/pdf/2002.05709.pdf
24
+
25
+ Args:
26
+ features: hidden vector of shape [bsz, n_views, ...].
27
+ labels: ground truth of shape [bsz].
28
+ mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
29
+ has the same class as sample i. Can be asymmetric.
30
+ Returns:
31
+ A loss scalar.
32
+ """
33
+ device = (torch.device('cuda')
34
+ if features.is_cuda
35
+ else torch.device('cpu'))
36
+
37
+ if len(features.shape) < 3:
38
+ raise ValueError('`features` needs to be [bsz, n_views, ...],'
39
+ 'at least 3 dimensions are required')
40
+ if len(features.shape) > 3:
41
+ features = features.view(features.shape[0], features.shape[1], -1)
42
+
43
+ batch_size = features.shape[0]
44
+ if labels is not None and mask is not None:
45
+ raise ValueError('Cannot define both `labels` and `mask`')
46
+ elif labels is None and mask is None:
47
+ mask = torch.eye(batch_size, dtype=torch.float32).to(device)
48
+ elif labels is not None:
49
+ labels = labels.contiguous().view(-1, 1)
50
+ if labels.shape[0] != batch_size:
51
+ raise ValueError('Num of labels does not match num of features')
52
+ mask = torch.eq(labels, labels.T).float().to(device)
53
+ else:
54
+ mask = mask.float().to(device)
55
+
56
+ contrast_count = features.shape[1]
57
+ contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
58
+ if self.contrast_mode == 'one':
59
+ anchor_feature = features[:, 0]
60
+ anchor_count = 1
61
+ elif self.contrast_mode == 'all':
62
+ anchor_feature = contrast_feature
63
+ anchor_count = contrast_count
64
+ else:
65
+ raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
66
+
67
+ # compute logits
68
+ anchor_dot_contrast = torch.div(
69
+ torch.matmul(anchor_feature, contrast_feature.T),
70
+ self.temperature)
71
+ # for numerical stability
72
+ logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
73
+ logits = anchor_dot_contrast - logits_max.detach()
74
+
75
+ # tile mask
76
+ mask = mask.repeat(anchor_count, contrast_count)
77
+ # mask-out self-contrast cases
78
+ logits_mask = torch.scatter(
79
+ torch.ones_like(mask),
80
+ 1,
81
+ torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
82
+ 0
83
+ )
84
+ mask = mask * logits_mask
85
+
86
+ # compute log_prob
87
+ exp_logits = torch.exp(logits) * logits_mask
88
+ log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
89
+
90
+ # compute mean of log-likelihood over positive
91
+ # modified to handle edge cases when there is no positive pair
92
+ # for an anchor point.
93
+ # Edge case e.g.:-
94
+ # features of shape: [4,1,...]
95
+ # labels: [0,1,1,2]
96
+ # loss before mean: [nan, ..., ..., nan]
97
+ mask_pos_pairs = mask.sum(1)
98
+ mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
99
+ mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs
100
+
101
+ # loss
102
+ loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
103
+ loss = loss.view(anchor_count, batch_size).mean()
104
+
105
+ return loss
models/projector.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class SiameseProjector(torch.nn.Module):
4
+ def __init__(self, inner_features = None, act_layer = torch.nn.GELU):
5
+ super().__init__()
6
+
7
+ self.inner_features = inner_features
8
+ self.act_fcn = act_layer()
9
+
10
+ # Localisation branch.
11
+ self.input = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
12
+ self.projection = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
13
+ self.output = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
14
+
15
+ # Contrastive branch.
16
+ self.input_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
17
+ self.projection_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
18
+ self.output_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
19
+
20
+ # Localisation head.
21
+ self.probe = torch.nn.Conv2d(in_channels=inner_features, out_channels=1, kernel_size=3)
22
+
23
+ def forward(self, x):
24
+
25
+ # Localisation branch.
26
+ x = self.input(x)
27
+ x = self.act_fcn(x)
28
+ x = self.projection(x)
29
+ x = self.act_fcn(x)
30
+ x = self.output(x)
31
+
32
+ # Localisation head.
33
+ seg = self.probe(x.permute(0,2,1).reshape(x.shape[0], self.inner_features, int(x.shape[1]**0.5), int(x.shape[1]**0.5)))
34
+
35
+ # Contrastive branch.
36
+ y = self.input_con(x)
37
+ y = self.act_fcn(y)
38
+ y = self.projection_con(y)
39
+ y = self.act_fcn(y)
40
+
41
+ # Contrastive head.
42
+ feat = self.output_con(y)
43
+
44
+ return feat, seg
45
+
46
+ def forward_segmentation(self, x):
47
+
48
+ # Localisation branch.
49
+ x = self.input(x)
50
+ x = self.act_fcn(x)
51
+ x = self.projection(x)
52
+ x = self.act_fcn(x)
53
+ x = self.output(x)
54
+
55
+ # Localisation head.
56
+ seg = self.probe(x.permute(0,2,1).reshape(x.shape[0], self.inner_features, int(x.shape[1]**0.5), int(x.shape[1]**0.5)))
57
+
58
+ return seg
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ datasets==4.0.0
3
+ matplotlib==3.10.7
4
+ numpy==2.3.3
5
+ Pillow==11.3.0
6
+ scikit_learn==1.7.2
7
+ segmentation_models_pytorch==0.5.0
8
+ torch==2.8.0
9
+ torchvision==0.23.0
10
+ tqdm==4.67.1
11
+ transformers==4.57.1
12
+ wandb==0.22.2