| import sys | |
| sys.path.append("./BranchSBM") | |
| import torch.nn as nn | |
| import torch | |
| from typing import List, Optional | |
| class swish(nn.Module): | |
| def forward(self, x): | |
| return x * torch.sigmoid(x) | |
| ACTIVATION_MAP = { | |
| "relu": nn.ReLU, | |
| "sigmoid": nn.Sigmoid, | |
| "tanh": nn.Tanh, | |
| "selu": nn.SELU, | |
| "elu": nn.ELU, | |
| "lrelu": nn.LeakyReLU, | |
| "softplus": nn.Softplus, | |
| "silu": nn.SiLU, | |
| "swish": swish, | |
| } | |
| class SimpleDenseNet(nn.Module): | |
| def __init__( | |
| self, | |
| input_size: int, | |
| target_size: int, | |
| activation: str, | |
| batch_norm: bool = False, | |
| hidden_dims: List[int] = None, | |
| ): | |
| super().__init__() | |
| dims = [input_size, *hidden_dims, target_size] | |
| layers = [] | |
| for i in range(len(dims) - 2): | |
| layers.append(nn.Linear(dims[i], dims[i + 1])) | |
| if batch_norm: | |
| layers.append(nn.BatchNorm1d(dims[i + 1])) | |
| layers.append(ACTIVATION_MAP[activation]()) | |
| layers.append(nn.Linear(dims[-2], dims[-1])) | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.model(x) | |