import gradio as gr import torch import torchvision from PIL import Image import numpy as np import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') import io import os # Import your models from models.feature_extractor import FeatureExtractor, FeatureExtractorDepth from models.projector import SiameseProjector from models.fuser import DoubleCrossAttentionFusion from loaders.loader_utils import SquarePad # Configuration CHECKPOINT_PATH = './checkpoints' MODEL_LABEL = 'multimodal_15k_10inp' EPOCHS = 120 BATCH_SIZE = 4 IMAGE_SIZE = 896 # Load models print("Loading models...") device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") model_name = f'{MODEL_LABEL}_{EPOCHS}ep_{BATCH_SIZE}bs' rgb_transform = torchvision.transforms.Compose([ SquarePad(), torchvision.transforms.Resize((IMAGE_SIZE, IMAGE_SIZE), interpolation=torchvision.transforms.InterpolationMode.BICUBIC), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), torchvision.transforms.Lambda(lambda img: img.unsqueeze(0)), ]) fe_rgb = FeatureExtractor().to(device).eval() fe_depth = FeatureExtractorDepth().to(device).eval() fusion_block = DoubleCrossAttentionFusion(hidden_dim=fe_rgb.embed_dim).to(device) fusion_block.load_state_dict(torch.load( os.path.join(CHECKPOINT_PATH, f'fusion_block_{model_name}.pth'), weights_only=False, map_location=device )) fusion_block.eval() projector = SiameseProjector(inner_features=fe_rgb.embed_dim).to(device) projector.load_state_dict(torch.load( os.path.join(CHECKPOINT_PATH, f'projector_{model_name}.pth'), weights_only=False, map_location=device )) projector.eval() print("Models loaded successfully!") def detect_manipulation(image): """Process image and return heatmap""" if image is None: return None # Convert to PIL if isinstance(image, np.ndarray): rgb_input = Image.fromarray(image.astype('uint8')).convert('RGB') else: rgb_input = image.convert('RGB') original_size = rgb_input.size # Transform and process rgb = rgb_transform(rgb_input) rgb = rgb.to(device) with torch.no_grad(): rgb_feat = fe_rgb(rgb) depth_feat = fe_depth(rgb) fused_feat = fusion_block(rgb_feat, depth_feat) _, segmentation_map = projector(fused_feat) segmentation_map = torch.sigmoid(segmentation_map) # Resize back to original segmentation_map = torch.nn.functional.interpolate( segmentation_map, size=[max(original_size), max(original_size)], mode='bilinear' ).squeeze() segmentation_map = torchvision.transforms.functional.center_crop( segmentation_map, original_size[::-1] ) heatmap = segmentation_map.cpu().detach().numpy() # Create visualization with exact size # Calculate figure size to match image dimensions dpi = 100 fig_height = original_size[1] / dpi fig_width = original_size[0] / dpi fig = plt.figure(figsize=(fig_width, fig_height), dpi=dpi) ax = fig.add_axes([0, 0, 1, 1]) # No margins ax.imshow(heatmap, cmap='jet') ax.axis('off') # Convert to numpy array buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0, dpi=dpi) buf.seek(0) result_image = Image.open(buf) # Ensure exact size match by resizing if needed if result_image.size != original_size: result_image = result_image.resize(original_size, Image.LANCZOS) result_array = np.array(result_image) plt.close(fig) return result_array # Custom CSS for styling custom_css = """ .gradio-container { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; } #title { text-align: center; font-size: 2.5em; font-weight: bold; margin-bottom: 0.5em; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; } #subtitle { text-align: center; font-size: 1.2em; color: #666; margin-bottom: 1em; } #info { background: #e8f4fd; border-left: 4px solid #2196F3; padding: 15px; border-radius: 5px; margin-bottom: 20px; color: #1976D2; } """ # Create interface using Gradio 4.x Blocks with gr.Blocks(css=custom_css, title="RADAR - Image Manipulation Detection") as demo: gr.HTML('
ReliAble iDentification of inpainted AReas
') gr.HTML('''