Spaces:
Running
Running
| import torch | |
| class DoubleCrossAttentionFusion(torch.nn.Module): | |
| def __init__(self, hidden_dim=768, num_heads=8, dropout=0.1): | |
| super().__init__() | |
| # 1. Per-modality normalization. | |
| self.norm_rgb = torch.nn.LayerNorm(hidden_dim) | |
| self.norm_depth = torch.nn.LayerNorm(hidden_dim) | |
| # 2. Cross-attention. | |
| self.cross_attn_depth = torch.torch.nn.MultiheadAttention( | |
| embed_dim=hidden_dim, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| batch_first=True, | |
| ) | |
| self.cross_attn_rgb = torch.torch.nn.MultiheadAttention( | |
| embed_dim=hidden_dim, | |
| num_heads=num_heads, | |
| dropout=dropout, | |
| batch_first=True, | |
| ) | |
| # 3. Mixing. | |
| self.mixer = torch.nn.Sequential( | |
| torch.nn.Linear(hidden_dim * 2, hidden_dim), | |
| torch.nn.GELU(), | |
| torch.nn.Dropout(dropout) | |
| ) | |
| # 4. Output normalisation. | |
| self.out_norm = torch.nn.LayerNorm(hidden_dim) | |
| def forward(self, rgb_features, depth_features): | |
| # 1. Normalize inputs. | |
| rgb = self.norm_rgb(rgb_features) | |
| depth = self.norm_depth(depth_features) | |
| # 2a. Cross-attention (depth -> rgb). | |
| attn_out_depth, _ = self.cross_attn_depth( | |
| query=depth, | |
| key=rgb, | |
| value=rgb, | |
| need_weights=False | |
| ) | |
| # 2b. Cross-attention (rgb -> depth). | |
| attn_out_rgb, _ = self.cross_attn_rgb( | |
| query=rgb, | |
| key=depth, | |
| value=depth, | |
| need_weights=False | |
| ) | |
| # 3a. Residuals. | |
| depth_attn = depth + attn_out_depth | |
| rgb_attn = rgb + attn_out_rgb | |
| # 3b. Mixing. | |
| fused = self.mixer(torch.cat([depth_attn, rgb_attn], dim=-1)) | |
| # 4. Output normalisation. | |
| return self.out_norm(fused) | |