Spaces:
Running
Running
| """ | |
| TTS Engine for Multi-lingual Indian Language Speech Synthesis | |
| This engine uses VITS (Variational Inference with adversarial learning | |
| for Text-to-Speech) models trained on various Indian language datasets. | |
| Supported Languages: | |
| - Hindi, Bengali, Marathi, Telugu, Kannada | |
| - Gujarati (via Facebook MMS), Bhojpuri, Chhattisgarhi | |
| - Maithili, Magahi, English | |
| Model Types: | |
| - JIT traced models (.pt) - Trained using train_vits.py | |
| - Coqui TTS checkpoints (.pth) - For Bhojpuri | |
| - Facebook MMS - For Gujarati | |
| """ | |
| import os | |
| import logging | |
| from pathlib import Path | |
| from typing import Dict, Optional, Union, List, Tuple, Any | |
| import numpy as np | |
| import torch | |
| from dataclasses import dataclass | |
| from .config import LANGUAGE_CONFIGS, LanguageConfig, MODELS_DIR, STYLE_PRESETS | |
| from .tokenizer import TTSTokenizer, CharactersConfig, TextNormalizer | |
| from .model_loader import _ensure_models_available, get_model_path, list_available_models | |
| logger = logging.getLogger(__name__) | |
| class TTSOutput: | |
| """Output from TTS synthesis""" | |
| audio: np.ndarray | |
| sample_rate: int | |
| duration: float | |
| voice: str | |
| text: str | |
| style: Optional[str] = None | |
| class StyleProcessor: | |
| """ | |
| Prosody/style control via audio post-processing | |
| Supports pitch shifting, speed change, and energy modification | |
| """ | |
| def apply_pitch_shift(audio: np.ndarray, sample_rate: int, pitch_factor: float) -> np.ndarray: | |
| """Shift pitch without changing duration""" | |
| if pitch_factor == 1.0: | |
| return audio | |
| try: | |
| import librosa | |
| semitones = 12 * np.log2(pitch_factor) | |
| shifted = librosa.effects.pitch_shift( | |
| audio.astype(np.float32), sr=sample_rate, n_steps=semitones | |
| ) | |
| return shifted | |
| except ImportError: | |
| from scipy import signal | |
| stretched = signal.resample(audio, int(len(audio) / pitch_factor)) | |
| return signal.resample(stretched, len(audio)) | |
| def apply_speed_change(audio: np.ndarray, sample_rate: int, speed_factor: float) -> np.ndarray: | |
| """Change speed/tempo without changing pitch""" | |
| if speed_factor == 1.0: | |
| return audio | |
| try: | |
| import librosa | |
| stretched = librosa.effects.time_stretch( | |
| audio.astype(np.float32), rate=speed_factor | |
| ) | |
| return stretched | |
| except ImportError: | |
| from scipy import signal | |
| target_length = int(len(audio) / speed_factor) | |
| return signal.resample(audio, target_length) | |
| def apply_energy_change(audio: np.ndarray, energy_factor: float) -> np.ndarray: | |
| """Modify audio energy/volume""" | |
| if energy_factor == 1.0: | |
| return audio | |
| modified = audio * energy_factor | |
| if energy_factor > 1.0: | |
| max_val = np.max(np.abs(modified)) | |
| if max_val > 0.95: | |
| modified = np.tanh(modified * 2) * 0.95 | |
| return modified | |
| def apply_style( | |
| audio: np.ndarray, | |
| sample_rate: int, | |
| speed: float = 1.0, | |
| pitch: float = 1.0, | |
| energy: float = 1.0, | |
| ) -> np.ndarray: | |
| """Apply all style modifications""" | |
| result = audio | |
| if pitch != 1.0: | |
| result = StyleProcessor.apply_pitch_shift(result, sample_rate, pitch) | |
| if speed != 1.0: | |
| result = StyleProcessor.apply_speed_change(result, sample_rate, speed) | |
| if energy != 1.0: | |
| result = StyleProcessor.apply_energy_change(result, energy) | |
| return result | |
| def get_preset(preset_name: str) -> Dict[str, float]: | |
| """Get style parameters from preset name""" | |
| return STYLE_PRESETS.get(preset_name, STYLE_PRESETS["default"]) | |
| class TTSEngine: | |
| """ | |
| Multi-lingual TTS Engine using trained VITS models | |
| Supports 11 Indian languages with male/female voices. | |
| Models are loaded from the models/ directory which contains | |
| trained checkpoints exported using training/export_model.py. | |
| """ | |
| def __init__( | |
| self, | |
| models_dir: str = MODELS_DIR, | |
| device: str = "auto", | |
| preload_voices: Optional[List[str]] = None, | |
| ): | |
| """ | |
| Initialize TTS Engine | |
| Args: | |
| models_dir: Directory containing trained models | |
| device: Device to run inference on ('cpu', 'cuda', 'mps', or 'auto') | |
| preload_voices: List of voice keys to preload into memory | |
| """ | |
| self.models_dir = Path(models_dir) | |
| self.device = self._get_device(device) | |
| # Ensure models are available | |
| _ensure_models_available() | |
| # Model caches | |
| self._models: Dict[str, torch.jit.ScriptModule] = {} | |
| self._tokenizers: Dict[str, TTSTokenizer] = {} | |
| self._coqui_models: Dict[str, Any] = {} | |
| self._mms_models: Dict[str, Any] = {} | |
| self._mms_tokenizers: Dict[str, Any] = {} | |
| # Text normalizer | |
| self.normalizer = TextNormalizer() | |
| # Style processor | |
| self.style_processor = StyleProcessor() | |
| # Preload specified voices | |
| if preload_voices: | |
| for voice in preload_voices: | |
| self.load_voice(voice) | |
| logger.info(f"TTS Engine initialized on device: {self.device}") | |
| def _get_device(self, device: str) -> torch.device: | |
| """Determine the best device for inference""" | |
| if device == "auto": | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| else: | |
| return torch.device("cpu") | |
| return torch.device(device) | |
| def load_voice(self, voice_key: str) -> bool: | |
| """ | |
| Load a trained voice model into memory | |
| Args: | |
| voice_key: Key from LANGUAGE_CONFIGS (e.g., 'hi_male') | |
| Returns: | |
| True if loaded successfully | |
| """ | |
| if voice_key in self._models or voice_key in self._coqui_models: | |
| return True | |
| if voice_key not in LANGUAGE_CONFIGS: | |
| raise ValueError(f"Unknown voice: {voice_key}") | |
| config = LANGUAGE_CONFIGS[voice_key] | |
| model_dir = self.models_dir / voice_key | |
| if not model_dir.exists(): | |
| raise FileNotFoundError(f"Model not found: {model_dir}") | |
| # Check model type | |
| pth_files = list(model_dir.glob("*.pth")) | |
| pt_files = list(model_dir.glob("*.pt")) | |
| if pth_files: | |
| return self._load_coqui_voice(voice_key, model_dir, pth_files[0]) | |
| elif pt_files: | |
| return self._load_jit_voice(voice_key, model_dir, pt_files[0]) | |
| else: | |
| raise FileNotFoundError(f"No model file found in {model_dir}") | |
| def _load_jit_voice(self, voice_key: str, model_dir: Path, model_path: Path) -> bool: | |
| """Load a JIT traced VITS model""" | |
| chars_path = model_dir / "chars.txt" | |
| if chars_path.exists(): | |
| tokenizer = TTSTokenizer.from_chars_file(str(chars_path)) | |
| else: | |
| chars_files = list(model_dir.glob("*chars*.txt")) | |
| if chars_files: | |
| tokenizer = TTSTokenizer.from_chars_file(str(chars_files[0])) | |
| else: | |
| raise FileNotFoundError(f"No chars.txt found in {model_dir}") | |
| logger.info(f"Loading model from {model_path}") | |
| model = torch.jit.load(str(model_path), map_location=self.device) | |
| model.eval() | |
| self._models[voice_key] = model | |
| self._tokenizers[voice_key] = tokenizer | |
| logger.info(f"Loaded voice: {voice_key}") | |
| return True | |
| def _load_coqui_voice(self, voice_key: str, model_dir: Path, checkpoint_path: Path) -> bool: | |
| """Load a Coqui TTS checkpoint model""" | |
| config_path = model_dir / "config.json" | |
| if not config_path.exists(): | |
| raise FileNotFoundError(f"No config.json found in {model_dir}") | |
| try: | |
| from TTS.utils.synthesizer import Synthesizer | |
| logger.info(f"Loading checkpoint from {checkpoint_path}") | |
| use_cuda = self.device.type == "cuda" | |
| synthesizer = Synthesizer( | |
| tts_checkpoint=str(checkpoint_path), | |
| tts_config_path=str(config_path), | |
| use_cuda=use_cuda, | |
| ) | |
| self._coqui_models[voice_key] = synthesizer | |
| logger.info(f"Loaded voice: {voice_key}") | |
| return True | |
| except ImportError: | |
| raise ImportError("Coqui TTS library not installed.") | |
| def _synthesize_coqui(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]: | |
| """Synthesize using Coqui TTS model""" | |
| if voice_key not in self._coqui_models: | |
| self.load_voice(voice_key) | |
| synthesizer = self._coqui_models[voice_key] | |
| wav = synthesizer.tts(text) | |
| audio_np = np.array(wav, dtype=np.float32) | |
| sample_rate = synthesizer.output_sample_rate | |
| return audio_np, sample_rate | |
| def _load_mms_voice(self, voice_key: str) -> bool: | |
| """Load Facebook MMS model for Gujarati""" | |
| if voice_key in self._mms_models: | |
| return True | |
| config = LANGUAGE_CONFIGS[voice_key] | |
| logger.info(f"Loading MMS model: {config.hf_model_id}") | |
| try: | |
| from transformers import VitsModel, AutoTokenizer | |
| model = VitsModel.from_pretrained(config.hf_model_id) | |
| tokenizer = AutoTokenizer.from_pretrained(config.hf_model_id) | |
| model = model.to(self.device) | |
| model.eval() | |
| self._mms_models[voice_key] = model | |
| self._mms_tokenizers[voice_key] = tokenizer | |
| logger.info(f"Loaded MMS voice: {voice_key}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to load MMS model: {e}") | |
| raise | |
| def _synthesize_mms(self, text: str, voice_key: str) -> Tuple[np.ndarray, int]: | |
| """Synthesize using Facebook MMS model""" | |
| if voice_key not in self._mms_models: | |
| self._load_mms_voice(voice_key) | |
| model = self._mms_models[voice_key] | |
| tokenizer = self._mms_tokenizers[voice_key] | |
| config = LANGUAGE_CONFIGS[voice_key] | |
| inputs = tokenizer(text, return_tensors="pt") | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model(**inputs) | |
| audio = output.waveform.squeeze().cpu().numpy() | |
| return audio, config.sample_rate | |
| def unload_voice(self, voice_key: str): | |
| """Unload a voice to free memory""" | |
| if voice_key in self._models: | |
| del self._models[voice_key] | |
| del self._tokenizers[voice_key] | |
| if voice_key in self._coqui_models: | |
| del self._coqui_models[voice_key] | |
| if voice_key in self._mms_models: | |
| del self._mms_models[voice_key] | |
| del self._mms_tokenizers[voice_key] | |
| torch.cuda.empty_cache() if self.device.type == "cuda" else None | |
| logger.info(f"Unloaded voice: {voice_key}") | |
| def synthesize( | |
| self, | |
| text: str, | |
| voice: str = "hi_male", | |
| speed: float = 1.0, | |
| pitch: float = 1.0, | |
| energy: float = 1.0, | |
| style: Optional[str] = None, | |
| normalize_text: bool = True, | |
| ) -> TTSOutput: | |
| """ | |
| Synthesize speech from text | |
| Args: | |
| text: Input text to synthesize | |
| voice: Voice key (e.g., 'hi_male', 'bn_female') | |
| speed: Speech speed multiplier (0.5-2.0) | |
| pitch: Pitch multiplier (0.5-2.0) | |
| energy: Energy/volume multiplier (0.5-2.0) | |
| style: Style preset name (e.g., 'happy', 'sad') | |
| normalize_text: Whether to apply text normalization | |
| Returns: | |
| TTSOutput with audio array and metadata | |
| """ | |
| if style and style in STYLE_PRESETS: | |
| preset = STYLE_PRESETS[style] | |
| speed = speed * preset["speed"] | |
| pitch = pitch * preset["pitch"] | |
| energy = energy * preset["energy"] | |
| config = LANGUAGE_CONFIGS[voice] | |
| if normalize_text: | |
| text = self.normalizer.clean_text(text, config.code) | |
| # Route to appropriate model type | |
| if "mms" in voice: | |
| audio_np, sample_rate = self._synthesize_mms(text, voice) | |
| elif voice in self._coqui_models: | |
| audio_np, sample_rate = self._synthesize_coqui(text, voice) | |
| else: | |
| if voice not in self._models and voice not in self._coqui_models: | |
| self.load_voice(voice) | |
| if voice in self._coqui_models: | |
| audio_np, sample_rate = self._synthesize_coqui(text, voice) | |
| else: | |
| model = self._models[voice] | |
| tokenizer = self._tokenizers[voice] | |
| token_ids = tokenizer.text_to_ids(text) | |
| x = torch.from_numpy(np.array(token_ids)).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| audio = model(x) | |
| audio_np = audio.squeeze().cpu().numpy() | |
| sample_rate = config.sample_rate | |
| # Apply style modifications | |
| audio_np = self.style_processor.apply_style( | |
| audio_np, sample_rate, speed=speed, pitch=pitch, energy=energy | |
| ) | |
| duration = len(audio_np) / sample_rate | |
| return TTSOutput( | |
| audio=audio_np, | |
| sample_rate=sample_rate, | |
| duration=duration, | |
| voice=voice, | |
| text=text, | |
| style=style, | |
| ) | |
| def synthesize_to_file( | |
| self, | |
| text: str, | |
| output_path: str, | |
| voice: str = "hi_male", | |
| speed: float = 1.0, | |
| pitch: float = 1.0, | |
| energy: float = 1.0, | |
| style: Optional[str] = None, | |
| normalize_text: bool = True, | |
| ) -> str: | |
| """Synthesize speech and save to file""" | |
| import soundfile as sf | |
| output = self.synthesize(text, voice, speed, pitch, energy, style, normalize_text) | |
| sf.write(output_path, output.audio, output.sample_rate) | |
| logger.info(f"Saved audio to {output_path} (duration: {output.duration:.2f}s)") | |
| return output_path | |
| def get_loaded_voices(self) -> List[str]: | |
| """Get list of currently loaded voices""" | |
| return ( | |
| list(self._models.keys()) | |
| + list(self._coqui_models.keys()) | |
| + list(self._mms_models.keys()) | |
| ) | |
| def get_available_voices(self) -> Dict[str, Dict]: | |
| """Get all available voices with their status""" | |
| voices = {} | |
| for key, config in LANGUAGE_CONFIGS.items(): | |
| is_mms = "mms" in key | |
| model_dir = self.models_dir / key | |
| if is_mms: | |
| model_type = "mms" | |
| elif model_dir.exists() and list(model_dir.glob("*.pth")): | |
| model_type = "coqui" | |
| else: | |
| model_type = "vits" | |
| voices[key] = { | |
| "name": config.name, | |
| "code": config.code, | |
| "gender": "male" if "male" in key else ("female" if "female" in key else "neutral"), | |
| "loaded": key in self._models or key in self._coqui_models or key in self._mms_models, | |
| "downloaded": is_mms or get_model_path(key) is not None, | |
| "type": model_type, | |
| } | |
| return voices | |
| def get_style_presets(self) -> Dict[str, Dict]: | |
| """Get available style presets""" | |
| return STYLE_PRESETS | |
| def batch_synthesize(self, texts: List[str], voice: str = "hi_male", speed: float = 1.0) -> List[TTSOutput]: | |
| """Synthesize multiple texts""" | |
| return [self.synthesize(text, voice, speed) for text in texts] | |
| def synthesize(text: str, voice: str = "hi_male", output_path: Optional[str] = None) -> Union[TTSOutput, str]: | |
| """Quick synthesis function""" | |
| engine = TTSEngine() | |
| if output_path: | |
| return engine.synthesize_to_file(text, output_path, voice) | |
| return engine.synthesize(text, voice) | |