Spaces:
Runtime error
Runtime error
| import torch | |
| import jyutping | |
| from whisper_audio_classifier import WhisperAudioClassifier | |
| import librosa | |
| from transformers import WhisperFeatureExtractor | |
| feature_extractor = WhisperFeatureExtractor.from_pretrained(f"alvanlii/whisper-small-cantonese") | |
| feature_extractor.chunk_length = 3 | |
| # Instantiate the model | |
| device = torch.device("cpu") | |
| model = WhisperAudioClassifier().to(device) | |
| # Load the state dict | |
| state_dict = torch.load(f"whisper-small-encoder-bisyllabic-jyutping/checkpoints/model_epoch_1_step_1800.pth", map_location=device) | |
| # Load the state dict into the model | |
| model.load_state_dict(state_dict) | |
| # Set the model to evaluation mode | |
| model.eval() | |
| def predict(audio): | |
| features = feature_extractor(audio, sampling_rate=16000) | |
| with torch.no_grad(): | |
| inputs = torch.from_numpy(features['input_features'][0]).to(device) | |
| inputs = inputs.unsqueeze(0) # Add extra batch dimension in front | |
| outs = model(inputs) | |
| return [torch.softmax(tensor.squeeze(), dim=0).tolist() for tensor in outs] | |
| import gradio as gr | |
| import numpy as np | |
| def rank_initials(preds, k=3): | |
| ranked = sorted([((jyutping.inflate_initial(i) if jyutping.inflate_initial(i) != '' else 'β '), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True) | |
| return dict(ranked[:k]) | |
| def rank_nucli(preds, k=3): | |
| ranked = sorted([((jyutping.inflate_nucleus(i) if jyutping.inflate_nucleus(i) != '' else 'β '), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True) | |
| return dict(ranked[:k]) | |
| def rank_codas(preds, k=3): | |
| ranked = sorted([((jyutping.inflate_coda(i) if jyutping.inflate_coda(i) != '' else 'β '), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True) | |
| return dict(ranked[:k]) | |
| def rank_tones(preds, k=3): | |
| ranked = sorted([(str(i + 1), p) for i, p in enumerate(preds)], key=lambda x: x[1], reverse=True) | |
| return dict(ranked[:k]) | |
| def classify_audio(audio): | |
| sampling_rate, audio = audio | |
| audio = audio.astype(np.float32) | |
| audio /= np.max(np.abs(audio)) | |
| audio_resampled = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000) | |
| preds = predict(torch.from_numpy(audio_resampled)) | |
| return [ | |
| rank_initials(preds[0]), | |
| rank_nucli(preds[1]), | |
| rank_codas(preds[2]), | |
| rank_tones(preds[3]), | |
| rank_initials(preds[4]), | |
| rank_nucli(preds[5]), | |
| rank_codas(preds[6]), | |
| rank_tones(preds[7]), | |
| ] | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| gr.Label("Please say a Cantonese word with exactly 2 characters, like δ½ ε₯½, into the microphone and click submit to see model predictions.\nNote that the predictions are not very reliable currently.") | |
| with gr.Row(): | |
| inputs = gr.Audio(sources=["microphone"], type="numpy", label="Input Audio") | |
| submit_btn = gr.Button("Submit") | |
| with gr.Row(): | |
| with gr.Column(): | |
| outputs_left = [ | |
| gr.Label(label="Initial 1"), | |
| gr.Label(label="Nucleus 1"), | |
| gr.Label(label="Coda 1"), | |
| gr.Label(label="Tone 1"), | |
| ] | |
| with gr.Column(): | |
| outputs_right = [ | |
| gr.Label(label="Initial 2"), | |
| gr.Label(label="Nucleus 2"), | |
| gr.Label(label="Coda 2"), | |
| gr.Label(label="Tone 2"), | |
| ] | |
| submit_btn.click(fn=classify_audio, inputs=inputs, outputs=outputs_left+outputs_right) | |
| demo.launch() | |