| | import os |
| | import requests |
| | import yaml |
| | import torch |
| | import librosa |
| | import numpy as np |
| | import soundfile as sf |
| | from pathlib import Path |
| | from transformers import T5Tokenizer, T5EncoderModel |
| | from tqdm import tqdm |
| | from .src.vc_wrapper import ReDiffVC, DreamVC |
| | from .src.plugin_wrapper import DreamVG |
| | from .src.modules.speaker_encoder.encoder import inference as spk_encoder |
| | from .src.modules.BigVGAN.inference import load_model as load_vocoder |
| | from .src.feats.contentvec_hf import get_content_model, get_content |
| |
|
| |
|
| | class DreamVoice: |
| | def __init__(self, config='dreamvc.yaml', mode='plugin', device='cuda', chunk_size=16): |
| | |
| | script_dir = Path(__file__).resolve().parent |
| | config_path = script_dir / config |
| |
|
| | |
| | with open(config_path, 'r') as fp: |
| | self.config = yaml.safe_load(fp) |
| |
|
| | self.script_dir = script_dir |
| |
|
| | |
| | self._ensure_checkpoints_exist() |
| |
|
| | |
| | self.device = device |
| | self.sr = self.config['sample_rate'] |
| |
|
| | |
| | vocoder_path = script_dir / self.config['vocoder_path'] |
| | self.hifigan, _ = load_vocoder(vocoder_path, device) |
| | self.hifigan.eval() |
| |
|
| | |
| | self.content_model = get_content_model().to(device) |
| |
|
| | |
| | lm_path = self.config['lm_path'] |
| | self.tokenizer = T5Tokenizer.from_pretrained(lm_path) |
| | self.text_encoder = T5EncoderModel.from_pretrained(lm_path).to(device).eval() |
| |
|
| | |
| | self.mode = mode |
| | if mode == 'plugin': |
| | self._init_plugin_mode() |
| | elif mode == 'end2end': |
| | self._init_end2end_mode() |
| | else: |
| | raise NotImplementedError("Select mode from 'plugin' and 'end2end'") |
| |
|
| | |
| | self.chunk_size = chunk_size * 50 |
| |
|
| | def _ensure_checkpoints_exist(self): |
| | checkpoints = [ |
| | ('vocoder_path', self.config.get('vocoder_url')), |
| | ('vocoder_config_path', self.config.get('vocoder_config_url')), |
| | ('speaker_path', self.config.get('speaker_url')), |
| | ('dreamvc.ckpt_path', self.config.get('dreamvc', {}).get('ckpt_url')), |
| | ('rediffvc.ckpt_path', self.config.get('rediffvc', {}).get('ckpt_url')), |
| | ('dreamvg.ckpt_path', self.config.get('dreamvg', {}).get('ckpt_url')) |
| | ] |
| |
|
| | for path_key, url in checkpoints: |
| | local_path = self._get_local_path(path_key) |
| | if not local_path.exists() and url: |
| | print(f"Downloading {path_key} from {url}") |
| | self._download_file(url, local_path) |
| |
|
| | def _get_local_path(self, path_key): |
| | keys = path_key.split('.') |
| | local_path = self.config |
| | for key in keys: |
| | local_path = local_path.get(key, {}) |
| | return self.script_dir / local_path |
| |
|
| | def _download_file(self, url, local_path): |
| | try: |
| | |
| | response = requests.get(url, stream=True) |
| | response.raise_for_status() |
| | except requests.exceptions.RequestException as e: |
| | |
| | print(f"Error encountered: {e}") |
| |
|
| | |
| | user_input = input("Private checkpoint, please request authorization and enter your Hugging Face API key.") |
| | self.hf_key = user_input if user_input else None |
| |
|
| | |
| | headers = {'Authorization': f'Bearer {self.hf_key}'} if self.hf_key else {} |
| |
|
| | try: |
| | |
| | response = requests.get(url, stream=True, headers=headers) |
| | response.raise_for_status() |
| | except requests.exceptions.RequestException as e: |
| | |
| | print(f"Error encountered in dev mode: {e}") |
| | response = None |
| |
|
| | local_path.parent.mkdir(parents=True, exist_ok=True) |
| |
|
| | total_size = int(response.headers.get('content-length', 0)) |
| | block_size = 8192 |
| | t = tqdm(total=total_size, unit='iB', unit_scale=True) |
| |
|
| | with open(local_path, 'wb') as f: |
| | for chunk in response.iter_content(chunk_size=block_size): |
| | t.update(len(chunk)) |
| | f.write(chunk) |
| | t.close() |
| |
|
| | def _init_plugin_mode(self): |
| | |
| | self.dreamvc = ReDiffVC( |
| | config_path=self.script_dir / self.config['rediffvc']['config_path'], |
| | ckpt_path=self.script_dir / self.config['rediffvc']['ckpt_path'], |
| | device=self.device |
| | ) |
| |
|
| | |
| | self.dreamvg = DreamVG( |
| | config_path=self.script_dir / self.config['dreamvg']['config_path'], |
| | ckpt_path=self.script_dir / self.config['dreamvg']['ckpt_path'], |
| | device=self.device |
| | ) |
| |
|
| | |
| | spk_encoder.load_model(self.script_dir / self.config['speaker_path'], self.device) |
| | self.spk_encoder = spk_encoder |
| | self.spk_embed_cache = None |
| |
|
| | def _init_end2end_mode(self): |
| | |
| | self.dreamvc = DreamVC( |
| | config_path=self.script_dir / self.config['dreamvc']['config_path'], |
| | ckpt_path=self.script_dir / self.config['dreamvc']['ckpt_path'], |
| | device=self.device |
| | ) |
| |
|
| | def _load_content(self, audio_path): |
| | content_audio, _ = librosa.load(audio_path, sr=16000) |
| | |
| | target_length = ((len(content_audio) + 16*160 - 1) // (16*160)) * (16*160) |
| | |
| | if len(content_audio) < target_length: |
| | content_audio = np.pad(content_audio, (0, target_length - len(content_audio)), mode='constant') |
| | content_audio = torch.tensor(content_audio).unsqueeze(0).to(self.device) |
| | content_clip = get_content(self.content_model, content_audio) |
| | return content_clip |
| |
|
| | def load_spk_embed(self, emb_path): |
| | self.spk_embed_cache = torch.load(emb_path, map_location=self.device) |
| |
|
| | def save_spk_embed(self, emb_path): |
| | assert self.spk_embed_cache is not None |
| | torch.save(self.spk_embed_cache.cpu(), emb_path) |
| |
|
| | def save_audio(self, output_path, audio, sr): |
| | sf.write(output_path, audio, samplerate=sr) |
| |
|
| | @torch.no_grad() |
| | def genvc(self, content_audio, prompt, |
| | prompt_guidance_scale=3, prompt_guidance_rescale=0.0, |
| | prompt_ddim_steps=100, prompt_eta=1, prompt_random_seed=None, |
| | vc_guidance_scale=3, vc_guidance_rescale=0.0, |
| | vc_ddim_steps=50, vc_eta=1, vc_random_seed=None, |
| | ): |
| |
|
| | content_clip = self._load_content(content_audio) |
| |
|
| | text_batch = self.tokenizer(prompt, max_length=32, |
| | padding='max_length', truncation=True, return_tensors="pt") |
| | text, text_mask = text_batch.input_ids.to(self.device), \ |
| | text_batch.attention_mask.to(self.device) |
| | text = self.text_encoder(input_ids=text, attention_mask=text_mask)[0] |
| |
|
| | if self.mode == 'plugin': |
| | spk_embed = self.dreamvg.inference([text, text_mask], |
| | guidance_scale=prompt_guidance_scale, |
| | guidance_rescale=prompt_guidance_rescale, |
| | ddim_steps=prompt_ddim_steps, eta=prompt_eta, |
| | random_seed=prompt_random_seed) |
| | |
| | B, L, D = content_clip.shape |
| | gen_audio_chunks = [] |
| | num_chunks = (L + self.chunk_size - 1) // self.chunk_size |
| | for i in range(num_chunks): |
| | start_idx = i * self.chunk_size |
| | end_idx = min((i + 1) * self.chunk_size, L) |
| | content_clip_chunk = content_clip[:, start_idx:end_idx, :] |
| | |
| | gen_audio_chunk = self.dreamvc.inference( |
| | spk_embed, content_clip_chunk, None, |
| | guidance_scale=vc_guidance_scale, |
| | guidance_rescale=vc_guidance_rescale, |
| | ddim_steps=vc_ddim_steps, |
| | eta=vc_eta, |
| | random_seed=vc_random_seed) |
| | |
| | gen_audio_chunks.append(gen_audio_chunk) |
| | |
| | gen_audio = torch.cat(gen_audio_chunks, dim=-1) |
| |
|
| | self.spk_embed_cache = spk_embed |
| |
|
| | elif self.mode == 'end2end': |
| | B, L, D = content_clip.shape |
| | gen_audio_chunks = [] |
| | num_chunks = (L + self.chunk_size - 1) // self.chunk_size |
| | |
| | for i in range(num_chunks): |
| | start_idx = i * self.chunk_size |
| | end_idx = min((i + 1) * self.chunk_size, L) |
| | content_clip_chunk = content_clip[:, start_idx:end_idx, :] |
| | |
| | gen_audio_chunk = self.dreamvc.inference([text, text_mask], content_clip, |
| | guidance_scale=prompt_guidance_scale, |
| | guidance_rescale=prompt_guidance_rescale, |
| | ddim_steps=prompt_ddim_steps, |
| | eta=prompt_eta, random_seed=prompt_random_seed) |
| | gen_audio_chunks.append(gen_audio_chunk) |
| | |
| | gen_audio = torch.cat(gen_audio_chunks, dim=-1) |
| |
|
| | else: |
| | raise NotImplementedError("Select mode from 'plugin' and 'end2end'") |
| |
|
| | gen_audio = self.hifigan(gen_audio.squeeze(1)) |
| | gen_audio = gen_audio.cpu().numpy().squeeze(0).squeeze(0) |
| |
|
| | return gen_audio, self.sr |
| |
|
| | @torch.no_grad() |
| | def simplevc(self, content_audio, speaker_audio=None, use_spk_cache=False, |
| | vc_guidance_scale=3, vc_guidance_rescale=0.0, |
| | vc_ddim_steps=50, vc_eta=1, vc_random_seed=None, |
| | ): |
| |
|
| | assert self.mode == 'plugin' |
| | if speaker_audio is not None: |
| | speaker_audio, _ = librosa.load(speaker_audio, sr=16000) |
| | speaker_audio = torch.tensor(speaker_audio).unsqueeze(0).to(self.device) |
| | spk_embed = spk_encoder.embed_utterance_batch(speaker_audio) |
| | self.spk_embed_cache = spk_embed |
| | elif use_spk_cache: |
| | assert self.spk_embed_cache is not None |
| | spk_embed = self.spk_embed_cache |
| | else: |
| | raise NotImplementedError |
| |
|
| | content_clip = self._load_content(content_audio) |
| |
|
| | B, L, D = content_clip.shape |
| | gen_audio_chunks = [] |
| | num_chunks = (L + self.chunk_size - 1) // self.chunk_size |
| | for i in range(num_chunks): |
| | start_idx = i * self.chunk_size |
| | end_idx = min((i + 1) * self.chunk_size, L) |
| | content_clip_chunk = content_clip[:, start_idx:end_idx, :] |
| | |
| | gen_audio_chunk = self.dreamvc.inference( |
| | spk_embed, content_clip_chunk, None, |
| | guidance_scale=vc_guidance_scale, |
| | guidance_rescale=vc_guidance_rescale, |
| | ddim_steps=vc_ddim_steps, |
| | eta=vc_eta, |
| | random_seed=vc_random_seed) |
| | |
| | gen_audio_chunks.append(gen_audio_chunk) |
| | |
| | gen_audio = torch.cat(gen_audio_chunks, dim=-1) |
| |
|
| | gen_audio = self.hifigan(gen_audio.squeeze(1)) |
| | gen_audio = gen_audio.cpu().numpy().squeeze(0).squeeze(0) |
| |
|
| | return gen_audio, self.sr |
| |
|
| |
|
| | if __name__ == '__main__': |
| | dreamvoice = DreamVoice(config='dreamvc.yaml', mode='plugin', device='cuda') |
| | content_audio = 'test.wav' |
| | speaker_audio = 'speaker.wav' |
| | prompt = 'young female voice, sounds young and cute' |
| | gen_audio, sr = dreamvoice.genvc('test.wav', prompt) |
| | dreamvoice.save_audio('debug.wav', gen_audio, sr) |