| import sys | |
| sys.path.append("./BranchSBM") | |
| import torch | |
| from networks.mlp_base import SimpleDenseNet | |
| class VelocityNet(SimpleDenseNet): | |
| def __init__(self, dim: int, *args, **kwargs): | |
| super().__init__(input_size=dim + 1, target_size=dim, *args, **kwargs) | |
| def forward(self, t, x): | |
| if t.dim() < 1 or t.shape[0] != x.shape[0]: | |
| t = t.repeat(x.shape[0])[:, None] | |
| if t.dim() < 2: | |
| t = t[:, None] | |
| x = torch.cat([t, x], dim=-1) | |
| return self.model(x) | |