File size: 10,104 Bytes
2bbfbb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 |
//! Mel-spectrogram computation
//!
//! Implements Short-Time Fourier Transform (STFT) and mel filterbank
use crate::{Error, Result};
use ndarray::{Array1, Array2, Axis};
use num_complex::Complex;
use realfft::RealFftPlanner;
use std::f32::consts::PI;
use super::AudioConfig;
/// Mel filterbank for converting linear spectrogram to mel scale
#[derive(Debug, Clone)]
pub struct MelFilterbank {
/// Filterbank matrix (n_mels x n_fft/2+1)
pub filters: Array2<f32>,
/// Sample rate
pub sample_rate: u32,
/// Number of mel bands
pub n_mels: usize,
/// FFT size
pub n_fft: usize,
}
impl MelFilterbank {
/// Create mel filterbank
pub fn new(sample_rate: u32, n_fft: usize, n_mels: usize, fmin: f32, fmax: f32) -> Self {
let filters = create_mel_filterbank(sample_rate, n_fft, n_mels, fmin, fmax);
Self {
filters,
sample_rate,
n_mels,
n_fft,
}
}
/// Apply filterbank to power spectrogram
pub fn apply(&self, spectrogram: &Array2<f32>) -> Array2<f32> {
// spectrogram: (n_fft/2+1, time_frames)
// filters: (n_mels, n_fft/2+1)
// output: (n_mels, time_frames)
self.filters.dot(spectrogram)
}
}
/// Convert frequency to mel scale
pub fn hz_to_mel(hz: f32) -> f32 {
2595.0 * (1.0 + hz / 700.0).log10()
}
/// Convert mel to frequency
pub fn mel_to_hz(mel: f32) -> f32 {
700.0 * (10f32.powf(mel / 2595.0) - 1.0)
}
/// Create mel filterbank matrix
fn create_mel_filterbank(
sample_rate: u32,
n_fft: usize,
n_mels: usize,
fmin: f32,
fmax: f32,
) -> Array2<f32> {
let n_freqs = n_fft / 2 + 1;
// Convert to mel scale
let mel_min = hz_to_mel(fmin);
let mel_max = hz_to_mel(fmax);
// Create mel points
let mel_points: Vec<f32> = (0..=n_mels + 1)
.map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
.collect();
// Convert back to Hz
let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();
// Convert to FFT bin numbers
let bin_points: Vec<usize> = hz_points
.iter()
.map(|&hz| ((n_fft as f32 + 1.0) * hz / sample_rate as f32).floor() as usize)
.collect();
// Create filterbank
let mut filters = Array2::zeros((n_mels, n_freqs));
for m in 0..n_mels {
let f_left = bin_points[m];
let f_center = bin_points[m + 1];
let f_right = bin_points[m + 2];
// Left slope
for k in f_left..f_center {
if k < n_freqs {
filters[[m, k]] = (k - f_left) as f32 / (f_center - f_left).max(1) as f32;
}
}
// Right slope
for k in f_center..f_right {
if k < n_freqs {
filters[[m, k]] = (f_right - k) as f32 / (f_right - f_center).max(1) as f32;
}
}
}
filters
}
/// Compute Hann window
fn hann_window(size: usize) -> Vec<f32> {
(0..size)
.map(|n| 0.5 * (1.0 - (2.0 * PI * n as f32 / size as f32).cos()))
.collect()
}
/// Compute Short-Time Fourier Transform (STFT)
///
/// # Arguments
/// * `signal` - Input audio signal
/// * `n_fft` - FFT size
/// * `hop_length` - Hop length between frames
/// * `win_length` - Window length (padded to n_fft)
///
/// # Returns
/// Complex STFT matrix (n_fft/2+1, time_frames)
pub fn stft(
signal: &[f32],
n_fft: usize,
hop_length: usize,
win_length: usize,
) -> Result<Array2<Complex<f32>>> {
if signal.is_empty() {
return Err(Error::Audio("Empty signal".into()));
}
// Create window
let window = hann_window(win_length);
// Pad signal
let pad_length = n_fft / 2;
let mut padded = vec![0.0f32; pad_length];
padded.extend_from_slice(signal);
padded.extend(vec![0.0f32; pad_length]);
// Calculate number of frames
let num_frames = (padded.len() - n_fft) / hop_length + 1;
let n_freqs = n_fft / 2 + 1;
// Create FFT planner
let mut planner = RealFftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(n_fft);
// Output matrix
let mut stft_matrix = Array2::zeros((n_freqs, num_frames));
// Process each frame
let mut input_buffer = vec![0.0f32; n_fft];
let mut output_buffer = vec![Complex::new(0.0f32, 0.0f32); n_freqs];
for (frame_idx, start) in (0..padded.len() - n_fft + 1)
.step_by(hop_length)
.enumerate()
{
if frame_idx >= num_frames {
break;
}
// Extract and window the frame
for i in 0..win_length {
input_buffer[i] = padded[start + i] * window[i];
}
// Zero pad if win_length < n_fft
for i in win_length..n_fft {
input_buffer[i] = 0.0;
}
// Perform FFT
fft.process(&mut input_buffer, &mut output_buffer)
.map_err(|e| Error::Audio(format!("FFT failed: {}", e)))?;
// Store result
for (freq_idx, &val) in output_buffer.iter().enumerate() {
stft_matrix[[freq_idx, frame_idx]] = val;
}
}
Ok(stft_matrix)
}
/// Compute magnitude spectrogram from STFT
pub fn magnitude_spectrogram(stft_matrix: &Array2<Complex<f32>>) -> Array2<f32> {
stft_matrix.mapv(|c| c.norm())
}
/// Compute power spectrogram from STFT
pub fn power_spectrogram(stft_matrix: &Array2<Complex<f32>>) -> Array2<f32> {
stft_matrix.mapv(|c| c.norm_sqr())
}
/// Compute mel spectrogram from audio signal
///
/// # Arguments
/// * `signal` - Audio samples
/// * `config` - Audio configuration
///
/// # Returns
/// Log mel spectrogram (n_mels, time_frames)
pub fn mel_spectrogram(signal: &[f32], config: &AudioConfig) -> Result<Array2<f32>> {
// Compute STFT
let stft_matrix = stft(signal, config.n_fft, config.hop_length, config.win_length)?;
// Compute power spectrogram
let power_spec = power_spectrogram(&stft_matrix);
// Create mel filterbank
let mel_fb = MelFilterbank::new(
config.sample_rate,
config.n_fft,
config.n_mels,
config.fmin,
config.fmax,
);
// Apply mel filterbank
let mel_spec = mel_fb.apply(&power_spec);
// Apply log compression
let log_mel_spec = mel_spec.mapv(|x| (x.max(1e-10)).ln());
Ok(log_mel_spec)
}
/// Compute mel spectrogram with normalization
pub fn mel_spectrogram_normalized(
signal: &[f32],
config: &AudioConfig,
mean: Option<f32>,
std: Option<f32>,
) -> Result<Array2<f32>> {
let mut mel_spec = mel_spectrogram(signal, config)?;
// Normalize
if let (Some(m), Some(s)) = (mean, std) {
mel_spec.mapv_inplace(|x| (x - m) / s);
} else {
// Compute statistics from spectrogram
let m = mel_spec.mean().unwrap_or(0.0);
let s = mel_spec.std(0.0);
if s > 1e-8 {
mel_spec.mapv_inplace(|x| (x - m) / s);
}
}
Ok(mel_spec)
}
/// Convert mel spectrogram back to linear spectrogram (approximate)
pub fn mel_to_linear(mel_spec: &Array2<f32>, mel_fb: &MelFilterbank) -> Array2<f32> {
// Pseudo-inverse of mel filterbank
let filters_t = mel_fb.filters.t();
let gram = mel_fb.filters.dot(&filters_t);
// Simple approximation using transpose
filters_t.dot(mel_spec)
}
/// Compute spectrogram energy per frame
pub fn frame_energy(mel_spec: &Array2<f32>) -> Array1<f32> {
mel_spec.sum_axis(Axis(0))
}
/// Detect voice activity based on energy threshold
pub fn voice_activity_detection(mel_spec: &Array2<f32>, threshold_db: f32) -> Vec<bool> {
let energy = frame_energy(mel_spec);
let max_energy = energy.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let threshold = max_energy + threshold_db; // threshold_db is negative
energy.iter().map(|&e| e > threshold).collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hz_to_mel() {
// Test known conversions
assert!((hz_to_mel(0.0) - 0.0).abs() < 1e-6);
assert!((hz_to_mel(1000.0) - 1000.0).abs() < 50.0); // Roughly linear at low freqs
}
#[test]
fn test_mel_to_hz() {
// Round trip
let hz = 440.0;
let mel = hz_to_mel(hz);
let hz_back = mel_to_hz(mel);
assert!((hz - hz_back).abs() < 1e-4);
}
#[test]
fn test_mel_filterbank_creation() {
let fb = MelFilterbank::new(22050, 1024, 80, 0.0, 8000.0);
assert_eq!(fb.filters.shape(), &[80, 513]);
// Check that filters are non-empty (some filter banks have coverage)
let total_sum: f32 = fb.filters.iter().sum();
assert!(total_sum > 0.0, "Filterbank should have some non-zero values");
}
#[test]
fn test_hann_window() {
let window = hann_window(1024);
assert_eq!(window.len(), 1024);
// Check endpoints are near zero
assert!(window[0].abs() < 1e-6);
// Check middle is near 1
assert!((window[512] - 1.0).abs() < 1e-4);
}
#[test]
fn test_stft_basic() {
// Create a simple sine wave
let sr = 22050;
let freq = 440.0;
let duration = 0.1;
let num_samples = (sr as f32 * duration) as usize;
let signal: Vec<f32> = (0..num_samples)
.map(|i| (2.0 * PI * freq * i as f32 / sr as f32).sin())
.collect();
let result = stft(&signal, 1024, 256, 1024);
assert!(result.is_ok());
let stft_matrix = result.unwrap();
assert_eq!(stft_matrix.shape()[0], 513); // n_fft/2 + 1
assert!(stft_matrix.shape()[1] > 0); // Some frames
}
#[test]
fn test_mel_spectrogram() {
let config = AudioConfig::default();
let num_samples = (config.sample_rate as f32 * 0.1) as usize;
let signal: Vec<f32> = (0..num_samples).map(|i| (i as f32 * 0.01).sin()).collect();
let result = mel_spectrogram(&signal, &config);
assert!(result.is_ok());
let mel_spec = result.unwrap();
assert_eq!(mel_spec.shape()[0], config.n_mels);
assert!(mel_spec.shape()[1] > 0);
}
}
|