BranchSBM / networks /flow_mlp.py
sophiat44
model upload
5a87d8d
raw
history blame contribute delete
514 Bytes
import sys
sys.path.append("./BranchSBM")
import torch
from networks.mlp_base import SimpleDenseNet
class VelocityNet(SimpleDenseNet):
def __init__(self, dim: int, *args, **kwargs):
super().__init__(input_size=dim + 1, target_size=dim, *args, **kwargs)
def forward(self, t, x):
if t.dim() < 1 or t.shape[0] != x.shape[0]:
t = t.repeat(x.shape[0])[:, None]
if t.dim() < 2:
t = t[:, None]
x = torch.cat([t, x], dim=-1)
return self.model(x)