nemo-tts-api / app.py
cuhgrel's picture
Update app.py
fa7ae36 verified
import torch
import soundfile as sf
import io
import logging
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
# --- Library Imports ---
from nemo.collections.tts.models import FastPitchModel, HifiGanModel
# BaseCharsTokenizer is no longer needed since we aren't using the old NeMo Bikol model
# from nemo.collections.tts.torch.tts_tokenizers import BaseCharsTokenizer
from transformers import VitsModel, AutoTokenizer
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- 1. Initialize FastAPI App ---
app = FastAPI(
title="Multilingual TTS API",
description="A backend service to convert text to speech in English, Bikol, and Tagalog.",
)
# --- 2. Load Models on Startup ---
models = {}
device = "cuda" if torch.cuda.is_available() else "cpu"
@app.on_event("startup")
def load_models():
"""Load all models into memory when the application starts."""
logger.info("Loading models...")
try:
# --- NeMo Models (English Only now) ---
logger.info("Loading HiFi-GAN vocoder (for English)...")
models['hifigan'] = HifiGanModel.restore_from("models/hifigan_en.nemo").to(device)
models['hifigan'].eval()
logger.info("Loading English FastPitch model...")
models['en'] = FastPitchModel.restore_from("models/fastpitch_en.nemo").to(device)
models['en'].eval()
# --- Transformers MMS-TTS Models (Bikol & Tagalog) ---
# 1. NEW BIKOL MODEL (Fine-tuned VITS)
logger.info("Loading Bikol (bcl) MMS-TTS model from Hub...")
models['bikol_tokenizer'] = AutoTokenizer.from_pretrained("cuhgrel/bikol-mms-finetuned-v2")
models['bikol_model'] = VitsModel.from_pretrained("cuhgrel/bikol-mms-finetuned-v2").to(device)
# 2. TAGALOG MODEL (Base VITS)
logger.info("Loading Tagalog (tgl) MMS-TTS model from Hub...")
models['tgl_tokenizer'] = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl")
models['tgl_model'] = VitsModel.from_pretrained("facebook/mms-tts-tgl").to(device)
except Exception as e:
logger.error(f"FATAL: Could not load models. Error: {e}")
import traceback
traceback.print_exc()
raise e
logger.info("Model loading complete.")
# --- 3. Define API Request and Response Models ---
class TTSRequest(BaseModel):
text: str
language: str
# --- 4. Define the TTS API Endpoint ---
@app.post("/synthesize/")
def synthesize_speech(request: TTSRequest):
"""Generates speech from text using the selected language model."""
if not models:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Models are not loaded.")
lang = request.language.lower()
valid_langs = ['en', 'bikol', 'tgl']
if lang not in valid_langs:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid language. Use one of {valid_langs}")
try:
logger.info(f"--- STARTING SYNTHESIS for '{lang}' ---")
# --- CASE A: English (Uses NeMo) ---
if lang == 'en':
sample_rate = 22050
spectrogram_generator = models['en']
vocoder = models['hifigan']
with torch.no_grad():
parsed = spectrogram_generator.parse(request.text)
spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed)
audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)
audio_numpy = audio.to('cpu').detach().numpy().squeeze()
# --- CASE B: Bikol & Tagalog (Uses Hugging Face VITS) ---
elif lang in ['tgl', 'bikol']:
sample_rate = 16000 # MMS models are usually 16kHz
# Dynamically select the correct tokenizer and model from the dictionary
tokenizer = models[f'{lang}_tokenizer']
model = models[f'{lang}_model']
with torch.no_grad():
inputs = tokenizer(request.text, return_tensors="pt").to(device)
output = model(**inputs).waveform
audio_numpy = output.cpu().numpy().squeeze()
buffer = io.BytesIO()
sf.write(buffer, audio_numpy, samplerate=sample_rate, format='WAV')
buffer.seek(0)
logger.info(f"--- SYNTHESIS COMPLETE ---")
return StreamingResponse(buffer, media_type="audio/wav")
except Exception as e:
logger.error(f"Error during synthesis: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An error occurred during audio synthesis: {str(e)}")
# --- 5. Add a Root Endpoint for Health Check ---
@app.get("/")
def read_root():
available_languages = ['en', 'bikol', 'tgl']
return {"status": "Multilingual TTS Backend is running", "available_languages": available_languages, "device": device}
# --- 6. Add Model Status Endpoint ---
@app.get("/status")
def get_status():
"""Get the status of all loaded models."""
return {
"models_loaded": list(models.keys()),
"device": device,
"english_available": 'en' in models,
"bikol_available": 'bikol_model' in models, # Updated check
"tagalog_available": 'tgl_model' in models
}