import torch.nn as nn ################################################################################# # Length Estimator # ################################################################################# class LengthEstimator(nn.Module): def __init__(self, input_size, output_size): super(LengthEstimator, self).__init__() nd = 512 self.output = nn.Sequential( nn.Linear(input_size, nd), nn.LayerNorm(nd), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.2), nn.Linear(nd, nd // 2), nn.LayerNorm(nd // 2), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(0.2), nn.Linear(nd // 2, nd // 4), nn.LayerNorm(nd // 4), nn.LeakyReLU(0.2, inplace=True), nn.Linear(nd // 4, output_size) ) self.output.apply(self.__init_weights) def __init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): module.weight.data.normal_(mean=0.0, std=0.02) if isinstance(module, nn.Linear) and module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0) def forward(self, text_emb): return self.output(text_emb)