import torch import soundfile as sf import io import logging from fastapi import FastAPI, HTTPException, status from pydantic import BaseModel from fastapi.responses import StreamingResponse # --- Library Imports --- from nemo.collections.tts.models import FastPitchModel, HifiGanModel # BaseCharsTokenizer is no longer needed since we aren't using the old NeMo Bikol model # from nemo.collections.tts.torch.tts_tokenizers import BaseCharsTokenizer from transformers import VitsModel, AutoTokenizer # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --- 1. Initialize FastAPI App --- app = FastAPI( title="Multilingual TTS API", description="A backend service to convert text to speech in English, Bikol, and Tagalog.", ) # --- 2. Load Models on Startup --- models = {} device = "cuda" if torch.cuda.is_available() else "cpu" @app.on_event("startup") def load_models(): """Load all models into memory when the application starts.""" logger.info("Loading models...") try: # --- NeMo Models (English Only now) --- logger.info("Loading HiFi-GAN vocoder (for English)...") models['hifigan'] = HifiGanModel.restore_from("models/hifigan_en.nemo").to(device) models['hifigan'].eval() logger.info("Loading English FastPitch model...") models['en'] = FastPitchModel.restore_from("models/fastpitch_en.nemo").to(device) models['en'].eval() # --- Transformers MMS-TTS Models (Bikol & Tagalog) --- # 1. NEW BIKOL MODEL (Fine-tuned VITS) logger.info("Loading Bikol (bcl) MMS-TTS model from Hub...") models['bikol_tokenizer'] = AutoTokenizer.from_pretrained("cuhgrel/bikol-mms-finetuned-v2") models['bikol_model'] = VitsModel.from_pretrained("cuhgrel/bikol-mms-finetuned-v2").to(device) # 2. TAGALOG MODEL (Base VITS) logger.info("Loading Tagalog (tgl) MMS-TTS model from Hub...") models['tgl_tokenizer'] = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl") models['tgl_model'] = VitsModel.from_pretrained("facebook/mms-tts-tgl").to(device) except Exception as e: logger.error(f"FATAL: Could not load models. Error: {e}") import traceback traceback.print_exc() raise e logger.info("Model loading complete.") # --- 3. Define API Request and Response Models --- class TTSRequest(BaseModel): text: str language: str # --- 4. Define the TTS API Endpoint --- @app.post("/synthesize/") def synthesize_speech(request: TTSRequest): """Generates speech from text using the selected language model.""" if not models: raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Models are not loaded.") lang = request.language.lower() valid_langs = ['en', 'bikol', 'tgl'] if lang not in valid_langs: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid language. Use one of {valid_langs}") try: logger.info(f"--- STARTING SYNTHESIS for '{lang}' ---") # --- CASE A: English (Uses NeMo) --- if lang == 'en': sample_rate = 22050 spectrogram_generator = models['en'] vocoder = models['hifigan'] with torch.no_grad(): parsed = spectrogram_generator.parse(request.text) spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed) audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram) audio_numpy = audio.to('cpu').detach().numpy().squeeze() # --- CASE B: Bikol & Tagalog (Uses Hugging Face VITS) --- elif lang in ['tgl', 'bikol']: sample_rate = 16000 # MMS models are usually 16kHz # Dynamically select the correct tokenizer and model from the dictionary tokenizer = models[f'{lang}_tokenizer'] model = models[f'{lang}_model'] with torch.no_grad(): inputs = tokenizer(request.text, return_tensors="pt").to(device) output = model(**inputs).waveform audio_numpy = output.cpu().numpy().squeeze() buffer = io.BytesIO() sf.write(buffer, audio_numpy, samplerate=sample_rate, format='WAV') buffer.seek(0) logger.info(f"--- SYNTHESIS COMPLETE ---") return StreamingResponse(buffer, media_type="audio/wav") except Exception as e: logger.error(f"Error during synthesis: {e}") import traceback traceback.print_exc() raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An error occurred during audio synthesis: {str(e)}") # --- 5. Add a Root Endpoint for Health Check --- @app.get("/") def read_root(): available_languages = ['en', 'bikol', 'tgl'] return {"status": "Multilingual TTS Backend is running", "available_languages": available_languages, "device": device} # --- 6. Add Model Status Endpoint --- @app.get("/status") def get_status(): """Get the status of all loaded models.""" return { "models_loaded": list(models.keys()), "device": device, "english_available": 'en' in models, "bikol_available": 'bikol_model' in models, # Updated check "tagalog_available": 'tgl_model' in models }