File size: 5,418 Bytes
501558b
 
 
df5718a
501558b
 
8491e9b
 
 
501558b
13f5b32
 
c542137
df5718a
 
 
 
501558b
 
 
8491e9b
 
501558b
 
 
 
b0bdaef
 
501558b
 
8491e9b
df5718a
501558b
13f5b32
 
501558b
 
 
df5718a
501558b
 
df5718a
13f5b32
8491e9b
13f5b32
 
fa7ae36
 
8491e9b
13f5b32
8491e9b
 
c542137
793ce85
501558b
df5718a
b404ee5
 
8491e9b
 
793ce85
501558b
 
 
c542137
501558b
 
 
 
c542137
8491e9b
 
df5718a
8491e9b
 
 
 
 
501558b
8491e9b
501558b
13f5b32
 
8491e9b
13f5b32
8491e9b
 
 
 
a07ebd0
8491e9b
 
13f5b32
 
 
 
 
 
 
 
8491e9b
 
c542137
8491e9b
a07ebd0
501558b
8491e9b
501558b
 
a07ebd0
501558b
df5718a
501558b
df5718a
b404ee5
 
a07ebd0
501558b
 
 
 
b0bdaef
c542137
df5718a
 
 
 
 
13f5b32
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import soundfile as sf
import io
import logging
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel
from fastapi.responses import StreamingResponse

# --- Library Imports ---
from nemo.collections.tts.models import FastPitchModel, HifiGanModel
# BaseCharsTokenizer is no longer needed since we aren't using the old NeMo Bikol model
# from nemo.collections.tts.torch.tts_tokenizers import BaseCharsTokenizer 
from transformers import VitsModel, AutoTokenizer

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- 1. Initialize FastAPI App ---
app = FastAPI(
    title="Multilingual TTS API",
    description="A backend service to convert text to speech in English, Bikol, and Tagalog.",
)

# --- 2. Load Models on Startup ---
models = {}
device = "cuda" if torch.cuda.is_available() else "cpu"

@app.on_event("startup")
def load_models():
    """Load all models into memory when the application starts."""
    logger.info("Loading models...")
    try:
        # --- NeMo Models (English Only now) ---
        logger.info("Loading HiFi-GAN vocoder (for English)...")
        models['hifigan'] = HifiGanModel.restore_from("models/hifigan_en.nemo").to(device)
        models['hifigan'].eval()
        
        logger.info("Loading English FastPitch model...")
        models['en'] = FastPitchModel.restore_from("models/fastpitch_en.nemo").to(device)
        models['en'].eval()
        
        # --- Transformers MMS-TTS Models (Bikol & Tagalog) ---
        
        # 1. NEW BIKOL MODEL (Fine-tuned VITS)
        logger.info("Loading Bikol (bcl) MMS-TTS model from Hub...")
        models['bikol_tokenizer'] = AutoTokenizer.from_pretrained("cuhgrel/bikol-mms-finetuned-v2")
        models['bikol_model'] = VitsModel.from_pretrained("cuhgrel/bikol-mms-finetuned-v2").to(device)

        # 2. TAGALOG MODEL (Base VITS)
        logger.info("Loading Tagalog (tgl) MMS-TTS model from Hub...")
        models['tgl_tokenizer'] = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl")
        models['tgl_model'] = VitsModel.from_pretrained("facebook/mms-tts-tgl").to(device)

    except Exception as e:
        logger.error(f"FATAL: Could not load models. Error: {e}")
        import traceback
        traceback.print_exc()
        raise e
    logger.info("Model loading complete.")

# --- 3. Define API Request and Response Models ---
class TTSRequest(BaseModel):
    text: str
    language: str

# --- 4. Define the TTS API Endpoint ---
@app.post("/synthesize/")
def synthesize_speech(request: TTSRequest):
    """Generates speech from text using the selected language model."""
    if not models:
        raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Models are not loaded.")
    
    lang = request.language.lower()
    valid_langs = ['en', 'bikol', 'tgl']
    if lang not in valid_langs:
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid language. Use one of {valid_langs}")

    try:
        logger.info(f"--- STARTING SYNTHESIS for '{lang}' ---")
        
        # --- CASE A: English (Uses NeMo) ---
        if lang == 'en':
            sample_rate = 22050
            spectrogram_generator = models['en']
            vocoder = models['hifigan']
            with torch.no_grad():
                parsed = spectrogram_generator.parse(request.text)
                spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed)
                audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)
            audio_numpy = audio.to('cpu').detach().numpy().squeeze()

        # --- CASE B: Bikol & Tagalog (Uses Hugging Face VITS) ---
        elif lang in ['tgl', 'bikol']:
            sample_rate = 16000 # MMS models are usually 16kHz
            
            # Dynamically select the correct tokenizer and model from the dictionary
            tokenizer = models[f'{lang}_tokenizer']
            model = models[f'{lang}_model']
            
            with torch.no_grad():
                inputs = tokenizer(request.text, return_tensors="pt").to(device)
                output = model(**inputs).waveform
            audio_numpy = output.cpu().numpy().squeeze()

        buffer = io.BytesIO()
        sf.write(buffer, audio_numpy, samplerate=sample_rate, format='WAV')
        buffer.seek(0)
        
        logger.info(f"--- SYNTHESIS COMPLETE ---")
        return StreamingResponse(buffer, media_type="audio/wav")
        
    except Exception as e:
        logger.error(f"Error during synthesis: {e}")
        import traceback
        traceback.print_exc()
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An error occurred during audio synthesis: {str(e)}")

# --- 5. Add a Root Endpoint for Health Check ---
@app.get("/")
def read_root():
    available_languages = ['en', 'bikol', 'tgl']
    return {"status": "Multilingual TTS Backend is running", "available_languages": available_languages, "device": device}

# --- 6. Add Model Status Endpoint ---
@app.get("/status")
def get_status():
    """Get the status of all loaded models."""
    return {
        "models_loaded": list(models.keys()), 
        "device": device, 
        "english_available": 'en' in models, 
        "bikol_available": 'bikol_model' in models, # Updated check
        "tagalog_available": 'tgl_model' in models
    }