Spaces:
Paused
Paused
| import sys | |
| sys.path.append('./BigVGAN') | |
| import time | |
| import torch | |
| import torchaudio | |
| import argparse | |
| from tqdm import tqdm | |
| import librosa | |
| from BigVGAN import bigvgan | |
| from BigVGAN.meldataset import get_mel_spectrogram | |
| from model import OptimizedAudioRestorationModel | |
| # Set the device handle macbooks with M1 chip | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| # Initialize BigVGAN model | |
| bigvgan_model = bigvgan.BigVGAN.from_pretrained( | |
| 'nvidia/bigvgan_v2_24khz_100band_256x', | |
| use_cuda_kernel=False, | |
| force_download=False | |
| ).to(device) | |
| bigvgan_model.remove_weight_norm() | |
| def measure_gpu_memory(): | |
| if device == 'cuda': | |
| torch.cuda.synchronize() | |
| return torch.cuda.max_memory_allocated() / (1024 ** 2) # Convert to MB | |
| return 0 | |
| def apply_overlap_windowing_waveform(waveform, window_size_samples, overlap): | |
| step_size = int(window_size_samples * (1 - overlap)) | |
| num_chunks = (waveform.shape[-1] - window_size_samples) // step_size + 1 | |
| windows = [] | |
| for i in range(num_chunks): | |
| start_idx = i * step_size | |
| end_idx = start_idx + window_size_samples | |
| chunk = waveform[..., start_idx:end_idx] | |
| windows.append(chunk) | |
| return torch.stack(windows) | |
| def reconstruct_waveform_from_windows(windows, window_size_samples, overlap): | |
| step_size = int(window_size_samples * (1 - overlap)) | |
| shape = windows.shape | |
| if len(shape) == 2: | |
| # windows.shape == (num_windows, window_len) | |
| num_windows, window_len = shape | |
| channels = 1 | |
| windows = windows.unsqueeze(1) # Now windows.shape == (num_windows, 1, window_len) | |
| elif len(shape) == 3: | |
| num_windows, channels, window_len = shape | |
| else: | |
| raise ValueError(f"Unexpected windows.shape: {windows.shape}") | |
| output_length = (num_windows - 1) * step_size + window_size_samples | |
| reconstructed = torch.zeros((channels, output_length)) | |
| window_sums = torch.zeros((channels, output_length)) | |
| for i in range(num_windows): | |
| start_idx = i * step_size | |
| end_idx = start_idx + window_len | |
| reconstructed[:, start_idx:end_idx] += windows[i] | |
| window_sums[:, start_idx:end_idx] += 1 | |
| reconstructed = reconstructed / window_sums.clamp(min=1e-6) | |
| if channels == 1: | |
| reconstructed = reconstructed.squeeze(0) # Remove channel dimension if single channel | |
| return reconstructed | |
| def load_model(save_path): | |
| """ | |
| Load the optimized audio restoration model. | |
| Parameters: | |
| - save_path: Path to the checkpoint file. | |
| """ | |
| optimized_model = OptimizedAudioRestorationModel(device=device, bigvgan_model=bigvgan_model) | |
| state_dict = torch.load(save_path, map_location=device) | |
| if 'model_state_dict' in state_dict: | |
| state_dict = state_dict['model_state_dict'] | |
| optimized_model.voice_restore.load_state_dict(state_dict, strict=True) | |
| return optimized_model | |
| def restore_audio(model, input_path, output_path, steps=16, cfg_strength=0.5, window_size_sec=5.0, overlap=0.5): | |
| # Load the audio file | |
| start_time = time.time() | |
| initial_gpu_memory = measure_gpu_memory() | |
| wav, sr = librosa.load(input_path, sr=24000, mono=True) | |
| wav = torch.FloatTensor(wav).unsqueeze(0) # Shape: [1, num_samples] | |
| window_size_samples = int(window_size_sec * sr) | |
| step_size = int(window_size_samples * (1 - overlap)) | |
| # Apply overlapping windowing to the waveform | |
| wav_windows = apply_overlap_windowing_waveform(wav, window_size_samples, overlap) | |
| restored_wav_windows = [] | |
| for wav_window in tqdm(wav_windows): | |
| wav_window = wav_window.to(device) # Shape: [1, window_size_samples] | |
| # Convert to Mel-spectrogram | |
| processed_mel = get_mel_spectrogram(wav_window, bigvgan_model.h).to(device) | |
| # Restore audio | |
| with torch.no_grad(): | |
| with torch.autocast(device): | |
| restored_mel = model.voice_restore.sample(processed_mel.transpose(1, 2), steps=steps, cfg_strength=cfg_strength) | |
| restored_mel = restored_mel.squeeze(0).transpose(0, 1) | |
| # Convert restored mel-spectrogram to waveform | |
| with torch.no_grad(): | |
| with torch.autocast(device): | |
| restored_wav = bigvgan_model(restored_mel.unsqueeze(0)).squeeze(0).float().cpu() # Shape: [num_samples] | |
| # Debug: Print shapes | |
| # print(f"restored_wav.shape: {restored_wav.shape}") | |
| restored_wav_windows.append(restored_wav) | |
| del wav_window, processed_mel, restored_mel, restored_wav | |
| torch.cuda.empty_cache() | |
| restored_wav_windows = torch.stack(restored_wav_windows) # Shape: [num_windows, num_samples] | |
| # Debug: Print shapes | |
| # print(f"restored_wav_windows.shape: {restored_wav_windows.shape}") | |
| # Reconstruct the full waveform from the processed windows | |
| restored_wav = reconstruct_waveform_from_windows(restored_wav_windows, window_size_samples, overlap) | |
| # Ensure the restored_wav has correct dimensions for saving | |
| if restored_wav.dim() == 1: | |
| restored_wav = restored_wav.unsqueeze(0) # Shape: [1, num_samples] | |
| # Save the restored audio | |
| torchaudio.save(output_path, restored_wav, 24000) | |
| end_time = time.time() | |
| total_time = end_time - start_time | |
| peak_gpu_memory = measure_gpu_memory() | |
| gpu_memory_used = peak_gpu_memory - initial_gpu_memory | |
| print(f"Total inference time: {total_time:.2f} seconds") | |
| print(f"Peak GPU memory usage: {peak_gpu_memory:.2f} MB") | |
| print(f"GPU memory used: {gpu_memory_used:.2f} MB") | |
| if __name__ == "__main__": | |
| # Argument parser setup | |
| parser = argparse.ArgumentParser(description="Audio restoration using OptimizedAudioRestorationModel for long-form audio.") | |
| parser.add_argument('--checkpoint', type=str, required=True, help="Path to the checkpoint file") | |
| parser.add_argument('--input', type=str, required=True, help="Path to the input audio file") | |
| parser.add_argument('--output', type=str, required=True, help="Path to save the restored audio file") | |
| parser.add_argument('--steps', type=int, default=16, help="Number of sampling steps") | |
| parser.add_argument('--cfg_strength', type=float, default=0.5, help="CFG strength value") | |
| parser.add_argument('--window_size_sec', type=float, default=5.0, help="Window size in seconds for overlapping") | |
| parser.add_argument('--overlap', type=float, default=0.5, help="Overlap ratio for windowing") | |
| # Parse arguments | |
| args = parser.parse_args() | |
| # Load the optimized model | |
| optimized_model = load_model(args.checkpoint) | |
| if device == 'cuda': | |
| optimized_model.bfloat16() | |
| optimized_model.eval() | |
| optimized_model.to(device) | |
| # Use the model to restore audio | |
| restore_audio( | |
| optimized_model, | |
| args.input, | |
| args.output, | |
| steps=args.steps, | |
| cfg_strength=args.cfg_strength, | |
| window_size_sec=args.window_size_sec, | |
| overlap=args.overlap | |
| ) | |