RADAR-demo / models /fuser.py
arcanoXIII's picture
Upload 13 files
7e08bf1 verified
import torch
class DoubleCrossAttentionFusion(torch.nn.Module):
def __init__(self, hidden_dim=768, num_heads=8, dropout=0.1):
super().__init__()
# 1. Per-modality normalization.
self.norm_rgb = torch.nn.LayerNorm(hidden_dim)
self.norm_depth = torch.nn.LayerNorm(hidden_dim)
# 2. Cross-attention.
self.cross_attn_depth = torch.torch.nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
self.cross_attn_rgb = torch.torch.nn.MultiheadAttention(
embed_dim=hidden_dim,
num_heads=num_heads,
dropout=dropout,
batch_first=True,
)
# 3. Mixing.
self.mixer = torch.nn.Sequential(
torch.nn.Linear(hidden_dim * 2, hidden_dim),
torch.nn.GELU(),
torch.nn.Dropout(dropout)
)
# 4. Output normalisation.
self.out_norm = torch.nn.LayerNorm(hidden_dim)
def forward(self, rgb_features, depth_features):
# 1. Normalize inputs.
rgb = self.norm_rgb(rgb_features)
depth = self.norm_depth(depth_features)
# 2a. Cross-attention (depth -> rgb).
attn_out_depth, _ = self.cross_attn_depth(
query=depth,
key=rgb,
value=rgb,
need_weights=False
)
# 2b. Cross-attention (rgb -> depth).
attn_out_rgb, _ = self.cross_attn_rgb(
query=rgb,
key=depth,
value=depth,
need_weights=False
)
# 3a. Residuals.
depth_attn = depth + attn_out_depth
rgb_attn = rgb + attn_out_rgb
# 3b. Mixing.
fused = self.mixer(torch.cat([depth_attn, rgb_attn], dim=-1))
# 4. Output normalisation.
return self.out_norm(fused)