Spaces:
Running
Running
File size: 7,082 Bytes
fc7b4a9 7633e2f fc7b4a9 7633e2f fc7b4a9 7633e2f fc7b4a9 7633e2f fc7b4a9 75d43d2 e26dafd 75d43d2 e26dafd 75d43d2 fc7b4a9 6530321 fc7b4a9 e26dafd fc7b4a9 7633e2f fc7b4a9 7633e2f fc7b4a9 7633e2f fc7b4a9 7633e2f fc7b4a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
import numpy as np
import time
import torch
from openunmix import predict
from src.musiclime.print_utils import green_bold
class OpenUnmixFactorization:
"""
Audio factorization using OpenUnmix source separation with temporal segmentation.
Decomposes audio into interpretable components by separating sources
(vocals, bass, drums, other) and segmenting each across time windows.
Creates temporal-source combinations for fine-grained audio explanations.
Attributes
----------
audio : ndarray
Original audio waveform
temporal_segments : list of tuple
Time window boundaries for segmentation
original_components : list of ndarray
Raw separated audio sources
component_names : list of str
Names of separated sources
components : list of ndarray
Final temporal-source component combinations
final_component_names : list of str
Names of temporal-source combinations
"""
def __init__(self, audio, temporal_segmentation_params=10, composition_fn=None):
"""
Initialize audio factorization using OpenUnmix source separation with temporal segmentation.
Parameters
----------
audio : array-like
Raw audio waveform data at 44.1kHz sample rate
temporal_segmentation_params : int, default=10
Number of temporal segments to divide the audio into
composition_fn : callable, optional
Custom function for composing separated sources (unused for now)
"""
print("[MusicLIME] Initializing OpenUnmix factorization...")
self.audio = audio
self.target_sr = 44100
start_time = time.time()
print(
f"[MusicLIME] Computing {temporal_segmentation_params} temporal segments..."
)
self.temporal_segments = self._compute_segments(
audio, temporal_segmentation_params
)
segmentation_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Temporal segmentation completed in {segmentation_time:.2f}s"
)
)
# Initialize source separation
start_time = time.time()
print("[MusicLIME] Separating audio sources...")
self.original_components, self.component_names = self._separate_sources()
print(f"[MusicLIME] Found components: {self.component_names}")
separation_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Source separation completed in {separation_time:.2f}s"
)
)
start_time = time.time()
print("[MusicLIME] Preparing temporal-source combinations...")
self._prepare_temporal_components()
print(f"[MusicLIME] Created {len(self.components)} total components")
preparation_time = time.time() - start_time
print(
green_bold(
f"[MusicLIME] Component preparation completed in {preparation_time:.2f}s"
)
)
def _compute_segments(self, signal, n_segments):
"""
Divide audio signal into equal temporal segments for factorization.
Parameters
----------
signal : array-like
Input audio waveform
n_segments : int
Number of temporal segments to create
Returns
-------
list of tuple
List of (start, end) sample indices for each segment
"""
audio_length = len(signal)
samples_per_segment = audio_length // n_segments
segments = []
for i in range(n_segments):
start = i * samples_per_segment
end = start + samples_per_segment
segments.append((start, end))
return segments
def _separate_sources(self):
"""
Perform source separation using OpenUnmix to extract instrument components.
Returns
-------
components : list of ndarray
Separated audio sources (vocals, bass, drums, other)
names : list of str
Names of the separated source components
"""
waveform = np.expand_dims(self.audio, axis=1)
# Load openunmix .pth files from local dir
model_path = "models/musiclime"
# Specify targets
targets = ["vocals", "bass", "drums", "other"]
# Specify device based on availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[MusicLIME] Using device for source separation: {device}")
# Then load openunmix files to openunmix' method
prediction = predict.separate(
torch.as_tensor(waveform).float(),
rate=44100,
model_str_or_path=model_path,
targets=targets,
device=device,
)
components = [
prediction[key][0].mean(dim=0).cpu().numpy() for key in prediction
]
names = list(prediction.keys())
return components, names
def _prepare_temporal_components(self):
"""
Create temporal-source combinations by applying each source to each time segment.
Creates components like 'vocals_T0', 'drums_T5' representing specific
instruments active only in specific temporal windows.
"""
# Create temporal-source combinations
self.components = []
self.final_component_names = []
for s, (start, end) in enumerate(self.temporal_segments):
for c, component in enumerate(self.original_components):
temp_component = np.zeros_like(self.audio)
temp_component[start:end] = component[start:end]
self.components.append(temp_component)
self.final_component_names.append(f"{self.component_names[c]}_T{s}")
def get_number_components(self):
"""
Get total number of factorized components (sources x temporal segments).
Returns
-------
int
Total number of temporal-source component combinations
"""
return len(self.components)
def get_ordered_component_names(self):
"""
Get ordered list of component names for explanation display.
Returns
-------
list of str
Component names in format '{source}_T{segment}' (e.g., 'vocals_T3')
"""
return self.final_component_names
def compose_model_input(self, component_indices):
"""
Reconstruct audio by summing selected temporal-source components.
Parameters
----------
component_indices : array-like
Indices of components to include in reconstruction
Returns
-------
ndarray
Reconstructed audio waveform from selected components
"""
if len(component_indices) == 0:
return np.zeros_like(self.audio)
selected_components = [self.components[i] for i in component_indices]
return sum(selected_components)
|