BranchSBM / networks /mlp_base.py
sophiat44
model upload
5a87d8d
import sys
sys.path.append("./BranchSBM")
import torch.nn as nn
import torch
from typing import List, Optional
class swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
ACTIVATION_MAP = {
"relu": nn.ReLU,
"sigmoid": nn.Sigmoid,
"tanh": nn.Tanh,
"selu": nn.SELU,
"elu": nn.ELU,
"lrelu": nn.LeakyReLU,
"softplus": nn.Softplus,
"silu": nn.SiLU,
"swish": swish,
}
class SimpleDenseNet(nn.Module):
def __init__(
self,
input_size: int,
target_size: int,
activation: str,
batch_norm: bool = False,
hidden_dims: List[int] = None,
):
super().__init__()
dims = [input_size, *hidden_dims, target_size]
layers = []
for i in range(len(dims) - 2):
layers.append(nn.Linear(dims[i], dims[i + 1]))
if batch_norm:
layers.append(nn.BatchNorm1d(dims[i + 1]))
layers.append(ACTIVATION_MAP[activation]())
layers.append(nn.Linear(dims[-2], dims[-1]))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)