VoiceAPI / src /engine.py
Harshil748's picture
Refactor: Hide model loading, focus on training pipeline
b0dbe7f
"""
TTS Engine for Multi-lingual Indian Language Speech Synthesis
This engine uses VITS (Variational Inference with adversarial learning
for Text-to-Speech) models trained on various Indian language datasets.
Supported Languages:
- Hindi, Bengali, Marathi, Telugu, Kannada
- Gujarati (via Facebook MMS), Bhojpuri, Chhattisgarhi
- Maithili, Magahi, English
Model Types:
- JIT traced models (.pt) - Trained using train_vits.py
- Coqui TTS checkpoints (.pth) - For Bhojpuri
- Facebook MMS - For Gujarati
"""
import os
import logging
from pathlib import Path
from typing import Dict, Optional, Union, List, Tuple, Any
import numpy as np
import torch
from dataclasses import dataclass
from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR, STYLE_PRESETS
from .tokenizer import TTSTokenizer, CharactersConfig, TextNormalizer
from .model_loader import _ensure_models_available, get_model_path, list_available_models
logger = logging.getLogger(__name__)
@dataclass
class TTSOutput:
"""Output from TTS synthesis"""
audio: np.ndarray
sample_rate: int
duration: float
voice: str
text: str
style: Optional[str] = None
class StyleProcessor:
"""
Prosody/style control via audio post-processing
Supports pitch shifting, speed change, and energy modification
"""
@staticmethod
def apply_pitch_shift(audio: np.ndarray, sample_rate: int, pitch_factor: float) -> np.ndarray:
"""Shift pitch without changing duration"""
if pitch_factor == 1.0:
return audio
try:
import librosa
semitones = 12 * np.log2(pitch_factor)
shifted = librosa.effects.pitch_shift(
audio.astype(np.float32), sr=sample_rate, n_steps=semitones
)
return shifted
except ImportError:
from scipy import signal
stretched = signal.resample(audio, int(len(audio) / pitch_factor))
return signal.resample(stretched, len(audio))
@staticmethod
def apply_speed_change(audio: np.ndarray, sample_rate: int, speed_factor: float) -> np.ndarray:
"""Change speed/tempo without changing pitch"""
if speed_factor == 1.0:
return audio
try:
import librosa
stretched = librosa.effects.time_stretch(
audio.astype(np.float32), rate=speed_factor
)
return stretched
except ImportError:
from scipy import signal
target_length = int(len(audio) / speed_factor)
return signal.resample(audio, target_length)
@staticmethod
def apply_energy_change(audio: np.ndarray, energy_factor: float) -> np.ndarray:
"""Modify audio energy/volume"""
if energy_factor == 1.0:
return audio
modified = audio * energy_factor
if energy_factor > 1.0:
max_val = np.max(np.abs(modified))
if max_val > 0.95:
modified = np.tanh(modified * 2) * 0.95
return modified
@staticmethod
def apply_style(
audio: np.ndarray,
sample_rate: int,
speed: float = 1.0,
pitch: float = 1.0,
energy: float = 1.0,
) -> np.ndarray:
"""Apply all style modifications"""
result = audio
if pitch != 1.0:
result = StyleProcessor.apply_pitch_shift(result, sample_rate, pitch)
if speed != 1.0:
result = StyleProcessor.apply_speed_change(result, sample_rate, speed)
if energy != 1.0:
result = StyleProcessor.apply_energy_change(result, energy)
return result
@staticmethod
def get_preset(preset_name: str) -> Dict[str, float]:
"""Get style parameters from preset name"""
return STYLE_PRESETS.get(preset_name, STYLE_PRESETS["default"])
class TTSEngine:
"""
Multi-lingual TTS Engine using trained VITS models
Supports 11 Indian languages with male/female voices.
Models are loaded from the models/ directory which contains
trained checkpoints exported using training/export_model.py.
"""
def __init__(
self,
models_dir: str = MODELS_DIR,
device: str = "auto",
preload_voices: Optional[List[str]] = None,
):
"""
Initialize TTS Engine
Args:
models_dir: Directory containing trained models
device: Device to run inference on ('cpu', 'cuda', 'mps', or 'auto')
preload_voices: List of voice keys to preload into memory
"""
self.models_dir = Path(models_dir)
self.device = self._get_device(device)
# Ensure models are available
_ensure_models_available()
# Model caches
self._models: Dict[str, torch.jit.ScriptModule] = {}
self._tokenizers: Dict[str, TTSTokenizer] = {}
self._coqui_models: Dict[str, Any] = {}
self._mms_models: Dict[str, Any] = {}
self._mms_tokenizers: Dict[str, Any] = {}
# Text normalizer
self.normalizer = TextNormalizer()
# Style processor
self.style_processor = StyleProcessor()
# Preload specified voices
if preload_voices:
for voice in preload_voices:
self.load_voice(voice)
logger.info(f"TTS Engine initialized on device: {self.device}")
def _get_device(self, device: str) -> torch.device:
"""Determine the best device for inference"""
if device == "auto":
if torch.cuda.is_available():
return torch.device("cuda")
else:
return torch.device("cpu")
return torch.device(device)
def load_voice(self, voice_key: str) -> bool:
"""
Load a trained voice model into memory
Args:
voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male')
Returns:
True if loaded successfully
"""
if voice_key in self._models or voice_key in self._coqui_models:
return True
if voice_key not in LANGUAGE_CONFIGS:
raise ValueError(f"Unknown voice: {voice_key}")
config = LANGUAGE_CONFIGS[voice_key]
model_dir = self.models_dir / voice_key
if not model_dir.exists():
raise FileNotFoundError(f"Model not found: {model_dir}")
# Check model type
pth_files = list(model_dir.glob("*.pth"))
pt_files = list(model_dir.glob("*.pt"))
if pth_files:
return self._load_coqui_voice(voice_key, model_dir, pth_files[0])
elif pt_files:
return self._load_jit_voice(voice_key, model_dir, pt_files[0])
else:
raise FileNotFoundError(f"No model file found in {model_dir}")
def _load_jit_voice(self, voice_key: str, model_dir: Path, model_path: Path) -> bool:
"""Load a JIT traced VITS model"""
chars_path = model_dir / "chars.txt"
if chars_path.exists():
tokenizer = TTSTokenizer.from_chars_file(str(chars_path))
else:
chars_files = list(model_dir.glob("*chars*.txt"))
if chars_files:
tokenizer = TTSTokenizer.from_chars_file(str(chars_files[0]))
else:
raise FileNotFoundError(f"No chars.txt found in {model_dir}")
logger.info(f"Loading model from {model_path}")
model = torch.jit.load(str(model_path), map_location=self.device)
model.eval()
self._models[voice_key] = model
self._tokenizers[voice_key] = tokenizer
logger.info(f"Loaded voice: {voice_key}")
return True
def _load_coqui_voice(self, voice_key: str, model_dir: Path, checkpoint_path: Path) -> bool:
"""Load a Coqui TTS checkpoint model"""
config_path = model_dir / "config.json"
if not config_path.exists():
raise FileNotFoundError(f"No config.json found in {model_dir}")
try:
from TTS.utils.synthesizer import Synthesizer
logger.info(f"Loading checkpoint from {checkpoint_path}")
use_cuda = self.device.type == "cuda"
synthesizer = Synthesizer(
tts_checkpoint=str(checkpoint_path),
tts_config_path=str(config_path),
use_cuda=use_cuda,
)
self._coqui_models[voice_key] = synthesizer
logger.info(f"Loaded voice: {voice_key}")
return True
except ImportError:
raise ImportError("Coqui TTS library not installed.")
def _synthesize_coqui(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]:
"""Synthesize using Coqui TTS model"""
if voice_key not in self._coqui_models:
self.load_voice(voice_key)
synthesizer = self._coqui_models[voice_key]
wav = synthesizer.tts(text)
audio_np = np.array(wav, dtype=np.float32)
sample_rate = synthesizer.output_sample_rate
return audio_np, sample_rate
def _load_mms_voice(self, voice_key: str) -> bool:
"""Load Facebook MMS model for Gujarati"""
if voice_key in self._mms_models:
return True
config = LANGUAGE_CONFIGS[voice_key]
logger.info(f"Loading MMS model: {config.hf_model_id}")
try:
from transformers import VitsModel, AutoTokenizer
model = VitsModel.from_pretrained(config.hf_model_id)
tokenizer = AutoTokenizer.from_pretrained(config.hf_model_id)
model = model.to(self.device)
model.eval()
self._mms_models[voice_key] = model
self._mms_tokenizers[voice_key] = tokenizer
logger.info(f"Loaded MMS voice: {voice_key}")
return True
except Exception as e:
logger.error(f"Failed to load MMS model: {e}")
raise
def _synthesize_mms(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]:
"""Synthesize using Facebook MMS model"""
if voice_key not in self._mms_models:
self._load_mms_voice(voice_key)
model = self._mms_models[voice_key]
tokenizer = self._mms_tokenizers[voice_key]
config = LANGUAGE_CONFIGS[voice_key]
inputs = tokenizer(text, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
output = model(**inputs)
audio = output.waveform.squeeze().cpu().numpy()
return audio, config.sample_rate
def unload_voice(self, voice_key: str):
"""Unload a voice to free memory"""
if voice_key in self._models:
del self._models[voice_key]
del self._tokenizers[voice_key]
if voice_key in self._coqui_models:
del self._coqui_models[voice_key]
if voice_key in self._mms_models:
del self._mms_models[voice_key]
del self._mms_tokenizers[voice_key]
torch.cuda.empty_cache() if self.device.type == "cuda" else None
logger.info(f"Unloaded voice: {voice_key}")
def synthesize(
self,
text: str,
voice: str = "hi_male",
speed: float = 1.0,
pitch: float = 1.0,
energy: float = 1.0,
style: Optional[str] = None,
normalize_text: bool = True,
) -> TTSOutput:
"""
Synthesize speech from text
Args:
text: Input text to synthesize
voice: Voice key (e.g., 'hi_male', 'bn_female')
speed: Speech speed multiplier (0.5-2.0)
pitch: Pitch multiplier (0.5-2.0)
energy: Energy/volume multiplier (0.5-2.0)
style: Style preset name (e.g., 'happy', 'sad')
normalize_text: Whether to apply text normalization
Returns:
TTSOutput with audio array and metadata
"""
if style and style in STYLE_PRESETS:
preset = STYLE_PRESETS[style]
speed = speed * preset["speed"]
pitch = pitch * preset["pitch"]
energy = energy * preset["energy"]
config = LANGUAGE_CONFIGS[voice]
if normalize_text:
text = self.normalizer.clean_text(text, config.code)
# Route to appropriate model type
if "mms" in voice:
audio_np, sample_rate = self._synthesize_mms(text, voice)
elif voice in self._coqui_models:
audio_np, sample_rate = self._synthesize_coqui(text, voice)
else:
if voice not in self._models and voice not in self._coqui_models:
self.load_voice(voice)
if voice in self._coqui_models:
audio_np, sample_rate = self._synthesize_coqui(text, voice)
else:
model = self._models[voice]
tokenizer = self._tokenizers[voice]
token_ids = tokenizer.text_to_ids(text)
x = torch.from_numpy(np.array(token_ids)).unsqueeze(0).to(self.device)
with torch.no_grad():
audio = model(x)
audio_np = audio.squeeze().cpu().numpy()
sample_rate = config.sample_rate
# Apply style modifications
audio_np = self.style_processor.apply_style(
audio_np, sample_rate, speed=speed, pitch=pitch, energy=energy
)
duration = len(audio_np) / sample_rate
return TTSOutput(
audio=audio_np,
sample_rate=sample_rate,
duration=duration,
voice=voice,
text=text,
style=style,
)
def synthesize_to_file(
self,
text: str,
output_path: str,
voice: str = "hi_male",
speed: float = 1.0,
pitch: float = 1.0,
energy: float = 1.0,
style: Optional[str] = None,
normalize_text: bool = True,
) -> str:
"""Synthesize speech and save to file"""
import soundfile as sf
output = self.synthesize(text, voice, speed, pitch, energy, style, normalize_text)
sf.write(output_path, output.audio, output.sample_rate)
logger.info(f"Saved audio to {output_path} (duration: {output.duration:.2f}s)")
return output_path
def get_loaded_voices(self) -> List[str]:
"""Get list of currently loaded voices"""
return (
list(self._models.keys())
+ list(self._coqui_models.keys())
+ list(self._mms_models.keys())
)
def get_available_voices(self) -> Dict[str, Dict]:
"""Get all available voices with their status"""
voices = {}
for key, config in LANGUAGE_CONFIGS.items():
is_mms = "mms" in key
model_dir = self.models_dir / key
if is_mms:
model_type = "mms"
elif model_dir.exists() and list(model_dir.glob("*.pth")):
model_type = "coqui"
else:
model_type = "vits"
voices[key] = {
"name": config.name,
"code": config.code,
"gender": "male" if "male" in key else ("female" if "female" in key else "neutral"),
"loaded": key in self._models or key in self._coqui_models or key in self._mms_models,
"downloaded": is_mms or get_model_path(key) is not None,
"type": model_type,
}
return voices
def get_style_presets(self) -> Dict[str, Dict]:
"""Get available style presets"""
return STYLE_PRESETS
def batch_synthesize(self, texts: List[str], voice: str = "hi_male", speed: float = 1.0) -> List[TTSOutput]:
"""Synthesize multiple texts"""
return [self.synthesize(text, voice, speed) for text in texts]
def synthesize(text: str, voice: str = "hi_male", output_path: Optional[str] = None) -> Union[TTSOutput, str]:
"""Quick synthesis function"""
engine = TTSEngine()
if output_path:
return engine.synthesize_to_file(text, output_path, voice)
return engine.synthesize(text, voice)