|
|
import torch |
|
|
import math |
|
|
import torch.nn.functional as F |
|
|
from torch import nn, einsum |
|
|
from inspect import isfunction |
|
|
|
|
|
|
|
|
def exists(val): |
|
|
return val is not None |
|
|
|
|
|
def uniq(arr): |
|
|
return{el: True for el in arr}.keys() |
|
|
|
|
|
|
|
|
def default(val, d): |
|
|
if exists(val): |
|
|
return val |
|
|
return d() if isfunction(d) else d |
|
|
|
|
|
|
|
|
def max_neg_value(t): |
|
|
return -torch.finfo(t.dtype).max |
|
|
|
|
|
|
|
|
def init_(tensor): |
|
|
dim = tensor.shape[-1] |
|
|
std = 1 / math.sqrt(dim) |
|
|
tensor.uniform_(-std, std) |
|
|
return tensor |
|
|
|
|
|
|
|
|
|
|
|
class GEGLU(nn.Module): |
|
|
def __init__(self, dim_in, dim_out): |
|
|
super().__init__() |
|
|
self.proj = nn.Linear(dim_in, dim_out * 2) |
|
|
|
|
|
def forward(self, x): |
|
|
x, gate = self.proj(x).chunk(2, dim=-1) |
|
|
return x * F.gelu(gate) |
|
|
|
|
|
|
|
|
class FeedForward(nn.Module): |
|
|
def __init__(self, dim, dim_out=None, mult=4, glu=True, dropout=0.): |
|
|
super().__init__() |
|
|
inner_dim = int(dim * mult) |
|
|
dim_out = default(dim_out, dim) |
|
|
project_in = nn.Sequential( |
|
|
nn.Linear(dim, inner_dim), |
|
|
nn.GELU() |
|
|
) if not glu else GEGLU(dim, inner_dim) |
|
|
|
|
|
self.net = nn.Sequential( |
|
|
project_in, |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(inner_dim, dim_out) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.net(x) |
|
|
|
|
|
|
|
|
class SelfAttention(nn.Module): |
|
|
def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.): |
|
|
super().__init__() |
|
|
inner_dim = dim_head * heads |
|
|
self.scale = dim_head ** -0.5 |
|
|
self.heads = heads |
|
|
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) |
|
|
self.to_k = nn.Linear(query_dim, inner_dim, bias=False) |
|
|
self.to_v = nn.Linear(query_dim, inner_dim, bias=False) |
|
|
|
|
|
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) |
|
|
|
|
|
def forward(self, x): |
|
|
q = self.to_q(x) |
|
|
k = self.to_k(x) |
|
|
v = self.to_v(x) |
|
|
|
|
|
B, N, HC = q.shape |
|
|
H = self.heads |
|
|
C = HC // H |
|
|
|
|
|
q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) |
|
|
k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) |
|
|
v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) |
|
|
|
|
|
sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale |
|
|
attn = sim.softmax(dim=-1) |
|
|
|
|
|
out = torch.einsum('b i j, b j c -> b i c', attn, v) |
|
|
out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) |
|
|
|
|
|
return self.to_out(out) |
|
|
|
|
|
|
|
|
|
|
|
class Resampler(nn.Module): |
|
|
def __init__(self, query_dim=1024, n_heads=8, d_head=64): |
|
|
super().__init__() |
|
|
|
|
|
self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head) |
|
|
self.ff = FeedForward(query_dim, glu=True) |
|
|
|
|
|
self.norm1 = nn.LayerNorm(query_dim) |
|
|
self.norm2 = nn.LayerNorm(query_dim) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.attn(self.norm1(x)) |
|
|
x = x + self.ff(self.norm2(x)) |
|
|
return x |