File size: 4,626 Bytes
ff42fba
f7212d2
 
069b4ed
eba970d
3627a6f
069b4ed
85d6c22
bafb16f
eba970d
391a015
bafb16f
 
 
 
7460079
bafb16f
3627a6f
 
 
2fa52c3
bafb16f
7460079
 
bafb16f
 
 
42cd95b
3627a6f
 
 
 
 
dca9a76
f1eefe4
 
fea3815
42cd95b
797ee59
 
7460079
bafb16f
 
eba970d
bafb16f
de7eff6
7460079
 
de7eff6
 
bafb16f
de7eff6
7460079
 
069b4ed
3627a6f
 
 
7460079
2fa52c3
 
3627a6f
7460079
3627a6f
bafb16f
3627a6f
42cd95b
7460079
 
 
 
de7eff6
7460079
bafb16f
7460079
069b4ed
7460079
 
 
 
dca9a76
391a015
7460079
d95af38
7460079
797ee59
eba970d
de7eff6
7460079
 
797ee59
de7eff6
7460079
 
 
 
 
 
de7eff6
bafb16f
797ee59
7460079
de7eff6
 
eba970d
 
7460079
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import annotations
import torch
import torchaudio
import gradio as gr
import spaces
from transformers import AutoModel, AutoModelForAudioClassification, Wav2Vec2FeatureExtractor

DESCRIPTION = "STT"
device = "cuda" if torch.cuda.is_available() else "cpu"

# --- Model Loading ---
print("Loading ASR model (IndicConformer)...")
asr_model_id = "ai4bharat/indic-conformer-600m-multilingual"
asr_model = AutoModel.from_pretrained(asr_model_id, trust_remote_code=True).to(device)
asr_model.eval()
print(" ASR Model loaded.")

print("\nLoading Language ID model (MMS-LID-1024)...")
lid_model_id = "facebook/mms-lid-1024"
lid_processor = Wav2Vec2FeatureExtractor.from_pretrained(lid_model_id)
lid_model = AutoModelForAudioClassification.from_pretrained(lid_model_id).to(device)
lid_model.eval()
print(" Language ID Model loaded.")


# --- Language Mappings ---
LID_TO_ASR_LANG_MAP = {
    # MMS-style codes (e.g., hin_Deva)
    "asm_Beng": "as", "ben_Beng": "bn", "brx_Deva": "br", "doi_Deva": "doi",
    "guj_Gujr": "gu", "hin_Deva": "hi", "kan_Knda": "kn", "kas_Arab": "ks",
    "kas_Deva": "ks", "gom_Deva": "kok", "mai_Deva": "mai", "mal_Mlym": "ml",
    "mni_Beng": "mni", "mar_Deva": "mr", "nep_Deva": "ne", "ory_Orya": "or",
    "pan_Guru": "pa", "san_Deva": "sa", "sat_Olck": "sat", "snd_Arab": "sd",
    "tam_Taml": "ta", "tel_Telu": "te", "urd_Arab": "ur",
    "asm": "as", "ben": "bn", "brx": "br", "doi": "doi", "guj": "gu", "hin": "hi",
    "kan": "kn", "kas": "ks", "gom": "kok", "mai": "mai", "mal": "ml", "mni": "mni",
    "mar": "mr", "npi": "ne", "ory": "or", "pan": "pa", "san": "sa", "sat": "sat",
    "snd": "sd", "tam": "ta", "tel": "te", "urd": "ur", "eng": "en"
}

ASR_CODE_TO_NAME = { "as": "Assamese", "bn": "Bengali", "br": "Bodo", "doi": "Dogri", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ks": "Kashmiri", "kok": "Konkani", "mai": "Maithili", "ml": "Malayalam", "mni": "Manipuri", "mr": "Marathi", "ne": "Nepali", "or": "Odia", "pa": "Punjabi", "sa": "Sanskrit", "sat": "Santali", "sd": "Sindhi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu", "en": "English"}


@spaces.GPU
def transcribe_audio_with_lid(audio_path):
    if not audio_path:
        return "Please provide an audio file.", "", ""

    try:
        waveform, sr = torchaudio.load(audio_path)
        waveform_16k = torchaudio.functional.resample(waveform, sr, 16000)
    except Exception as e:
        return f"Error loading audio: {e}", "", ""

    try:
        inputs = lid_processor(waveform_16k.squeeze(), sampling_rate=16000, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = lid_model(**inputs)

        logits = outputs[0]
        predicted_lid_id = logits.argmax(-1).item()
        detected_lid_code = lid_model.config.id2label[predicted_lid_id]

        asr_lang_code = LID_TO_ASR_LANG_MAP.get(detected_lid_code)
        
        if not asr_lang_code:
            detected_lang_str = f"Detected '{detected_lid_code}', which is not supported by the ASR model."
            return detected_lang_str, "N/A", "N/A"

        detected_lang_str = f"Detected Language: {ASR_CODE_TO_NAME.get(asr_lang_code, 'Unknown')}"

        with torch.no_grad():
            transcription_ctc = asr_model(waveform_16k.to(device), asr_lang_code, "ctc")
            transcription_rnnt = asr_model(waveform_16k.to(device), asr_lang_code, "rnnt")

    except Exception as e:
        return f"Error during processing: {str(e)}", "", ""

    return detected_lang_str, transcription_ctc.strip(), transcription_rnnt.strip()


# --- Gradio UI ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(f"## {DESCRIPTION}")
    gr.Markdown("Upload or record audio in any of the 22 supported Indian languages. The app will automatically detect the language and provide the transcription.")
    
    with gr.Row():
        with gr.Column(scale=1):
            audio = gr.Audio(label="Upload or Record Audio", type="filepath")
            transcribe_btn = gr.Button("Transcribe", variant="primary")
        
        with gr.Column(scale=2):
            detected_lang_output = gr.Label(label="Language Detection Result")
            gr.Markdown("### RNNT Transcription")
            rnnt_output = gr.Textbox(lines=3, label="RNNT Output")
            gr.Markdown("### CTC Transcription")
            ctc_output = gr.Textbox(lines=3, label="CTC Output")

    transcribe_btn.click(
        fn=transcribe_audio_with_lid, 
        inputs=[audio], 
        outputs=[detected_lang_output, ctc_output, rnnt_output],
        api_name="transcribe"
    )

if __name__ == "__main__":
    demo.queue().launch()