Spaces:
Sleeping
Sleeping
| import torch | |
| import torchaudio | |
| import scipy.io.wavfile | |
| from transformers import AutoProcessor, SeamlessM4Tv2Model | |
| from pathlib import Path | |
| from typing import Optional, Union | |
| class SeamlessTranslator: | |
| """ | |
| A wrapper class for Facebook's SeamlessM4T translation model. | |
| Handles both text-to-speech and speech-to-speech translation. | |
| """ | |
| def __init__(self, model_name: str = "facebook/seamless-m4t-v2-large"): | |
| """ | |
| Initialize the translator with the specified model. | |
| Args: | |
| model_name (str): Name of the model to use | |
| """ | |
| try: | |
| self.processor = AutoProcessor.from_pretrained(model_name) | |
| self.model = SeamlessM4Tv2Model.from_pretrained(model_name) | |
| self.sample_rate = self.model.config.sampling_rate | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to initialize model: {str(e)}") | |
| def translate_text(self, text: str, src_lang: str, tgt_lang: str) -> numpy.ndarray: | |
| """ | |
| Translate text to speech in the target language. | |
| Args: | |
| text (str): Input text to translate | |
| src_lang (str): Source language code (e.g., 'eng') | |
| tgt_lang (str): Target language code (e.g., 'rus') | |
| Returns: | |
| numpy.ndarray: Audio waveform array | |
| """ | |
| try: | |
| inputs = self.processor(text=text, src_lang=src_lang, return_tensors="pt") | |
| audio_array = self.model.generate(**inputs, tgt_lang=tgt_lang)[0].cpu().numpy().squeeze() | |
| return audio_array | |
| except Exception as e: | |
| raise RuntimeError(f"Text translation failed: {str(e)}") | |
| def translate_audio(self, audio_path: Union[str, Path], tgt_lang: str) -> numpy.ndarray: | |
| """ | |
| Translate audio to speech in the target language. | |
| Args: | |
| audio_path (str or Path): Path to input audio file | |
| tgt_lang (str): Target language code (e.g., 'rus') | |
| Returns: | |
| numpy.ndarray: Audio waveform array | |
| """ | |
| try: | |
| # Load and resample audio | |
| audio, orig_freq = torchaudio.load(audio_path) | |
| audio = torchaudio.functional.resample( | |
| audio, | |
| orig_freq=orig_freq, | |
| new_freq=16_000 | |
| ) | |
| # Process and generate translation | |
| inputs = self.processor(audios=audio, return_tensors="pt") | |
| audio_array = self.model.generate(**inputs, tgt_lang=tgt_lang)[0].cpu().numpy().squeeze() | |
| return audio_array | |
| except Exception as e: | |
| raise RuntimeError(f"Audio translation failed: {str(e)}") | |
| def save_audio(self, audio_array: numpy.ndarray, output_path: Union[str, Path]) -> None: | |
| """ | |
| Save an audio array to a WAV file. | |
| Args: | |
| audio_array (numpy.ndarray): Audio data to save | |
| output_path (str or Path): Path where to save the WAV file | |
| """ | |
| try: | |
| scipy.io.wavfile.write( | |
| output_path, | |
| rate=self.sample_rate, | |
| data=audio_array | |
| ) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to save audio: {str(e)}") | |
| def main(): | |
| """Example usage of the SeamlessTranslator class.""" | |
| try: | |
| # Initialize translator | |
| translator = SeamlessTranslator() | |
| # Example text translation | |
| text_audio = translator.translate_text( | |
| text="Hello, my dog is cute", | |
| src_lang="eng", | |
| tgt_lang="rus" | |
| ) | |
| translator.save_audio(text_audio, "output_from_text.wav") | |
| # Example audio translation | |
| audio_audio = translator.translate_audio( | |
| audio_path="input_audio.wav", | |
| tgt_lang="rus" | |
| ) | |
| translator.save_audio(audio_audio, "output_from_audio.wav") | |
| except Exception as e: | |
| print(f"Translation failed: {str(e)}") | |
| if __name__ == "__main__": | |
| main() |