cuhgrel commited on
Commit
df5718a
·
1 Parent(s): b404ee5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -20
app.py CHANGED
@@ -1,9 +1,15 @@
1
  import torch
2
  import soundfile as sf
3
  import io
 
4
  from fastapi import FastAPI, HTTPException, status
5
  from pydantic import BaseModel
6
  from nemo.collections.tts.models import FastPitchModel, HifiGanModel
 
 
 
 
 
7
 
8
  # --- 1. Initialize FastAPI App ---
9
  app = FastAPI(
@@ -12,43 +18,91 @@ app = FastAPI(
12
  )
13
 
14
  # --- 2. Load Models on Startup ---
15
- # This dictionary will hold our loaded models to avoid reloading on every request.
16
  models = {}
17
 
18
  @app.on_event("startup")
19
  def load_models():
20
  """Load all NeMo models into memory when the application starts."""
21
- print("Loading models...")
22
  device = "cuda" if torch.cuda.is_available() else "cpu"
23
 
24
- # Load the shared HiFi-GAN Vocoder
25
  try:
 
 
26
  models['hifigan'] = HifiGanModel.restore_from("models/hifigan_en.nemo").to(device)
27
  models['hifigan'].eval()
 
28
 
29
  # Load the English Spectrogram Generator
 
30
  models['en'] = FastPitchModel.restore_from("models/fastpitch_en.nemo").to(device)
31
  models['en'].eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Load the Bikol Spectrogram Generator with strict=False
34
- # This allows loading even if there are size mismatches in embedding layers
35
- models['bikol'] = FastPitchModel.restore_from(
36
- "models/fastpitch_bikol_repacked.nemo",
37
- strict=False # Critical: allows loading with different vocabulary size
38
- ).to(device)
39
- models['bikol'].eval()
40
-
41
- print("All models loaded successfully.")
42
  except Exception as e:
43
- print(f"FATAL: Could not load models. Error: {e}")
44
  import traceback
45
  traceback.print_exc()
46
- # In a real app, you might want the app to fail fast if models can't load.
47
 
48
  # --- 3. Define API Request and Response Models ---
49
  class TTSRequest(BaseModel):
50
  text: str
51
- language: str # Should be 'en' or 'bikol'
52
 
53
  # --- 4. Define the TTS API Endpoint ---
54
  @app.post("/synthesize/")
@@ -61,30 +115,46 @@ def synthesize_speech(request: TTSRequest):
61
  status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
62
  detail="Models are not loaded yet. Please try again in a moment."
63
  )
64
-
65
  # Validate the requested language
66
  if request.language not in ['en', 'bikol']:
67
  raise HTTPException(
68
  status_code=status.HTTP_400_BAD_REQUEST,
69
  detail="Invalid language specified. Use 'en' or 'bikol'."
70
  )
71
-
 
 
 
 
 
 
 
 
72
  try:
73
  # Select the correct FastPitch model
74
  spectrogram_generator = models[request.language]
75
  vocoder = models['hifigan']
76
 
77
  # --- Generate Audio ---
 
 
78
  # Parse text into token IDs
79
  parsed = spectrogram_generator.parse(request.text)
 
80
  # Generate spectrogram
81
  spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed)
 
82
  # Convert spectrogram to audio waveform
83
  audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)
84
 
85
  # --- Prepare and return audio file ---
86
  audio_numpy = audio.to('cpu').numpy()
87
 
 
 
 
 
88
  # Use an in-memory buffer to avoid writing to disk
89
  buffer = io.BytesIO()
90
  sf.write(buffer, audio_numpy, samplerate=22050, format='WAV')
@@ -93,9 +163,9 @@ def synthesize_speech(request: TTSRequest):
93
  # Return the audio data as a streaming response
94
  from fastapi.responses import StreamingResponse
95
  return StreamingResponse(buffer, media_type="audio/wav")
96
-
97
  except Exception as e:
98
- print(f"Error during synthesis: {e}")
99
  import traceback
100
  traceback.print_exc()
