Spaces:
Running
Running
| import torch | |
| import soundfile as sf | |
| import io | |
| import logging | |
| from fastapi import FastAPI, HTTPException, status | |
| from pydantic import BaseModel | |
| from fastapi.responses import StreamingResponse | |
| # --- Library Imports --- | |
| from nemo.collections.tts.models import FastPitchModel, HifiGanModel | |
| # BaseCharsTokenizer is no longer needed since we aren't using the old NeMo Bikol model | |
| # from nemo.collections.tts.torch.tts_tokenizers import BaseCharsTokenizer | |
| from transformers import VitsModel, AutoTokenizer | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --- 1. Initialize FastAPI App --- | |
| app = FastAPI( | |
| title="Multilingual TTS API", | |
| description="A backend service to convert text to speech in English, Bikol, and Tagalog.", | |
| ) | |
| # --- 2. Load Models on Startup --- | |
| models = {} | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_models(): | |
| """Load all models into memory when the application starts.""" | |
| logger.info("Loading models...") | |
| try: | |
| # --- NeMo Models (English Only now) --- | |
| logger.info("Loading HiFi-GAN vocoder (for English)...") | |
| models['hifigan'] = HifiGanModel.restore_from("models/hifigan_en.nemo").to(device) | |
| models['hifigan'].eval() | |
| logger.info("Loading English FastPitch model...") | |
| models['en'] = FastPitchModel.restore_from("models/fastpitch_en.nemo").to(device) | |
| models['en'].eval() | |
| # --- Transformers MMS-TTS Models (Bikol & Tagalog) --- | |
| # 1. NEW BIKOL MODEL (Fine-tuned VITS) | |
| logger.info("Loading Bikol (bcl) MMS-TTS model from Hub...") | |
| models['bikol_tokenizer'] = AutoTokenizer.from_pretrained("cuhgrel/bikol-mms-finetuned-v2") | |
| models['bikol_model'] = VitsModel.from_pretrained("cuhgrel/bikol-mms-finetuned-v2").to(device) | |
| # 2. TAGALOG MODEL (Base VITS) | |
| logger.info("Loading Tagalog (tgl) MMS-TTS model from Hub...") | |
| models['tgl_tokenizer'] = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl") | |
| models['tgl_model'] = VitsModel.from_pretrained("facebook/mms-tts-tgl").to(device) | |
| except Exception as e: | |
| logger.error(f"FATAL: Could not load models. Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise e | |
| logger.info("Model loading complete.") | |
| # --- 3. Define API Request and Response Models --- | |
| class TTSRequest(BaseModel): | |
| text: str | |
| language: str | |
| # --- 4. Define the TTS API Endpoint --- | |
| def synthesize_speech(request: TTSRequest): | |
| """Generates speech from text using the selected language model.""" | |
| if not models: | |
| raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Models are not loaded.") | |
| lang = request.language.lower() | |
| valid_langs = ['en', 'bikol', 'tgl'] | |
| if lang not in valid_langs: | |
| raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid language. Use one of {valid_langs}") | |
| try: | |
| logger.info(f"--- STARTING SYNTHESIS for '{lang}' ---") | |
| # --- CASE A: English (Uses NeMo) --- | |
| if lang == 'en': | |
| sample_rate = 22050 | |
| spectrogram_generator = models['en'] | |
| vocoder = models['hifigan'] | |
| with torch.no_grad(): | |
| parsed = spectrogram_generator.parse(request.text) | |
| spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed) | |
| audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram) | |
| audio_numpy = audio.to('cpu').detach().numpy().squeeze() | |
| # --- CASE B: Bikol & Tagalog (Uses Hugging Face VITS) --- | |
| elif lang in ['tgl', 'bikol']: | |
| sample_rate = 16000 # MMS models are usually 16kHz | |
| # Dynamically select the correct tokenizer and model from the dictionary | |
| tokenizer = models[f'{lang}_tokenizer'] | |
| model = models[f'{lang}_model'] | |
| with torch.no_grad(): | |
| inputs = tokenizer(request.text, return_tensors="pt").to(device) | |
| output = model(**inputs).waveform | |
| audio_numpy = output.cpu().numpy().squeeze() | |
| buffer = io.BytesIO() | |
| sf.write(buffer, audio_numpy, samplerate=sample_rate, format='WAV') | |
| buffer.seek(0) | |
| logger.info(f"--- SYNTHESIS COMPLETE ---") | |
| return StreamingResponse(buffer, media_type="audio/wav") | |
| except Exception as e: | |
| logger.error(f"Error during synthesis: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An error occurred during audio synthesis: {str(e)}") | |
| # --- 5. Add a Root Endpoint for Health Check --- | |
| def read_root(): | |
| available_languages = ['en', 'bikol', 'tgl'] | |
| return {"status": "Multilingual TTS Backend is running", "available_languages": available_languages, "device": device} | |
| # --- 6. Add Model Status Endpoint --- | |
| def get_status(): | |
| """Get the status of all loaded models.""" | |
| return { | |
| "models_loaded": list(models.keys()), | |
| "device": device, | |
| "english_available": 'en' in models, | |
| "bikol_available": 'bikol_model' in models, # Updated check | |
| "tagalog_available": 'tgl_model' in models | |
| } |