Spaces:
Running
Running
File size: 4,050 Bytes
fc7b4a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import torch
import numpy as np
import torch.nn as nn
try:
from torch.amp import autocast
torch_amp_new = True
except:
from torch.cuda.amp import autocast
torch_amp_new = False
from torchaudio.transforms import AmplitudeToDB, MelSpectrogram
class FeatureExtractor(nn.Module):
"""
Converts raw audio waveforms into normalized mel-spectrogram features.
Args:
cfg (object): Configuration object containing parameters for audio
processing and spectrogram generation.
"""
def __init__(
self,
cfg,
):
super().__init__()
self.audio2melspec = MelSpectrogram(
n_fft=cfg.melspec.n_fft,
hop_length=cfg.melspec.hop_length,
win_length=cfg.melspec.win_length,
n_mels=cfg.melspec.n_mels,
sample_rate=cfg.audio.sample_rate,
f_min=cfg.melspec.f_min,
f_max=cfg.melspec.f_max,
power=cfg.melspec.power,
)
self.amplitude_to_db = AmplitudeToDB(top_db=cfg.melspec.top_db)
if cfg.melspec.norm == "mean_std":
self.normalizer = MeanStdNorm()
elif cfg.melspec.norm == "min_max":
self.normalizer = MinMaxNorm()
elif cfg.melspec.norm == "simple":
self.normalizer = SimpleNorm()
else:
self.normalizer = nn.Identity()
def forward(self, x):
"""
Forward pass of the feature extractor.
Args:
x (torch.Tensor): Raw audio input of shape (batch_size, num_samples).
Returns:
torch.Tensor: Normalized mel-spectrogram features of shape
(batch_size, n_mels, time).
"""
with (
autocast("cuda", enabled=False)
if torch_amp_new
else autocast(enabled=False)
):
melspec = self.audio2melspec(x.float())
melspec = self.amplitude_to_db(melspec)
melspec = self.normalizer(melspec)
return melspec
class MinMaxNorm(nn.Module):
"""
Applies min-max normalization to input tensors.
Args:
eps (float, optional): Small constant to prevent division by zero. Defaults to 1e-6.
"""
def __init__(self, eps=1e-6):
super().__init__()
self.eps = eps
def forward(self, X):
"""
Forward pass of min-max normalization.
Args:
X (torch.Tensor): Input tensor of shape (batch_size, n_mels, time).
Returns:
torch.Tensor: Min-max normalized tensor of the same shape.
"""
min_ = torch.amin(X, dim=(1, 2), keepdim=True)
max_ = torch.amax(X, dim=(1, 2), keepdim=True)
return (X - min_) / (max_ - min_ + self.eps)
class SimpleNorm(nn.Module):
"""
Applies a simple linear normalization to input tensors.
Normalizes values by shifting and scaling using fixed constants:
(x - 40) / 80.
"""
def __init__(self):
super().__init__()
def forward(self, x):
"""
Forward pass of simple normalization.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, n_mels, time).
Returns:
torch.Tensor: Normalized tensor of the same shape.
"""
return (x - 40) / 80
class MeanStdNorm(nn.Module):
"""
Applies mean-std normalization to input tensors.
Args:
eps (float, optional): Small constant to prevent division by zero. Defaults to 1e-6.
"""
def __init__(self, eps=1e-6):
super().__init__()
self.eps = eps
def forward(self, X):
"""
Forward pass of mean-std normalization.
Args:
X (torch.Tensor): Input tensor of shape (batch_size, n_mels, time).
Returns:
torch.Tensor: Normalized tensor of the same shape.
"""
mean = X.mean((1, 2), keepdim=True)
std = X.reshape(X.size(0), -1).std(1, keepdim=True).unsqueeze(-1)
return (X - mean) / (std + self.eps)
|