101
  raise HTTPException(
@@ -106,4 +176,20 @@ def synthesize_speech(request: TTSRequest):
106
  # --- 5. Add a Root Endpoint for Health Check ---
107
  @app.get("/")
108
  def read_root():
109
- return {"status": "Nemo TTS Backend is running."}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import soundfile as sf
3
  import io
4
+ import logging
5
  from fastapi import FastAPI, HTTPException, status
6
  from pydantic import BaseModel
7
  from nemo.collections.tts.models import FastPitchModel, HifiGanModel
8
+ from omegaconf import OmegaConf, open_dict
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
 
14
  # --- 1. Initialize FastAPI App ---
15
  app = FastAPI(
 
18
  )
19
 
20
  # --- 2. Load Models on Startup ---
 
21
  models = {}
22
 
23
  @app.on_event("startup")
24
  def load_models():
25
  """Load all NeMo models into memory when the application starts."""
26
+ logger.info("Loading models...")
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
 
29
  try:
30
+ # Load the shared HiFi-GAN Vocoder
31
+ logger.info("Loading HiFi-GAN vocoder...")
32
  models['hifigan'] = HifiGanModel.restore_from("models/hifigan_en.nemo").to(device)
33
  models['hifigan'].eval()
34
+ logger.info("HiFi-GAN loaded successfully")
35
 
36
  # Load the English Spectrogram Generator
37
+ logger.info("Loading English FastPitch model...")
38
  models['en'] = FastPitchModel.restore_from("models/fastpitch_en.nemo").to(device)
39
  models['en'].eval()
40
+ logger.info("English model loaded successfully")
41
+
42
+ # Load the Bikol Spectrogram Generator with configuration override
43
+ logger.info("Loading Bikol FastPitch model...")
44
+ try:
45
+ # First attempt: Try loading with strict=False
46
+ models['bikol'] = FastPitchModel.restore_from(
47
+ "models/fastpitch_bikol_repacked.nemo",
48
+ strict=False
49
+ ).to(device)
50
+ models['bikol'].eval()
51
+ logger.info("Bikol model loaded successfully")
52
+
53
+ except Exception as e:
54
+ logger.warning(f"First attempt failed: {e}")
55
+ logger.info("Attempting to load Bikol model with config override...")
56
+
57
+ # Second attempt: Override the text_tokenizer config to remove g2p parameter
58
+ try:
59
+ # Create a config override that removes the problematic g2p parameter
60
+ override_config = OmegaConf.create({
61
+ 'text_tokenizer': {
62
+ '_target_': 'nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.BaseCharsTokenizer',
63
+ 'pad_with_space': True
64
+ }
65
+ })
66
+
67
+ models['bikol'] = FastPitchModel.restore_from(
68
+ "models/fastpitch_bikol_repacked.nemo",
69
+ override_config_path=override_config,
70
+ strict=False
71
+ ).to(device)
72
+ models['bikol'].eval()
73
+ logger.info("Bikol model loaded successfully with config override")
74
+
75
+ except Exception as e2:
76
+ logger.error(f"Failed to load Bikol model with override: {e2}")
77
+ # Third attempt: Try modifying the saved config
78
+ logger.info("Attempting alternative loading method...")
79
+
80
+ try:
81
+ # Load model with map_location to avoid device issues
82
+ models['bikol'] = FastPitchModel.restore_from(
83
+ "models/fastpitch_bikol_repacked.nemo",
84
+ map_location=device,
85
+ strict=False
86
+ )
87
+ models['bikol'].eval()
88
+ logger.info("Bikol model loaded with map_location")
89
+ except Exception as e3:
90
+ logger.error(f"All attempts to load Bikol model failed: {e3}")
91
+ logger.error("Bikol language will not be available")
92
+ # Don't raise - allow app to start with just English
93
+
94
+ logger.info("Model loading complete. Available models: " + ", ".join(models.keys()))
95
 
 
 
 
 
 
 
 
 
 
96
  except Exception as e:
97
+ logger.error(f"FATAL: Could not load models. Error: {e}")
98
  import traceback
99
  traceback.print_exc()
100
+ # Allow app to start even if models fail - better for debugging
101
 
102
  # --- 3. Define API Request and Response Models ---
103
  class TTSRequest(BaseModel):
104
  text: str
105
+ language: str # Should be 'en' or 'bikol'
106
 
107
  # --- 4. Define the TTS API Endpoint ---
108
  @app.post("/synthesize/")
 
115
  status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
116
  detail="Models are not loaded yet. Please try again in a moment."
117
  )
118
+
119
  # Validate the requested language
120
  if request.language not in ['en', 'bikol']:
121
  raise HTTPException(
122
  status_code=status.HTTP_400_BAD_REQUEST,
123
  detail="Invalid language specified. Use 'en' or 'bikol'."
124
  )
125
+
126
+ # Check if requested model is available
127
+ if request.language not in models:
128
+ available = [k for k in models.keys() if k != 'hifigan']
129
+ raise HTTPException(
130
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
131
+ detail=f"The '{request.language}' model is not available. Available languages: {', '.join(available)}"
132
+ )
133
+
134
  try:
135
  # Select the correct FastPitch model
136
  spectrogram_generator = models[request.language]
137
  vocoder = models['hifigan']
138
 
139
  # --- Generate Audio ---
140
+ logger.info(f"Generating speech for text: '{request.text}' in language: {request.language}")
141
+
142
  # Parse text into token IDs
143
  parsed = spectrogram_generator.parse(request.text)
144
+
145
  # Generate spectrogram
146
  spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed)
147
+
148
  # Convert spectrogram to audio waveform
149
  audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)
150
 
151
  # --- Prepare and return audio file ---
152
  audio_numpy = audio.to('cpu').numpy()
153
 
154
+ # Ensure audio is 1D
155
+ if len(audio_numpy.shape) > 1:
156
+ audio_numpy = audio_numpy.squeeze()
157
+
158
  # Use an in-memory buffer to avoid writing to disk
159
  buffer = io.BytesIO()
160
  sf.write(buffer, audio_numpy, samplerate=22050, format='WAV')
 
163
  # Return the audio data as a streaming response
164
  from fastapi.responses import StreamingResponse
165
  return StreamingResponse(buffer, media_type="audio/wav")
166
+
167
  except Exception as e:
168
+ logger.error(f"Error during synthesis: {e}")
169
  import traceback
170
  traceback.print_exc()
171
  raise HTTPException(
 
176
  # --- 5. Add a Root Endpoint for Health Check ---
177
  @app.get("/")
178
  def read_root():
179
+ available_models = [k for k in models.keys() if k != 'hifigan']
180
+ return {
181
+ "status": "NeMo TTS Backend is running",
182
+ "available_languages": available_models,
183
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
184
+ }
185
+
186
+ # --- 6. Add Model Status Endpoint ---
187
+ @app.get("/status")
188
+ def get_status():
189
+ """Get the status of all loaded models."""
190
+ return {
191
+ "models_loaded": list(models.keys()),
192
+ "device": "cuda" if torch.cuda.is_available() else "cpu",
193
+ "english_available": 'en' in models,
194
+ "bikol_available": 'bikol' in models
195
+ }