File size: 10,104 Bytes
2bbfbb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
//! Mel-spectrogram computation
//!
//! Implements Short-Time Fourier Transform (STFT) and mel filterbank

use crate::{Error, Result};
use ndarray::{Array1, Array2, Axis};
use num_complex::Complex;
use realfft::RealFftPlanner;
use std::f32::consts::PI;

use super::AudioConfig;

/// Mel filterbank for converting linear spectrogram to mel scale
#[derive(Debug, Clone)]
pub struct MelFilterbank {
    /// Filterbank matrix (n_mels x n_fft/2+1)
    pub filters: Array2<f32>,
    /// Sample rate
    pub sample_rate: u32,
    /// Number of mel bands
    pub n_mels: usize,
    /// FFT size
    pub n_fft: usize,
}

impl MelFilterbank {
    /// Create mel filterbank
    pub fn new(sample_rate: u32, n_fft: usize, n_mels: usize, fmin: f32, fmax: f32) -> Self {
        let filters = create_mel_filterbank(sample_rate, n_fft, n_mels, fmin, fmax);
        Self {
            filters,
            sample_rate,
            n_mels,
            n_fft,
        }
    }

    /// Apply filterbank to power spectrogram
    pub fn apply(&self, spectrogram: &Array2<f32>) -> Array2<f32> {
        // spectrogram: (n_fft/2+1, time_frames)
        // filters: (n_mels, n_fft/2+1)
        // output: (n_mels, time_frames)
        self.filters.dot(spectrogram)
    }
}

/// Convert frequency to mel scale
pub fn hz_to_mel(hz: f32) -> f32 {
    2595.0 * (1.0 + hz / 700.0).log10()
}

/// Convert mel to frequency
pub fn mel_to_hz(mel: f32) -> f32 {
    700.0 * (10f32.powf(mel / 2595.0) - 1.0)
}

/// Create mel filterbank matrix
fn create_mel_filterbank(
    sample_rate: u32,
    n_fft: usize,
    n_mels: usize,
    fmin: f32,
    fmax: f32,
) -> Array2<f32> {
    let n_freqs = n_fft / 2 + 1;

    // Convert to mel scale
    let mel_min = hz_to_mel(fmin);
    let mel_max = hz_to_mel(fmax);

    // Create mel points
    let mel_points: Vec<f32> = (0..=n_mels + 1)
        .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32)
        .collect();

    // Convert back to Hz
    let hz_points: Vec<f32> = mel_points.iter().map(|&m| mel_to_hz(m)).collect();

    // Convert to FFT bin numbers
    let bin_points: Vec<usize> = hz_points
        .iter()
        .map(|&hz| ((n_fft as f32 + 1.0) * hz / sample_rate as f32).floor() as usize)
        .collect();

    // Create filterbank
    let mut filters = Array2::zeros((n_mels, n_freqs));

    for m in 0..n_mels {
        let f_left = bin_points[m];
        let f_center = bin_points[m + 1];
        let f_right = bin_points[m + 2];

        // Left slope
        for k in f_left..f_center {
            if k < n_freqs {
                filters[[m, k]] = (k - f_left) as f32 / (f_center - f_left).max(1) as f32;
            }
        }

        // Right slope
        for k in f_center..f_right {
            if k < n_freqs {
                filters[[m, k]] = (f_right - k) as f32 / (f_right - f_center).max(1) as f32;
            }
        }
    }

    filters
}

/// Compute Hann window
fn hann_window(size: usize) -> Vec<f32> {
    (0..size)
        .map(|n| 0.5 * (1.0 - (2.0 * PI * n as f32 / size as f32).cos()))
        .collect()
}

/// Compute Short-Time Fourier Transform (STFT)
///
/// # Arguments
/// * `signal` - Input audio signal
/// * `n_fft` - FFT size
/// * `hop_length` - Hop length between frames
/// * `win_length` - Window length (padded to n_fft)
///
/// # Returns
/// Complex STFT matrix (n_fft/2+1, time_frames)
pub fn stft(
    signal: &[f32],
    n_fft: usize,
    hop_length: usize,
    win_length: usize,
) -> Result<Array2<Complex<f32>>> {
    if signal.is_empty() {
        return Err(Error::Audio("Empty signal".into()));
    }

    // Create window
    let window = hann_window(win_length);

    // Pad signal
    let pad_length = n_fft / 2;
    let mut padded = vec![0.0f32; pad_length];
    padded.extend_from_slice(signal);
    padded.extend(vec![0.0f32; pad_length]);

    // Calculate number of frames
    let num_frames = (padded.len() - n_fft) / hop_length + 1;
    let n_freqs = n_fft / 2 + 1;

    // Create FFT planner
    let mut planner = RealFftPlanner::<f32>::new();
    let fft = planner.plan_fft_forward(n_fft);

    // Output matrix
    let mut stft_matrix = Array2::zeros((n_freqs, num_frames));

    // Process each frame
    let mut input_buffer = vec![0.0f32; n_fft];
    let mut output_buffer = vec![Complex::new(0.0f32, 0.0f32); n_freqs];

    for (frame_idx, start) in (0..padded.len() - n_fft + 1)
        .step_by(hop_length)
        .enumerate()
    {
        if frame_idx >= num_frames {
            break;
        }

        // Extract and window the frame
        for i in 0..win_length {
            input_buffer[i] = padded[start + i] * window[i];
        }
        // Zero pad if win_length < n_fft
        for i in win_length..n_fft {
            input_buffer[i] = 0.0;
        }

        // Perform FFT
        fft.process(&mut input_buffer, &mut output_buffer)
            .map_err(|e| Error::Audio(format!("FFT failed: {}", e)))?;

        // Store result
        for (freq_idx, &val) in output_buffer.iter().enumerate() {
            stft_matrix[[freq_idx, frame_idx]] = val;
        }
    }

    Ok(stft_matrix)
}

