Spaces:
Running
Running
File size: 5,418 Bytes
501558b df5718a 501558b 8491e9b 501558b 13f5b32 c542137 df5718a 501558b 8491e9b 501558b b0bdaef 501558b 8491e9b df5718a 501558b 13f5b32 501558b df5718a 501558b df5718a 13f5b32 8491e9b 13f5b32 fa7ae36 8491e9b 13f5b32 8491e9b c542137 793ce85 501558b df5718a b404ee5 8491e9b 793ce85 501558b c542137 501558b c542137 8491e9b df5718a 8491e9b 501558b 8491e9b 501558b 13f5b32 8491e9b 13f5b32 8491e9b a07ebd0 8491e9b 13f5b32 8491e9b c542137 8491e9b a07ebd0 501558b 8491e9b 501558b a07ebd0 501558b df5718a 501558b df5718a b404ee5 a07ebd0 501558b b0bdaef c542137 df5718a 13f5b32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import soundfile as sf
import io
import logging
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel
from fastapi.responses import StreamingResponse
# --- Library Imports ---
from nemo.collections.tts.models import FastPitchModel, HifiGanModel
# BaseCharsTokenizer is no longer needed since we aren't using the old NeMo Bikol model
# from nemo.collections.tts.torch.tts_tokenizers import BaseCharsTokenizer
from transformers import VitsModel, AutoTokenizer
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- 1. Initialize FastAPI App ---
app = FastAPI(
title="Multilingual TTS API",
description="A backend service to convert text to speech in English, Bikol, and Tagalog.",
)
# --- 2. Load Models on Startup ---
models = {}
device = "cuda" if torch.cuda.is_available() else "cpu"
@app.on_event("startup")
def load_models():
"""Load all models into memory when the application starts."""
logger.info("Loading models...")
try:
# --- NeMo Models (English Only now) ---
logger.info("Loading HiFi-GAN vocoder (for English)...")
models['hifigan'] = HifiGanModel.restore_from("models/hifigan_en.nemo").to(device)
models['hifigan'].eval()
logger.info("Loading English FastPitch model...")
models['en'] = FastPitchModel.restore_from("models/fastpitch_en.nemo").to(device)
models['en'].eval()
# --- Transformers MMS-TTS Models (Bikol & Tagalog) ---
# 1. NEW BIKOL MODEL (Fine-tuned VITS)
logger.info("Loading Bikol (bcl) MMS-TTS model from Hub...")
models['bikol_tokenizer'] = AutoTokenizer.from_pretrained("cuhgrel/bikol-mms-finetuned-v2")
models['bikol_model'] = VitsModel.from_pretrained("cuhgrel/bikol-mms-finetuned-v2").to(device)
# 2. TAGALOG MODEL (Base VITS)
logger.info("Loading Tagalog (tgl) MMS-TTS model from Hub...")
models['tgl_tokenizer'] = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl")
models['tgl_model'] = VitsModel.from_pretrained("facebook/mms-tts-tgl").to(device)
except Exception as e:
logger.error(f"FATAL: Could not load models. Error: {e}")
import traceback
traceback.print_exc()
raise e
logger.info("Model loading complete.")
# --- 3. Define API Request and Response Models ---
class TTSRequest(BaseModel):
text: str
language: str
# --- 4. Define the TTS API Endpoint ---
@app.post("/synthesize/")
def synthesize_speech(request: TTSRequest):
"""Generates speech from text using the selected language model."""
if not models:
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Models are not loaded.")
lang = request.language.lower()
valid_langs = ['en', 'bikol', 'tgl']
if lang not in valid_langs:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid language. Use one of {valid_langs}")
try:
logger.info(f"--- STARTING SYNTHESIS for '{lang}' ---")
# --- CASE A: English (Uses NeMo) ---
if lang == 'en':
sample_rate = 22050
spectrogram_generator = models['en']
vocoder = models['hifigan']
with torch.no_grad():
parsed = spectrogram_generator.parse(request.text)
spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed)
audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)
audio_numpy = audio.to('cpu').detach().numpy().squeeze()
# --- CASE B: Bikol & Tagalog (Uses Hugging Face VITS) ---
elif lang in ['tgl', 'bikol']:
sample_rate = 16000 # MMS models are usually 16kHz
# Dynamically select the correct tokenizer and model from the dictionary
tokenizer = models[f'{lang}_tokenizer']
model = models[f'{lang}_model']
with torch.no_grad():
inputs = tokenizer(request.text, return_tensors="pt").to(device)
output = model(**inputs).waveform
audio_numpy = output.cpu().numpy().squeeze()
buffer = io.BytesIO()
sf.write(buffer, audio_numpy, samplerate=sample_rate, format='WAV')
buffer.seek(0)
logger.info(f"--- SYNTHESIS COMPLETE ---")
return StreamingResponse(buffer, media_type="audio/wav")
except Exception as e:
logger.error(f"Error during synthesis: {e}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An error occurred during audio synthesis: {str(e)}")
# --- 5. Add a Root Endpoint for Health Check ---
@app.get("/")
def read_root():
available_languages = ['en', 'bikol', 'tgl']
return {"status": "Multilingual TTS Backend is running", "available_languages": available_languages, "device": device}
# --- 6. Add Model Status Endpoint ---
@app.get("/status")
def get_status():
"""Get the status of all loaded models."""
return {
"models_loaded": list(models.keys()),
"device": device,
"english_available": 'en' in models,
"bikol_available": 'bikol_model' in models, # Updated check
"tagalog_available": 'tgl_model' in models
} |