Spaces:
Runtime error
Runtime error
| from transformers import WhisperModel | |
| from torch import nn | |
| import torch | |
| from jyutping import jyutping_initials, jyutping_nuclei, jyutping_codas | |
| class WhisperAudioClassifier(nn.Module): | |
| def __init__(self): | |
| super(WhisperAudioClassifier, self).__init__() | |
| # Load the Whisper model encoder | |
| self.whisper_encoder = WhisperModel.from_pretrained(f"alvanlii/whisper-small-cantonese", device_map="auto").get_encoder() | |
| self.whisper_encoder.eval() # Set the Whisper model to evaluation mode | |
| # Assuming we know the output size of the Whisper encoder, or it needs to be determined | |
| whisper_output_size = 768 | |
| self.tone_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True) | |
| self.initial_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True) | |
| self.nucleus_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True) | |
| self.coda_attention = nn.MultiheadAttention(whisper_output_size, 8, dropout=0.1, batch_first=True) | |
| self.pool = nn.AdaptiveAvgPool1d(1) | |
| # Separate output layers for each class set | |
| self.initial_fc1 = nn.Linear(whisper_output_size, len(jyutping_initials)) | |
| self.nucleus_fc1 = nn.Linear(whisper_output_size, len(jyutping_nuclei)) | |
| self.coda_fc1 = nn.Linear(whisper_output_size, len(jyutping_codas)) | |
| self.tone_fc1 = nn.Linear(whisper_output_size, 6) | |
| self.initial_fc2 = nn.Linear(whisper_output_size, len(jyutping_initials)) | |
| self.nucleus_fc2 = nn.Linear(whisper_output_size, len(jyutping_nuclei)) | |
| self.coda_fc2 = nn.Linear(whisper_output_size, len(jyutping_codas)) | |
| self.tone_fc2 = nn.Linear(whisper_output_size, 6) | |
| self.dropout = nn.Dropout(0.1) | |
| def forward(self, x): | |
| # Use Whisper model to encode audio input | |
| with torch.no_grad(): # No need to track gradients for the encoder | |
| x = self.whisper_encoder(x).last_hidden_state | |
| initial, _ = self.initial_attention(x, x, x, need_weights=False) | |
| initial = initial.permute(0, 2, 1) # [batch_size, channels, seq_len] | |
| initial = self.pool(initial) # [batch_size, channels, 1] | |
| initial = initial.squeeze(-1) # [batch_size, channels] | |
| initial_out1 = self.initial_fc1(initial) | |
| initial_out2 = self.initial_fc2(initial) | |
| nucleus, _ = self.nucleus_attention(x, x, x, need_weights=False) | |
| nucleus = nucleus.permute(0, 2, 1) # [batch_size, channels, seq_len] | |
| nucleus = self.pool(nucleus) # [batch_size, channels, 1] | |
| nucleus = nucleus.squeeze(-1) # [batch_size, channels] | |
| nucleus_out1 = self.nucleus_fc1(nucleus) | |
| nucleus_out2 = self.nucleus_fc2(nucleus) | |
| coda, _ = self.coda_attention(x, x, x, need_weights=False) | |
| coda = coda.permute(0, 2, 1) # [batch_size, channels, seq_len] | |
| coda = self.pool(coda) # [batch_size, channels, 1] | |
| coda = coda.squeeze(-1) # [batch_size, channels] | |
| coda_out1 = self.coda_fc1(coda) | |
| coda_out2 = self.coda_fc2(coda) | |
| tone, _ = self.tone_attention(x, x, x, need_weights=False) | |
| tone = tone.permute(0, 2, 1) # [batch_size, channels, seq_len] | |
| tone = self.pool(tone) # [batch_size, channels, 1] | |
| tone = tone.squeeze(-1) # [batch_size, channels] | |
| tone_out1 = self.tone_fc1(tone) | |
| tone_out2 = self.tone_fc2(tone) | |
| return initial_out1, nucleus_out1, coda_out1, tone_out1, initial_out2, nucleus_out2, coda_out2, tone_out2 | |