/// Compute magnitude spectrogram from STFT
pub fn magnitude_spectrogram(stft_matrix: &Array2<Complex<f32>>) -> Array2<f32> {
    stft_matrix.mapv(|c| c.norm())
}

/// Compute power spectrogram from STFT
pub fn power_spectrogram(stft_matrix: &Array2<Complex<f32>>) -> Array2<f32> {
    stft_matrix.mapv(|c| c.norm_sqr())
}

/// Compute mel spectrogram from audio signal
///
/// # Arguments
/// * `signal` - Audio samples
/// * `config` - Audio configuration
///
/// # Returns
/// Log mel spectrogram (n_mels, time_frames)
pub fn mel_spectrogram(signal: &[f32], config: &AudioConfig) -> Result<Array2<f32>> {
    // Compute STFT
    let stft_matrix = stft(signal, config.n_fft, config.hop_length, config.win_length)?;

    // Compute power spectrogram
    let power_spec = power_spectrogram(&stft_matrix);

    // Create mel filterbank
    let mel_fb = MelFilterbank::new(
        config.sample_rate,
        config.n_fft,
        config.n_mels,
        config.fmin,
        config.fmax,
    );

    // Apply mel filterbank
    let mel_spec = mel_fb.apply(&power_spec);

    // Apply log compression
    let log_mel_spec = mel_spec.mapv(|x| (x.max(1e-10)).ln());

    Ok(log_mel_spec)
}

/// Compute mel spectrogram with normalization
pub fn mel_spectrogram_normalized(
    signal: &[f32],
    config: &AudioConfig,
    mean: Option<f32>,
    std: Option<f32>,
) -> Result<Array2<f32>> {
    let mut mel_spec = mel_spectrogram(signal, config)?;

    // Normalize
    if let (Some(m), Some(s)) = (mean, std) {
        mel_spec.mapv_inplace(|x| (x - m) / s);
    } else {
        // Compute statistics from spectrogram
        let m = mel_spec.mean().unwrap_or(0.0);
        let s = mel_spec.std(0.0);
        if s > 1e-8 {
            mel_spec.mapv_inplace(|x| (x - m) / s);
        }
    }

    Ok(mel_spec)
}

/// Convert mel spectrogram back to linear spectrogram (approximate)
pub fn mel_to_linear(mel_spec: &Array2<f32>, mel_fb: &MelFilterbank) -> Array2<f32> {
    // Pseudo-inverse of mel filterbank
    let filters_t = mel_fb.filters.t();
    let gram = mel_fb.filters.dot(&filters_t);

    // Simple approximation using transpose
    filters_t.dot(mel_spec)
}

/// Compute spectrogram energy per frame
pub fn frame_energy(mel_spec: &Array2<f32>) -> Array1<f32> {
    mel_spec.sum_axis(Axis(0))
}

/// Detect voice activity based on energy threshold
pub fn voice_activity_detection(mel_spec: &Array2<f32>, threshold_db: f32) -> Vec<bool> {
    let energy = frame_energy(mel_spec);
    let max_energy = energy.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
    let threshold = max_energy + threshold_db; // threshold_db is negative

    energy.iter().map(|&e| e > threshold).collect()
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_hz_to_mel() {
        // Test known conversions
        assert!((hz_to_mel(0.0) - 0.0).abs() < 1e-6);
        assert!((hz_to_mel(1000.0) - 1000.0).abs() < 50.0); // Roughly linear at low freqs
    }

    #[test]
    fn test_mel_to_hz() {
        // Round trip
        let hz = 440.0;
        let mel = hz_to_mel(hz);
        let hz_back = mel_to_hz(mel);
        assert!((hz - hz_back).abs() < 1e-4);
    }

    #[test]
    fn test_mel_filterbank_creation() {
        let fb = MelFilterbank::new(22050, 1024, 80, 0.0, 8000.0);
        assert_eq!(fb.filters.shape(), &[80, 513]);

        // Check that filters are non-empty (some filter banks have coverage)
        let total_sum: f32 = fb.filters.iter().sum();
        assert!(total_sum > 0.0, "Filterbank should have some non-zero values");
    }

    #[test]
    fn test_hann_window() {
        let window = hann_window(1024);
        assert_eq!(window.len(), 1024);
        // Check endpoints are near zero
        assert!(window[0].abs() < 1e-6);
        // Check middle is near 1
        assert!((window[512] - 1.0).abs() < 1e-4);
    }

    #[test]
    fn test_stft_basic() {
        // Create a simple sine wave
        let sr = 22050;
        let freq = 440.0;
        let duration = 0.1;
        let num_samples = (sr as f32 * duration) as usize;

        let signal: Vec<f32> = (0..num_samples)
            .map(|i| (2.0 * PI * freq * i as f32 / sr as f32).sin())
            .collect();

        let result = stft(&signal, 1024, 256, 1024);
        assert!(result.is_ok());

        let stft_matrix = result.unwrap();
        assert_eq!(stft_matrix.shape()[0], 513); // n_fft/2 + 1
        assert!(stft_matrix.shape()[1] > 0); // Some frames
    }

    #[test]
    fn test_mel_spectrogram() {
        let config = AudioConfig::default();
        let num_samples = (config.sample_rate as f32 * 0.1) as usize;
        let signal: Vec<f32> = (0..num_samples).map(|i| (i as f32 * 0.01).sin()).collect();

        let result = mel_spectrogram(&signal, &config);
        assert!(result.is_ok());

        let mel_spec = result.unwrap();
        assert_eq!(mel_spec.shape()[0], config.n_mels);
        assert!(mel_spec.shape()[1] > 0);
    }
}