import sys sys.path.append("./BranchSBM") import torch from networks.mlp_base import SimpleDenseNet class VelocityNet(SimpleDenseNet): def __init__(self, dim: int, *args, **kwargs): super().__init__(input_size=dim + 1, target_size=dim, *args, **kwargs) def forward(self, t, x): if t.dim() < 1 or t.shape[0] != x.shape[0]: t = t.repeat(x.shape[0])[:, None] if t.dim() < 2: t = t[:, None] x = torch.cat([t, x], dim=-1) return self.model(x)