Spaces:
Running
Running
Upload 13 files
Browse files- README.md +130 -14
- app.py +250 -0
- checkpoints/fusion_block_multimodal_15k_10inp_120ep_4bs.pth +3 -0
- checkpoints/projector_multimodal_15k_10inp_120ep_4bs.pth +3 -0
- loaders/.DS_Store +0 -0
- loaders/dataloader.py +167 -0
- loaders/loader_utils.py +14 -0
- models/.DS_Store +0 -0
- models/feature_extractor.py +27 -0
- models/fuser.py +65 -0
- models/losses.py +105 -0
- models/projector.py +58 -0
- requirements.txt +12 -0
README.md
CHANGED
|
@@ -1,14 +1,130 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
<h1 align="center"> Towards Reliable Identification of Diffusion-based Image Manipulations (NeurIPS 2025) </h1>
|
| 3 |
+
|
| 4 |
+
:rotating_light: This repository contains code snippets and checkpoints of our work "**Towards Reliable Identification of Diffusion-based Image Manipulations**" :rotating_light:
|
| 5 |
+
|
| 6 |
+
by [Alex Costanzino](https://alex-costanzino.github.io/)*1, [Woody Bayliss](https://www.linkedin.com/in/woody-bayliss-750a27325/)*2, [Juil Sock](https://skf0321.wixsite.com/juil)*2, [Marc Gorriz Blanch](https://www.linkedin.com/in/marc-gorriz/)*2, [Danijela Horak](https://www.linkedin.com/in/danijela-horak-3b417a70/?originalSubdomain=uk)*2, [Ivan Laptev](https://mbzuai.ac.ae/study/faculty/ivan-laptev/)*3, [Philip Tor](https://eng.ox.ac.uk/people/philip-torr/)*4 and [Fabio Pizzati](https://fabvio.github.io/)*3.
|
| 7 |
+
|
| 8 |
+
*1 University of Bologna, *2 BBC R&D, *3 MBZUAI, *4 University of Oxford
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
<div class="alert alert-info">
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
<h2 align="center">
|
| 15 |
+
|
| 16 |
+
[Project Page](https://alex-costanzino.github.io/radar/) | [Paper (ArXiv)](https://www.arxiv.org/abs/2506.05466)
|
| 17 |
+
</h2>
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## :bookmark_tabs: Table of Contents
|
| 21 |
+
|
| 22 |
+
1. [Introduction](#clapper-introduction)
|
| 23 |
+
2. [Dataset](#file_cabinet)
|
| 24 |
+
3. [Checkpoints](#inbox_tray)
|
| 25 |
+
4. [Code](#memo-code)
|
| 26 |
+
6. [Contacts](#envelope-contacts)
|
| 27 |
+
|
| 28 |
+
</div>
|
| 29 |
+
|
| 30 |
+
## :clapper: Introduction
|
| 31 |
+
|
| 32 |
+
Changing facial expressions, gestures, or background details may dramatically alter the meaning conveyed by an image. Notably, recent advances in diffusion models greatly improve the quality of image manipulation while also opening the door to misuse. Identifying changes made to authentic images, thus, becomes an important task, constantly challenged by new diffusion-based editing tools.
|
| 33 |
+
|
| 34 |
+
To this end, we propose a novel approach for ReliAble iDentification of inpainted AReas (RADAR). RADAR builds on existing foundation models and combines features from different image modalities. It also incorporates an auxiliary contrastive loss that helps to isolate manipulated image patches.
|
| 35 |
+
|
| 36 |
+
We demonstrate these techniques to significantly improve both the accuracy of our method and its generalisation to a large number of diffusion models. To support realistic evaluation, we further introduce BBC-PAIR, a new comprehensive benchmark, with images tampered by 28 diffusion models. Our experiments show that RADAR achieves excellent results, outperforming the state-of-the-art in detecting and localising image edits made by both seen and unseen diffusion models.
|
| 37 |
+
|
| 38 |
+
<h4 align="center">
|
| 39 |
+
|
| 40 |
+
</h4>
|
| 41 |
+
|
| 42 |
+
<img src="./assets/architecture.jpg" alt="Alt text" style="width: 800px;" title="architecture">
|
| 43 |
+
|
| 44 |
+
:fountain_pen: If you find this code useful in your research, please cite:
|
| 45 |
+
|
| 46 |
+
```bibtex
|
| 47 |
+
@article{costanzino2025radar,
|
| 48 |
+
author = {Costanzino, Alex and Bayliss, Woody and Sock, Juil and Gorriz Blanch, Marc and Horak, Danijela and Laptev, Ivan and Torr, Philip and Pizzati, Fabio},
|
| 49 |
+
title = {Towards Reliable Identification of Diffusion-based Image Manipulations},
|
| 50 |
+
journal = {arXiv},
|
| 51 |
+
year = {2025},
|
| 52 |
+
}
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
<h2 id="file_cabinet"> :file_cabinet: Dataset </h2>
|
| 56 |
+
|
| 57 |
+
In our experiments, we employed our benchmark: [BBC - PAIR (Paired Authentic and Inpainted References)](https://github.com/bbc/PAIR).
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
<h2 id="inbox_tray"> :inbox_tray: Checkpoints </h2>
|
| 61 |
+
|
| 62 |
+
Here, you can download the weights of the networks employed in the results our paper.
|
| 63 |
+
|
| 64 |
+
To use these weights, please follow these steps:
|
| 65 |
+
|
| 66 |
+
1. Create a folder named `checkpoints` in the project directory;
|
| 67 |
+
2. Download the weights [[Download]](https://drive.google.com/drive/folders/1VYbLoK3HM238AU1_xszBM4DYd2nai0jf?usp=sharing);
|
| 68 |
+
3. Copy the downloaded weights into the `checkpoints` folder.
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
## :memo: Code
|
| 72 |
+
|
| 73 |
+
<div class="alert alert-info">
|
| 74 |
+
|
| 75 |
+
**Warning**:
|
| 76 |
+
- The code utilizes `wandb` during training to log results. Please be sure to have a wandb account. Otherwise, if you prefer to not use `wandb`, disable it in `A_train.py` with the flag `mode = 'disabled'`.
|
| 77 |
+
|
| 78 |
+
</div>
|
| 79 |
+
|
| 80 |
+
### :hammer_and_wrench: Setup Instructions
|
| 81 |
+
|
| 82 |
+
**Dependencies**: Ensure that you have installed all the necessary dependencies. The list of dependencies can be found in the `requirements.txt` file.
|
| 83 |
+
|
| 84 |
+
### :rocket: Single-image Inference
|
| 85 |
+
|
| 86 |
+
The `C_infer_single-image.py` run the framework on an image to generate a continous tampering map.
|
| 87 |
+
|
| 88 |
+
You can specify the following arguments:
|
| 89 |
+
- `--checkpoint_savepath`: Path to the directory of the checkpoints, i.e., `checkpoints`;
|
| 90 |
+
- `--label`: A label to tell apart the experiments, i.e., `multimodal_15k_10inp`;
|
| 91 |
+
- `--image_path`: Path to the image to assess;
|
| 92 |
+
- `--epochs_no`: Number of epochs employed during the framework optimization.
|
| 93 |
+
- `--img_size`: Resolution employed during the framework optimization;
|
| 94 |
+
- `--batch_size`: Number of samples per batch employed during the framework optimization.
|
| 95 |
+
|
| 96 |
+
If you haven't downloaded the checkpoints yet, you can find the download links in the **Checkpoints** section above.
|
| 97 |
+
|
| 98 |
+
### :rocket: Full Inference [TBA]
|
| 99 |
+
|
| 100 |
+
<!-- The `B_infer.py` script test the framework. It can be used to generate anomaly maps as well.
|
| 101 |
+
|
| 102 |
+
You can specify the following options:
|
| 103 |
+
- `--dataset_path`: Path to the root directory of the dataset.
|
| 104 |
+
- `--checkpoint_folder`: Path to the directory of the checkpoints, i.e., `checkpoints/checkpoints_visa`.
|
| 105 |
+
- `--class_name`: Class on which the framework was trained, i.e., `candle`.
|
| 106 |
+
- `--epochs_no`: Number of epochs used in framework optimization.
|
| 107 |
+
- `--batch_size`: Number of samples per batch employed for framework optimization.
|
| 108 |
+
- `--qualitative_folder`: Folder on which the anomaly maps are saved.
|
| 109 |
+
- `--quantitative_folder`: Folder on which the metrics are saved.
|
| 110 |
+
- `--visualize_plot`: Flag to visualize qualitatived during inference.
|
| 111 |
+
- `--produce_qualitatives`: Flag to save qualitatived during inference.
|
| 112 |
+
|
| 113 |
+
If you haven't downloaded the checkpoints yet, you can find the download links in the **Checkpoints** section above. -->
|
| 114 |
+
|
| 115 |
+
### :rocket: Train [TBA]
|
| 116 |
+
|
| 117 |
+
<!-- The `A_train.py` script train the framework.
|
| 118 |
+
|
| 119 |
+
You can specify the following options:
|
| 120 |
+
- `--dataset_path`: Path to the root directory of the dataset.
|
| 121 |
+
- `--checkpoint_savepath`: Path to the directory on which checkpoints will be saved, i.e., `checkpoints/checkpoints_visa`.
|
| 122 |
+
- `--class_name`: Class on which the framework is trained, i.e., `candle`.
|
| 123 |
+
- `--epochs_no`: Number of epochs for framework optimization.
|
| 124 |
+
- `--img_size`: Resolution employed during training.
|
| 125 |
+
- `--batch_size`: Number of samples per batch for framework optimization.
|
| 126 |
+
- `--label`: A label to tell apart the experiments. -->
|
| 127 |
+
|
| 128 |
+
## :envelope: Contacts
|
| 129 |
+
|
| 130 |
+
For questions, please send an email to [email protected].
|
app.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import torchvision
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import numpy as np
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import matplotlib
|
| 8 |
+
matplotlib.use('Agg')
|
| 9 |
+
import io
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# Import your models
|
| 13 |
+
from models.feature_extractor import FeatureExtractor, FeatureExtractorDepth
|
| 14 |
+
from models.projector import SiameseProjector
|
| 15 |
+
from models.fuser import DoubleCrossAttentionFusion
|
| 16 |
+
from loaders.loader_utils import SquarePad
|
| 17 |
+
|
| 18 |
+
# Configuration
|
| 19 |
+
CHECKPOINT_PATH = './checkpoints'
|
| 20 |
+
MODEL_LABEL = 'multimodal_15k_10inp'
|
| 21 |
+
EPOCHS = 120
|
| 22 |
+
BATCH_SIZE = 4
|
| 23 |
+
IMAGE_SIZE = 896
|
| 24 |
+
|
| 25 |
+
# Load models
|
| 26 |
+
print("Loading models...")
|
| 27 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 28 |
+
print(f"Using device: {device}")
|
| 29 |
+
|
| 30 |
+
model_name = f'{MODEL_LABEL}_{EPOCHS}ep_{BATCH_SIZE}bs'
|
| 31 |
+
|
| 32 |
+
rgb_transform = torchvision.transforms.Compose([
|
| 33 |
+
SquarePad(),
|
| 34 |
+
torchvision.transforms.Resize((IMAGE_SIZE, IMAGE_SIZE),
|
| 35 |
+
interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
|
| 36 |
+
torchvision.transforms.ToTensor(),
|
| 37 |
+
torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
| 38 |
+
std=[0.229, 0.224, 0.225]),
|
| 39 |
+
torchvision.transforms.Lambda(lambda img: img.unsqueeze(0)),
|
| 40 |
+
])
|
| 41 |
+
|
| 42 |
+
fe_rgb = FeatureExtractor().to(device).eval()
|
| 43 |
+
fe_depth = FeatureExtractorDepth().to(device).eval()
|
| 44 |
+
|
| 45 |
+
fusion_block = DoubleCrossAttentionFusion(hidden_dim=fe_rgb.embed_dim).to(device)
|
| 46 |
+
fusion_block.load_state_dict(torch.load(
|
| 47 |
+
os.path.join(CHECKPOINT_PATH, f'fusion_block_{model_name}.pth'),
|
| 48 |
+
weights_only=False,
|
| 49 |
+
map_location=device
|
| 50 |
+
))
|
| 51 |
+
fusion_block.eval()
|
| 52 |
+
|
| 53 |
+
projector = SiameseProjector(inner_features=fe_rgb.embed_dim).to(device)
|
| 54 |
+
projector.load_state_dict(torch.load(
|
| 55 |
+
os.path.join(CHECKPOINT_PATH, f'projector_{model_name}.pth'),
|
| 56 |
+
weights_only=False,
|
| 57 |
+
map_location=device
|
| 58 |
+
))
|
| 59 |
+
projector.eval()
|
| 60 |
+
|
| 61 |
+
print("Models loaded successfully!")
|
| 62 |
+
|
| 63 |
+
def detect_manipulation(image):
|
| 64 |
+
"""Process image and return heatmap"""
|
| 65 |
+
if image is None:
|
| 66 |
+
return None
|
| 67 |
+
|
| 68 |
+
# Convert to PIL if needed
|
| 69 |
+
if isinstance(image, np.ndarray):
|
| 70 |
+
rgb_input = Image.fromarray(image).convert('RGB')
|
| 71 |
+
else:
|
| 72 |
+
rgb_input = image.convert('RGB')
|
| 73 |
+
|
| 74 |
+
original_size = rgb_input.size
|
| 75 |
+
|
| 76 |
+
# Transform and process
|
| 77 |
+
rgb = rgb_transform(rgb_input)
|
| 78 |
+
rgb = rgb.to(device)
|
| 79 |
+
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
rgb_feat = fe_rgb(rgb)
|
| 82 |
+
depth_feat = fe_depth(rgb)
|
| 83 |
+
fused_feat = fusion_block(rgb_feat, depth_feat)
|
| 84 |
+
_, segmentation_map = projector(fused_feat)
|
| 85 |
+
segmentation_map = torch.sigmoid(segmentation_map)
|
| 86 |
+
|
| 87 |
+
# Resize back to original
|
| 88 |
+
segmentation_map = torch.nn.functional.interpolate(
|
| 89 |
+
segmentation_map,
|
| 90 |
+
size=[max(original_size), max(original_size)],
|
| 91 |
+
mode='bilinear'
|
| 92 |
+
).squeeze()
|
| 93 |
+
segmentation_map = torchvision.transforms.functional.center_crop(
|
| 94 |
+
segmentation_map,
|
| 95 |
+
original_size[::-1]
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
heatmap = segmentation_map.cpu().detach().numpy()
|
| 99 |
+
|
| 100 |
+
# Create visualization
|
| 101 |
+
fig, ax = plt.subplots(figsize=(10, 10))
|
| 102 |
+
ax.imshow(heatmap, cmap='jet')
|
| 103 |
+
ax.axis('off')
|
| 104 |
+
plt.tight_layout(pad=0)
|
| 105 |
+
|
| 106 |
+
# Convert to image
|
| 107 |
+
buf = io.BytesIO()
|
| 108 |
+
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=150)
|
| 109 |
+
buf.seek(0)
|
| 110 |
+
result_image = Image.open(buf)
|
| 111 |
+
plt.close(fig)
|
| 112 |
+
|
| 113 |
+
return result_image
|
| 114 |
+
|
| 115 |
+
# Custom CSS to match your design
|
| 116 |
+
custom_css = """
|
| 117 |
+
#component-0 {
|
| 118 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 119 |
+
padding: 20px;
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
.gradio-container {
|
| 123 |
+
max-width: 1200px !important;
|
| 124 |
+
margin: auto !important;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
#title {
|
| 128 |
+
text-align: center;
|
| 129 |
+
color: white;
|
| 130 |
+
font-size: 3em;
|
| 131 |
+
margin-bottom: 10px;
|
| 132 |
+
text-shadow: 2px 2px 4px rgba(0,0,0,0.2);
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
#subtitle {
|
| 136 |
+
text-align: center;
|
| 137 |
+
color: white;
|
| 138 |
+
font-size: 1.2em;
|
| 139 |
+
margin-bottom: 30px;
|
| 140 |
+
opacity: 0.95;
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
.main-card {
|
| 144 |
+
background: white;
|
| 145 |
+
border-radius: 20px;
|
| 146 |
+
padding: 40px;
|
| 147 |
+
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
#info-box {
|
| 151 |
+
background: #e8f4fd;
|
| 152 |
+
border-left: 4px solid #2196F3;
|
| 153 |
+
padding: 15px;
|
| 154 |
+
border-radius: 5px;
|
| 155 |
+
margin-bottom: 20px;
|
| 156 |
+
color: #1976D2;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
.upload-container .transition {
|
| 160 |
+
border: 3px dashed #667eea !important;
|
| 161 |
+
border-radius: 15px !important;
|
| 162 |
+
background: #f8f9ff !important;
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
.upload-container .transition:hover {
|
| 166 |
+
border-color: #764ba2 !important;
|
| 167 |
+
background: #f0f2ff !important;
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
button.primary {
|
| 171 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
|
| 172 |
+
border: none !important;
|
| 173 |
+
color: white !important;
|
| 174 |
+
border-radius: 25px !important;
|
| 175 |
+
padding: 12px 30px !important;
|
| 176 |
+
font-size: 1em !important;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
button.primary:hover {
|
| 180 |
+
transform: translateY(-2px);
|
| 181 |
+
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4) !important;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
.output-image {
|
| 185 |
+
border-radius: 15px;
|
| 186 |
+
box-shadow: 0 5px 15px rgba(0,0,0,0.1);
|
| 187 |
+
}
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
# Create Gradio interface
|
| 191 |
+
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
|
| 192 |
+
gr.HTML("""
|
| 193 |
+
<div id="title">🎯 RADAR</div>
|
| 194 |
+
<div id="subtitle">ReliAble iDentification of inpainted AReas</div>
|
| 195 |
+
""")
|
| 196 |
+
|
| 197 |
+
with gr.Column(elem_classes="main-card"):
|
| 198 |
+
gr.HTML("""
|
| 199 |
+
<div id="info-box">
|
| 200 |
+
<strong>ℹ️ About RADAR:</strong> Upload an image to detect and localize regions
|
| 201 |
+
that have been manipulated using diffusion-based inpainting models.
|
| 202 |
+
The output shows a heatmap where red areas indicate detected manipulations.
|
| 203 |
+
</div>
|
| 204 |
+
""")
|
| 205 |
+
|
| 206 |
+
with gr.Row():
|
| 207 |
+
with gr.Column():
|
| 208 |
+
input_image = gr.Image(
|
| 209 |
+
label="Upload Image",
|
| 210 |
+
type="pil",
|
| 211 |
+
elem_classes="upload-container"
|
| 212 |
+
)
|
| 213 |
+
submit_btn = gr.Button("🔍 Detect Manipulations", variant="primary", size="lg")
|
| 214 |
+
|
| 215 |
+
with gr.Column():
|
| 216 |
+
output_image = gr.Image(
|
| 217 |
+
label="Manipulation Heatmap",
|
| 218 |
+
type="pil",
|
| 219 |
+
elem_classes="output-image"
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
gr.Examples(
|
| 223 |
+
examples=[
|
| 224 |
+
# Add paths to your example images here
|
| 225 |
+
# ["examples/example1.png"],
|
| 226 |
+
# ["examples/example2.png"],
|
| 227 |
+
],
|
| 228 |
+
inputs=input_image,
|
| 229 |
+
outputs=output_image,
|
| 230 |
+
fn=detect_manipulation,
|
| 231 |
+
cache_examples=False,
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Connect the button
|
| 235 |
+
submit_btn.click(
|
| 236 |
+
fn=detect_manipulation,
|
| 237 |
+
inputs=input_image,
|
| 238 |
+
outputs=output_image
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# Also trigger on image upload
|
| 242 |
+
input_image.change(
|
| 243 |
+
fn=detect_manipulation,
|
| 244 |
+
inputs=input_image,
|
| 245 |
+
outputs=output_image
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# Launch
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
demo.launch()
|
checkpoints/fusion_block_multimodal_15k_10inp_120ep_4bs.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:70601e0d63edddbfe42da48674dc9053dc626176c5fcda273a2018b10b80bd10
|
| 3 |
+
size 23646170
|
checkpoints/projector_multimodal_15k_10inp_120ep_4bs.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e5ec7007ebeae7f08054de2ab223899386bb1b977e028545175241efae8ce388
|
| 3 |
+
size 14208328
|
loaders/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
loaders/dataloader.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
import torchvision
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
from torch.utils.data import Dataset
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
|
| 12 |
+
class SquarePad:
|
| 13 |
+
def __call__(self, image):
|
| 14 |
+
max_wh = max(image.size)
|
| 15 |
+
p_left, p_top = [(max_wh - s) // 2 for s in image.size]
|
| 16 |
+
p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
|
| 17 |
+
padding = (p_left, p_top, p_right, p_bottom)
|
| 18 |
+
return transforms.functional.pad(image, padding, padding_mode = 'edge')
|
| 19 |
+
|
| 20 |
+
class BaseDataset(Dataset):
|
| 21 |
+
def __init__(self, img_size, dataset_path, inpainter):
|
| 22 |
+
self.IMAGENET_MEAN = [0.485, 0.456, 0.406]
|
| 23 |
+
self.IMAGENET_STD = [0.229, 0.224, 0.225]
|
| 24 |
+
|
| 25 |
+
self.img_size = img_size
|
| 26 |
+
self.dataset_path = dataset_path
|
| 27 |
+
self.inpainter = inpainter
|
| 28 |
+
|
| 29 |
+
self.json_path = os.path.join(dataset_path, 'DFDS_V2/DFDS_V2.0_2Percent.json')
|
| 30 |
+
# self.json_path = os.path.join(dataset_path, 'DFDS_V2.0_2Percent.json')
|
| 31 |
+
|
| 32 |
+
self.data = self.load_json()
|
| 33 |
+
self.data_train = self.data[0:500]
|
| 34 |
+
|
| 35 |
+
self.rgb_transform = transforms.Compose([
|
| 36 |
+
SquarePad(),
|
| 37 |
+
transforms.Resize((img_size, img_size), interpolation = transforms.InterpolationMode.BICUBIC),
|
| 38 |
+
transforms.ToTensor(),
|
| 39 |
+
transforms.Normalize(mean = self.IMAGENET_MEAN, std = self.IMAGENET_STD)
|
| 40 |
+
])
|
| 41 |
+
|
| 42 |
+
def load_json(self):
|
| 43 |
+
with open(self.json_path, 'r') as file:
|
| 44 |
+
data = json.load(file)
|
| 45 |
+
return data
|
| 46 |
+
|
| 47 |
+
class TrainDataset(BaseDataset):
|
| 48 |
+
def __init__(self, img_size, dataset_path):
|
| 49 |
+
super().__init__(img_size = img_size, dataset_path = dataset_path, inpainter = None)
|
| 50 |
+
|
| 51 |
+
self.gt_transform = transforms.Compose([
|
| 52 |
+
SquarePad(),
|
| 53 |
+
transforms.Resize((img_size, img_size), interpolation = transforms.InterpolationMode.BICUBIC),
|
| 54 |
+
transforms.ToTensor()]
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.img_paths_pos, self.img_paths_neg, self.mask_paths_neg = self.load_dataset()
|
| 58 |
+
|
| 59 |
+
def load_dataset(self):
|
| 60 |
+
positive_imgs = [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
|
| 61 |
+
[os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
|
| 62 |
+
[os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
|
| 63 |
+
[os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
|
| 64 |
+
[os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
|
| 65 |
+
[os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
|
| 66 |
+
[os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
|
| 67 |
+
[os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
|
| 68 |
+
[os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train] + \
|
| 69 |
+
[os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data_train]
|
| 70 |
+
|
| 71 |
+
negative_imgs = [os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD1_Inpaint'].lstrip('/')) for data in self.data_train] + \
|
| 72 |
+
[os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD1.5_Inpaint'].lstrip('/')) for data in self.data_train] + \
|
| 73 |
+
[os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD2_Inpaint'].lstrip('/')) for data in self.data_train] + \
|
| 74 |
+
[os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SDXL_Inpaint'].lstrip('/')) for data in self.data_train] + \
|
| 75 |
+
[os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD3_Inpaint'].lstrip('/')) for data in self.data_train] + \
|
| 76 |
+
[os.path.join(self.dataset_path, data['masks'][0]['inpainters']['SD3.5_Inpaint'].lstrip('/')) for data in self.data_train] + \
|
| 77 |
+
[os.path.join(self.dataset_path, data['masks'][0]['inpainters']['kadinsky2.2_Inpaint'].lstrip('/')) for data in self.data_train] + \
|
| 78 |
+
[os.path.join(self.dataset_path, data['masks'][0]['inpainters']['kadinsky3.1_Inpaint'].lstrip('/')) for data in self.data_train] + \
|
| 79 |
+
[os.path.join(self.dataset_path, data['masks'][0]['inpainters']['FLUX_SHNELL_Inpaint'].lstrip('/')) for data in self.data_train] + \
|
| 80 |
+
[os.path.join(self.dataset_path, data['masks'][0]['inpainters']['FLUX_DEV_FILL'].lstrip('/')) for data in self.data_train]
|
| 81 |
+
|
| 82 |
+
negative_masks = [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
|
| 83 |
+
[os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
|
| 84 |
+
[os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
|
| 85 |
+
[os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
|
| 86 |
+
[os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
|
| 87 |
+
[os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
|
| 88 |
+
[os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
|
| 89 |
+
[os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
|
| 90 |
+
[os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train] + \
|
| 91 |
+
[os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data_train]
|
| 92 |
+
|
| 93 |
+
return positive_imgs, negative_imgs, negative_masks
|
| 94 |
+
|
| 95 |
+
def __len__(self):
|
| 96 |
+
return len(self.data_train) * 3
|
| 97 |
+
|
| 98 |
+
def __getitem__(self, idx):
|
| 99 |
+
img_path_pos, img_path_neg, gt_neg = self.img_paths_pos[idx], self.img_paths_neg[idx], self.mask_paths_neg[idx]
|
| 100 |
+
|
| 101 |
+
img_pos = Image.open(img_path_pos.replace('/Open_V7/','')).convert('RGB')
|
| 102 |
+
img_neg = Image.open(img_path_neg.replace('/Open_V7/','')).convert('RGB')
|
| 103 |
+
|
| 104 |
+
rgb_pos = self.rgb_transform(img_pos)
|
| 105 |
+
rgb_neg = self.rgb_transform(img_neg)
|
| 106 |
+
|
| 107 |
+
gt_pos = torch.zeros([1, img_pos.size[1], img_pos.size[0]])
|
| 108 |
+
gt_pos = torchvision.transforms.functional.to_pil_image(gt_pos)
|
| 109 |
+
gt_pos = self.gt_transform(gt_pos)
|
| 110 |
+
|
| 111 |
+
gt_neg = Image.open(gt_neg.replace('/Open_V7/','')).convert('L')
|
| 112 |
+
gt_neg = self.gt_transform(gt_neg)
|
| 113 |
+
gt_neg = torch.where(gt_neg > 0.5, 1., .0)
|
| 114 |
+
|
| 115 |
+
return rgb_pos, gt_pos, rgb_neg, gt_neg
|
| 116 |
+
|
| 117 |
+
class TestDataset(BaseDataset):
|
| 118 |
+
def __init__(self, img_size, dataset_path, inpainter):
|
| 119 |
+
super().__init__(img_size = img_size, dataset_path = dataset_path, inpainter = inpainter)
|
| 120 |
+
|
| 121 |
+
self.gt_transform = transforms.Compose([
|
| 122 |
+
transforms.ToTensor()])
|
| 123 |
+
|
| 124 |
+
self.img_paths, self.mask_paths, self.labels = self.load_dataset()
|
| 125 |
+
|
| 126 |
+
def load_dataset(self):
|
| 127 |
+
positive_imgs = [os.path.join(self.dataset_path, data['base_image_location'].lstrip('/')) for data in self.data[500:600]]
|
| 128 |
+
positive_masks = [None for data in self.data[500:600]]
|
| 129 |
+
|
| 130 |
+
negative_imgs = [os.path.join(self.dataset_path, data['masks'][0]['inpainters'][self.inpainter].lstrip('/')) for data in self.data[600:700]]
|
| 131 |
+
negative_masks = [os.path.join(self.dataset_path, data['masks'][0]['edited_mask_location'].lstrip('/')) for data in self.data[600:700]]
|
| 132 |
+
|
| 133 |
+
labels = [0.0 for data in self.data[500:600]] + [1.0 for data in self.data[600:700]]
|
| 134 |
+
|
| 135 |
+
return positive_imgs + negative_imgs, positive_masks + negative_masks, labels
|
| 136 |
+
|
| 137 |
+
def __len__(self):
|
| 138 |
+
return len(self.data[500:700])
|
| 139 |
+
|
| 140 |
+
def __getitem__(self, idx):
|
| 141 |
+
img_path, gt, label = self.img_paths[idx], self.mask_paths[idx], self.labels[idx]
|
| 142 |
+
|
| 143 |
+
img = Image.open(img_path.replace('/Open_V7/','').replace('data/', 'data/DFDS_V2/')).convert('RGB')
|
| 144 |
+
|
| 145 |
+
rgb = self.rgb_transform(img)
|
| 146 |
+
|
| 147 |
+
if gt == None:
|
| 148 |
+
gt = torch.zeros(
|
| 149 |
+
[1, img.size[1], img.size[0]])
|
| 150 |
+
else:
|
| 151 |
+
gt = Image.open(gt.replace('/Open_V7/','').replace('data/', 'data/DFDS_V2/')).convert('L')
|
| 152 |
+
gt = self.gt_transform(gt)
|
| 153 |
+
gt = torch.where(gt > 0.5, 1., .0)
|
| 154 |
+
|
| 155 |
+
return rgb, label, gt, img_path
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def get_data_loader(split, img_size, batch_size, dataset_path, inpainter = None):
|
| 159 |
+
if split == 'train':
|
| 160 |
+
dataset = TrainDataset(img_size, dataset_path)
|
| 161 |
+
data_loader = DataLoader(dataset = dataset, batch_size = batch_size, shuffle = True, num_workers = 8, drop_last = True, pin_memory = False)
|
| 162 |
+
|
| 163 |
+
elif split == 'test':
|
| 164 |
+
dataset = TestDataset(img_size, dataset_path, inpainter)
|
| 165 |
+
data_loader = DataLoader(dataset = dataset, batch_size = batch_size, shuffle = False, num_workers = 8, drop_last = False, pin_memory = False)
|
| 166 |
+
|
| 167 |
+
return data_loader
|
loaders/loader_utils.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision import transforms
|
| 2 |
+
|
| 3 |
+
imagenet_denormalize = transforms.Compose([
|
| 4 |
+
transforms.Normalize(mean = [0., 0., 0.], std = [1/0.229, 1/0.224, 1/0.225]),
|
| 5 |
+
transforms.Normalize(mean = [-0.485, -0.456, -0.406], std = [1., 1., 1.])
|
| 6 |
+
])
|
| 7 |
+
|
| 8 |
+
class SquarePad:
|
| 9 |
+
def __call__(self, image):
|
| 10 |
+
max_wh = max(image.size)
|
| 11 |
+
p_left, p_top = [(max_wh - s) // 2 for s in image.size]
|
| 12 |
+
p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
|
| 13 |
+
padding = (p_left, p_top, p_right, p_bottom)
|
| 14 |
+
return transforms.functional.pad(image, padding, padding_mode = 'edge')
|
models/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
models/feature_extractor.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForDepthEstimation
|
| 3 |
+
|
| 4 |
+
class FeatureExtractor(torch.nn.Module):
|
| 5 |
+
def __init__(self):
|
| 6 |
+
super().__init__()
|
| 7 |
+
|
| 8 |
+
self.fe = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
|
| 9 |
+
|
| 10 |
+
self.patch_size = self.fe.patch_size
|
| 11 |
+
self.embed_dim = self.fe.embed_dim
|
| 12 |
+
|
| 13 |
+
def forward(self, x):
|
| 14 |
+
return self.fe.forward_features(x)['x_norm_patchtokens']
|
| 15 |
+
|
| 16 |
+
class FeatureExtractorDepth(torch.nn.Module):
|
| 17 |
+
def __init__(self):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
self.fe = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Base-hf")
|
| 21 |
+
|
| 22 |
+
self.patch_size = 14
|
| 23 |
+
self.embed_dim = 768
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
x = self.fe(x, output_hidden_states=True).hidden_states
|
| 27 |
+
return x[-1][:,1:,:]
|
models/fuser.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
class DoubleCrossAttentionFusion(torch.nn.Module):
|
| 4 |
+
def __init__(self, hidden_dim=768, num_heads=8, dropout=0.1):
|
| 5 |
+
super().__init__()
|
| 6 |
+
|
| 7 |
+
# 1. Per-modality normalization.
|
| 8 |
+
self.norm_rgb = torch.nn.LayerNorm(hidden_dim)
|
| 9 |
+
self.norm_depth = torch.nn.LayerNorm(hidden_dim)
|
| 10 |
+
|
| 11 |
+
# 2. Cross-attention.
|
| 12 |
+
self.cross_attn_depth = torch.torch.nn.MultiheadAttention(
|
| 13 |
+
embed_dim=hidden_dim,
|
| 14 |
+
num_heads=num_heads,
|
| 15 |
+
dropout=dropout,
|
| 16 |
+
batch_first=True,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
self.cross_attn_rgb = torch.torch.nn.MultiheadAttention(
|
| 20 |
+
embed_dim=hidden_dim,
|
| 21 |
+
num_heads=num_heads,
|
| 22 |
+
dropout=dropout,
|
| 23 |
+
batch_first=True,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
# 3. Mixing.
|
| 27 |
+
self.mixer = torch.nn.Sequential(
|
| 28 |
+
torch.nn.Linear(hidden_dim * 2, hidden_dim),
|
| 29 |
+
torch.nn.GELU(),
|
| 30 |
+
torch.nn.Dropout(dropout)
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# 4. Output normalisation.
|
| 34 |
+
self.out_norm = torch.nn.LayerNorm(hidden_dim)
|
| 35 |
+
|
| 36 |
+
def forward(self, rgb_features, depth_features):
|
| 37 |
+
# 1. Normalize inputs.
|
| 38 |
+
rgb = self.norm_rgb(rgb_features)
|
| 39 |
+
depth = self.norm_depth(depth_features)
|
| 40 |
+
|
| 41 |
+
# 2a. Cross-attention (depth -> rgb).
|
| 42 |
+
attn_out_depth, _ = self.cross_attn_depth(
|
| 43 |
+
query=depth,
|
| 44 |
+
key=rgb,
|
| 45 |
+
value=rgb,
|
| 46 |
+
need_weights=False
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# 2b. Cross-attention (rgb -> depth).
|
| 50 |
+
attn_out_rgb, _ = self.cross_attn_rgb(
|
| 51 |
+
query=rgb,
|
| 52 |
+
key=depth,
|
| 53 |
+
value=depth,
|
| 54 |
+
need_weights=False
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# 3a. Residuals.
|
| 58 |
+
depth_attn = depth + attn_out_depth
|
| 59 |
+
rgb_attn = rgb + attn_out_rgb
|
| 60 |
+
|
| 61 |
+
# 3b. Mixing.
|
| 62 |
+
fused = self.mixer(torch.cat([depth_attn, rgb_attn], dim=-1))
|
| 63 |
+
|
| 64 |
+
# 4. Output normalisation.
|
| 65 |
+
return self.out_norm(fused)
|
models/losses.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Yonglong Tian ([email protected])
|
| 3 |
+
Date: May 07, 2020
|
| 4 |
+
"""
|
| 5 |
+
from __future__ import print_function
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SupervisedContrastiveLoss(nn.Module):
|
| 12 |
+
"""Supervised Contrastive Learning: https://arxiv.org/pdf/2004.11362.pdf.
|
| 13 |
+
It also supports the unsupervised contrastive loss in SimCLR"""
|
| 14 |
+
def __init__(self, temperature=0.1, contrast_mode='all', base_temperature=0.1):
|
| 15 |
+
super(SupervisedContrastiveLoss, self).__init__()
|
| 16 |
+
self.temperature = temperature
|
| 17 |
+
self.contrast_mode = contrast_mode
|
| 18 |
+
self.base_temperature = base_temperature
|
| 19 |
+
|
| 20 |
+
def forward(self, features, labels=None, mask=None):
|
| 21 |
+
"""Compute loss for model. If both `labels` and `mask` are None,
|
| 22 |
+
it degenerates to SimCLR unsupervised loss:
|
| 23 |
+
https://arxiv.org/pdf/2002.05709.pdf
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
features: hidden vector of shape [bsz, n_views, ...].
|
| 27 |
+
labels: ground truth of shape [bsz].
|
| 28 |
+
mask: contrastive mask of shape [bsz, bsz], mask_{i,j}=1 if sample j
|
| 29 |
+
has the same class as sample i. Can be asymmetric.
|
| 30 |
+
Returns:
|
| 31 |
+
A loss scalar.
|
| 32 |
+
"""
|
| 33 |
+
device = (torch.device('cuda')
|
| 34 |
+
if features.is_cuda
|
| 35 |
+
else torch.device('cpu'))
|
| 36 |
+
|
| 37 |
+
if len(features.shape) < 3:
|
| 38 |
+
raise ValueError('`features` needs to be [bsz, n_views, ...],'
|
| 39 |
+
'at least 3 dimensions are required')
|
| 40 |
+
if len(features.shape) > 3:
|
| 41 |
+
features = features.view(features.shape[0], features.shape[1], -1)
|
| 42 |
+
|
| 43 |
+
batch_size = features.shape[0]
|
| 44 |
+
if labels is not None and mask is not None:
|
| 45 |
+
raise ValueError('Cannot define both `labels` and `mask`')
|
| 46 |
+
elif labels is None and mask is None:
|
| 47 |
+
mask = torch.eye(batch_size, dtype=torch.float32).to(device)
|
| 48 |
+
elif labels is not None:
|
| 49 |
+
labels = labels.contiguous().view(-1, 1)
|
| 50 |
+
if labels.shape[0] != batch_size:
|
| 51 |
+
raise ValueError('Num of labels does not match num of features')
|
| 52 |
+
mask = torch.eq(labels, labels.T).float().to(device)
|
| 53 |
+
else:
|
| 54 |
+
mask = mask.float().to(device)
|
| 55 |
+
|
| 56 |
+
contrast_count = features.shape[1]
|
| 57 |
+
contrast_feature = torch.cat(torch.unbind(features, dim=1), dim=0)
|
| 58 |
+
if self.contrast_mode == 'one':
|
| 59 |
+
anchor_feature = features[:, 0]
|
| 60 |
+
anchor_count = 1
|
| 61 |
+
elif self.contrast_mode == 'all':
|
| 62 |
+
anchor_feature = contrast_feature
|
| 63 |
+
anchor_count = contrast_count
|
| 64 |
+
else:
|
| 65 |
+
raise ValueError('Unknown mode: {}'.format(self.contrast_mode))
|
| 66 |
+
|
| 67 |
+
# compute logits
|
| 68 |
+
anchor_dot_contrast = torch.div(
|
| 69 |
+
torch.matmul(anchor_feature, contrast_feature.T),
|
| 70 |
+
self.temperature)
|
| 71 |
+
# for numerical stability
|
| 72 |
+
logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
|
| 73 |
+
logits = anchor_dot_contrast - logits_max.detach()
|
| 74 |
+
|
| 75 |
+
# tile mask
|
| 76 |
+
mask = mask.repeat(anchor_count, contrast_count)
|
| 77 |
+
# mask-out self-contrast cases
|
| 78 |
+
logits_mask = torch.scatter(
|
| 79 |
+
torch.ones_like(mask),
|
| 80 |
+
1,
|
| 81 |
+
torch.arange(batch_size * anchor_count).view(-1, 1).to(device),
|
| 82 |
+
0
|
| 83 |
+
)
|
| 84 |
+
mask = mask * logits_mask
|
| 85 |
+
|
| 86 |
+
# compute log_prob
|
| 87 |
+
exp_logits = torch.exp(logits) * logits_mask
|
| 88 |
+
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))
|
| 89 |
+
|
| 90 |
+
# compute mean of log-likelihood over positive
|
| 91 |
+
# modified to handle edge cases when there is no positive pair
|
| 92 |
+
# for an anchor point.
|
| 93 |
+
# Edge case e.g.:-
|
| 94 |
+
# features of shape: [4,1,...]
|
| 95 |
+
# labels: [0,1,1,2]
|
| 96 |
+
# loss before mean: [nan, ..., ..., nan]
|
| 97 |
+
mask_pos_pairs = mask.sum(1)
|
| 98 |
+
mask_pos_pairs = torch.where(mask_pos_pairs < 1e-6, 1, mask_pos_pairs)
|
| 99 |
+
mean_log_prob_pos = (mask * log_prob).sum(1) / mask_pos_pairs
|
| 100 |
+
|
| 101 |
+
# loss
|
| 102 |
+
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
|
| 103 |
+
loss = loss.view(anchor_count, batch_size).mean()
|
| 104 |
+
|
| 105 |
+
return loss
|
models/projector.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
class SiameseProjector(torch.nn.Module):
|
| 4 |
+
def __init__(self, inner_features = None, act_layer = torch.nn.GELU):
|
| 5 |
+
super().__init__()
|
| 6 |
+
|
| 7 |
+
self.inner_features = inner_features
|
| 8 |
+
self.act_fcn = act_layer()
|
| 9 |
+
|
| 10 |
+
# Localisation branch.
|
| 11 |
+
self.input = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
|
| 12 |
+
self.projection = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
|
| 13 |
+
self.output = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
|
| 14 |
+
|
| 15 |
+
# Contrastive branch.
|
| 16 |
+
self.input_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
|
| 17 |
+
self.projection_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
|
| 18 |
+
self.output_con = torch.nn.Linear(in_features=inner_features, out_features=inner_features)
|
| 19 |
+
|
| 20 |
+
# Localisation head.
|
| 21 |
+
self.probe = torch.nn.Conv2d(in_channels=inner_features, out_channels=1, kernel_size=3)
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
|
| 25 |
+
# Localisation branch.
|
| 26 |
+
x = self.input(x)
|
| 27 |
+
x = self.act_fcn(x)
|
| 28 |
+
x = self.projection(x)
|
| 29 |
+
x = self.act_fcn(x)
|
| 30 |
+
x = self.output(x)
|
| 31 |
+
|
| 32 |
+
# Localisation head.
|
| 33 |
+
seg = self.probe(x.permute(0,2,1).reshape(x.shape[0], self.inner_features, int(x.shape[1]**0.5), int(x.shape[1]**0.5)))
|
| 34 |
+
|
| 35 |
+
# Contrastive branch.
|
| 36 |
+
y = self.input_con(x)
|
| 37 |
+
y = self.act_fcn(y)
|
| 38 |
+
y = self.projection_con(y)
|
| 39 |
+
y = self.act_fcn(y)
|
| 40 |
+
|
| 41 |
+
# Contrastive head.
|
| 42 |
+
feat = self.output_con(y)
|
| 43 |
+
|
| 44 |
+
return feat, seg
|
| 45 |
+
|
| 46 |
+
def forward_segmentation(self, x):
|
| 47 |
+
|
| 48 |
+
# Localisation branch.
|
| 49 |
+
x = self.input(x)
|
| 50 |
+
x = self.act_fcn(x)
|
| 51 |
+
x = self.projection(x)
|
| 52 |
+
x = self.act_fcn(x)
|
| 53 |
+
x = self.output(x)
|
| 54 |
+
|
| 55 |
+
# Localisation head.
|
| 56 |
+
seg = self.probe(x.permute(0,2,1).reshape(x.shape[0], self.inner_features, int(x.shape[1]**0.5), int(x.shape[1]**0.5)))
|
| 57 |
+
|
| 58 |
+
return seg
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
datasets==4.0.0
|
| 3 |
+
matplotlib==3.10.7
|
| 4 |
+
numpy==2.3.3
|
| 5 |
+
Pillow==11.3.0
|
| 6 |
+
scikit_learn==1.7.2
|
| 7 |
+
segmentation_models_pytorch==0.5.0
|
| 8 |
+
torch==2.8.0
|
| 9 |
+
torchvision==0.23.0
|
| 10 |
+
tqdm==4.67.1
|
| 11 |
+
transformers==4.57.1
|
| 12 |
+
wandb==0.22.2
|