| import os |
| import torch |
| import soundfile as sf |
| from src.chatterbox_.tts import ChatterboxTTS |
| from safetensors.torch import load_file |
|
|
| |
| |
| MODEL_DIR = "./pretrained_models" |
|
|
| |
| |
| FINETUNED_WEIGHTS = "./models/best_finnish_multilingual_cp986.safetensors" |
|
|
| |
| TEXT = "Tervetuloa kokeilemaan hienoviritettyä suomenkielistä Chatterbox-puhesynteesiä." |
|
|
| |
| REFERENCE_AUDIO = "./samples/reference_finnish.wav" |
|
|
| |
| OUTPUT_FILE = "output_finnish.wav" |
| |
|
|
| def main(): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device}") |
|
|
| |
| print(f"Loading base model from {MODEL_DIR}...") |
| engine = ChatterboxTTS.from_local(MODEL_DIR, device=device) |
|
|
| |
| if os.path.exists(FINETUNED_WEIGHTS): |
| print(f"Loading finetuned weights from {FINETUNED_WEIGHTS}...") |
| checkpoint_state = load_file(FINETUNED_WEIGHTS) |
| |
| |
| t3_state_dict = {k[3:] if k.startswith("t3.") else k: v for k, v in checkpoint_state.items()} |
| |
| |
| engine.t3.load_state_dict(t3_state_dict, strict=False) |
| else: |
| print(f"Warning: Finetuned weights not found at {FINETUNED_WEIGHTS}. Using base weights.") |
|
|
| |
| print(f"Generating audio for: '{TEXT}'") |
| |
| wav_tensor = engine.generate( |
| text=TEXT, |
| audio_prompt_path=REFERENCE_AUDIO, |
| repetition_penalty=1.2, |
| temperature=0.8, |
| exaggeration=0.6 |
| ) |
|
|
| |
| wav_np = wav_tensor.squeeze().cpu().numpy() |
| sf.write(OUTPUT_FILE, wav_np, engine.sr) |
| print(f"Successfully saved audio to {OUTPUT_FILE}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|