Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,8 +8,8 @@ from fastapi.responses import StreamingResponse
|
|
| 8 |
|
| 9 |
# --- Library Imports ---
|
| 10 |
from nemo.collections.tts.models import FastPitchModel, HifiGanModel
|
| 11 |
-
|
| 12 |
-
#
|
| 13 |
from transformers import VitsModel, AutoTokenizer
|
| 14 |
|
| 15 |
# Configure logging
|
|
@@ -31,8 +31,8 @@ def load_models():
|
|
| 31 |
"""Load all models into memory when the application starts."""
|
| 32 |
logger.info("Loading models...")
|
| 33 |
try:
|
| 34 |
-
# --- NeMo Models ---
|
| 35 |
-
logger.info("Loading HiFi-GAN vocoder...")
|
| 36 |
models['hifigan'] = HifiGanModel.restore_from("models/hifigan_en.nemo").to(device)
|
| 37 |
models['hifigan'].eval()
|
| 38 |
|
|
@@ -40,19 +40,14 @@ def load_models():
|
|
| 40 |
models['en'] = FastPitchModel.restore_from("models/fastpitch_en.nemo").to(device)
|
| 41 |
models['en'].eval()
|
| 42 |
|
| 43 |
-
|
| 44 |
-
models['bikol'] = FastPitchModel.restore_from("models/fastpitch_bikol_corrected.nemo").to(device)
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
'y', 'z', 'à', 'á', 'â', 'é', 'ì', 'í', 'î', 'ñ', 'ò', 'ó', 'ô', 'ú', '’'
|
| 51 |
-
]
|
| 52 |
-
models['bikol'].tokenizer = BaseCharsTokenizer(chars=BIKOL_CHARS)
|
| 53 |
-
models['bikol'].eval()
|
| 54 |
|
| 55 |
-
#
|
| 56 |
logger.info("Loading Tagalog (tgl) MMS-TTS model from Hub...")
|
| 57 |
models['tgl_tokenizer'] = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl")
|
| 58 |
models['tgl_model'] = VitsModel.from_pretrained("facebook/mms-tts-tgl").to(device)
|
|
@@ -84,9 +79,10 @@ def synthesize_speech(request: TTSRequest):
|
|
| 84 |
try:
|
| 85 |
logger.info(f"--- STARTING SYNTHESIS for '{lang}' ---")
|
| 86 |
|
| 87 |
-
|
|
|
|
| 88 |
sample_rate = 22050
|
| 89 |
-
spectrogram_generator = models[
|
| 90 |
vocoder = models['hifigan']
|
| 91 |
with torch.no_grad():
|
| 92 |
parsed = spectrogram_generator.parse(request.text)
|
|
@@ -94,10 +90,14 @@ def synthesize_speech(request: TTSRequest):
|
|
| 94 |
audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)
|
| 95 |
audio_numpy = audio.to('cpu').detach().numpy().squeeze()
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
with torch.no_grad():
|
| 102 |
inputs = tokenizer(request.text, return_tensors="pt").to(device)
|
| 103 |
output = model(**inputs).waveform
|
|
@@ -126,4 +126,10 @@ def read_root():
|
|
| 126 |
@app.get("/status")
|
| 127 |
def get_status():
|
| 128 |
"""Get the status of all loaded models."""
|
| 129 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# --- Library Imports ---
|
| 10 |
from nemo.collections.tts.models import FastPitchModel, HifiGanModel
|
| 11 |
+
# BaseCharsTokenizer is no longer needed since we aren't using the old NeMo Bikol model
|
| 12 |
+
# from nemo.collections.tts.torch.tts_tokenizers import BaseCharsTokenizer
|
| 13 |
from transformers import VitsModel, AutoTokenizer
|
| 14 |
|
| 15 |
# Configure logging
|
|
|
|
| 31 |
"""Load all models into memory when the application starts."""
|
| 32 |
logger.info("Loading models...")
|
| 33 |
try:
|
| 34 |
+
# --- NeMo Models (English Only now) ---
|
| 35 |
+
logger.info("Loading HiFi-GAN vocoder (for English)...")
|
| 36 |
models['hifigan'] = HifiGanModel.restore_from("models/hifigan_en.nemo").to(device)
|
| 37 |
models['hifigan'].eval()
|
| 38 |
|
|
|
|
| 40 |
models['en'] = FastPitchModel.restore_from("models/fastpitch_en.nemo").to(device)
|
| 41 |
models['en'].eval()
|
| 42 |
|
| 43 |
+
# --- Transformers MMS-TTS Models (Bikol & Tagalog) ---
|
|
|
|
| 44 |
|
| 45 |
+
# 1. NEW BIKOL MODEL (Fine-tuned VITS)
|
| 46 |
+
logger.info("Loading Bikol (bcl) MMS-TTS model from Hub...")
|
| 47 |
+
models['bikol_tokenizer'] = AutoTokenizer.from_pretrained("cuhgrel/bikol-mms-finetuned")
|
| 48 |
+
models['bikol_model'] = VitsModel.from_pretrained("cuhgrel/bikol-mms-finetuned").to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
+
# 2. TAGALOG MODEL (Base VITS)
|
| 51 |
logger.info("Loading Tagalog (tgl) MMS-TTS model from Hub...")
|
| 52 |
models['tgl_tokenizer'] = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl")
|
| 53 |
models['tgl_model'] = VitsModel.from_pretrained("facebook/mms-tts-tgl").to(device)
|
|
|
|
| 79 |
try:
|
| 80 |
logger.info(f"--- STARTING SYNTHESIS for '{lang}' ---")
|
| 81 |
|
| 82 |
+
# --- CASE A: English (Uses NeMo) ---
|
| 83 |
+
if lang == 'en':
|
| 84 |
sample_rate = 22050
|
| 85 |
+
spectrogram_generator = models['en']
|
| 86 |
vocoder = models['hifigan']
|
| 87 |
with torch.no_grad():
|
| 88 |
parsed = spectrogram_generator.parse(request.text)
|
|
|
|
| 90 |
audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)
|
| 91 |
audio_numpy = audio.to('cpu').detach().numpy().squeeze()
|
| 92 |
|
| 93 |
+
# --- CASE B: Bikol & Tagalog (Uses Hugging Face VITS) ---
|
| 94 |
+
elif lang in ['tgl', 'bikol']:
|
| 95 |
+
sample_rate = 16000 # MMS models are usually 16kHz
|
| 96 |
+
|
| 97 |
+
# Dynamically select the correct tokenizer and model from the dictionary
|
| 98 |
+
tokenizer = models[f'{lang}_tokenizer']
|
| 99 |
+
model = models[f'{lang}_model']
|
| 100 |
+
|
| 101 |
with torch.no_grad():
|
| 102 |
inputs = tokenizer(request.text, return_tensors="pt").to(device)
|
| 103 |
output = model(**inputs).waveform
|
|
|
|
| 126 |
@app.get("/status")
|
| 127 |
def get_status():
|
| 128 |
"""Get the status of all loaded models."""
|
| 129 |
+
return {
|
| 130 |
+
"models_loaded": list(models.keys()),
|
| 131 |
+
"device": device,
|
| 132 |
+
"english_available": 'en' in models,
|
| 133 |
+
"bikol_available": 'bikol_model' in models, # Updated check
|
| 134 |
+
"tagalog_available": 'tgl_model' in models
|
| 135 |
+
}
|