import sys sys.path.append("./BranchSBM") import torch.nn as nn import torch from typing import List, Optional class swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) ACTIVATION_MAP = { "relu": nn.ReLU, "sigmoid": nn.Sigmoid, "tanh": nn.Tanh, "selu": nn.SELU, "elu": nn.ELU, "lrelu": nn.LeakyReLU, "softplus": nn.Softplus, "silu": nn.SiLU, "swish": swish, } class SimpleDenseNet(nn.Module): def __init__( self, input_size: int, target_size: int, activation: str, batch_norm: bool = False, hidden_dims: List[int] = None, ): super().__init__() dims = [input_size, *hidden_dims, target_size] layers = [] for i in range(len(dims) - 2): layers.append(nn.Linear(dims[i], dims[i + 1])) if batch_norm: layers.append(nn.BatchNorm1d(dims[i + 1])) layers.append(ACTIVATION_MAP[activation]()) layers.append(nn.Linear(dims[-2], dims[-1])) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x)