NeoPy commited on
Commit
f178dc2
·
verified ·
1 Parent(s): 06e0aa0

Update infer/lib/rmvpe.py

Browse files
Files changed (1) hide show
  1. infer/lib/rmvpe.py +454 -532
infer/lib/rmvpe.py CHANGED
@@ -1,670 +1,592 @@
1
- from io import BytesIO
2
  import os
3
- from typing import List, Optional, Tuple
4
- import numpy as np
5
  import torch
6
 
7
- from infer.lib import jit
 
 
8
 
9
- try:
10
- # Fix "Torch not compiled with CUDA enabled"
11
- import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
12
 
13
- if torch.xpu.is_available():
14
- from infer.modules.ipex import ipex_init
15
 
16
- ipex_init()
17
- except Exception: # pylint: disable=broad-exception-caught
18
- pass
19
- import torch.nn as nn
20
- import torch.nn.functional as F
21
- from librosa.util import normalize, pad_center, tiny
22
- from scipy.signal import get_window
23
-
24
- import logging
25
-
26
- logger = logging.getLogger(__name__)
27
-
28
-
29
- class STFT(torch.nn.Module):
30
- def __init__(
31
- self, filter_length=1024, hop_length=512, win_length=None, window="hann"
32
- ):
33
- """
34
- This module implements an STFT using 1D convolution and 1D transpose convolutions.
35
- This is a bit tricky so there are some cases that probably won't work as working
36
- out the same sizes before and after in all overlap add setups is tough. Right now,
37
- this code should work with hop lengths that are half the filter length (50% overlap
38
- between frames).
39
-
40
- Keyword Arguments:
41
- filter_length {int} -- Length of filters used (default: {1024})
42
- hop_length {int} -- Hop length of STFT (restrict to 50% overlap between frames) (default: {512})
43
- win_length {[type]} -- Length of the window function applied to each frame (if not specified, it
44
- equals the filter length). (default: {None})
45
- window {str} -- Type of window to use (options are bartlett, hann, hamming, blackman, blackmanharris)
46
- (default: {'hann'})
47
- """
48
- super(STFT, self).__init__()
49
- self.filter_length = filter_length
50
- self.hop_length = hop_length
51
- self.win_length = win_length if win_length else filter_length
52
- self.window = window
53
- self.forward_transform = None
54
- self.pad_amount = int(self.filter_length / 2)
55
- fourier_basis = np.fft.fft(np.eye(self.filter_length))
56
-
57
- cutoff = int((self.filter_length / 2 + 1))
58
- fourier_basis = np.vstack(
59
- [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
60
- )
61
- forward_basis = torch.FloatTensor(fourier_basis)
62
- inverse_basis = torch.FloatTensor(np.linalg.pinv(fourier_basis))
63
-
64
- assert filter_length >= self.win_length
65
- # get window and zero center pad it to filter_length
66
- fft_window = get_window(window, self.win_length, fftbins=True)
67
- fft_window = pad_center(fft_window, size=filter_length)
68
- fft_window = torch.from_numpy(fft_window).float()
69
-
70
- # window the bases
71
- forward_basis *= fft_window
72
- inverse_basis = (inverse_basis.T * fft_window).T
73
-
74
- self.register_buffer("forward_basis", forward_basis.float())
75
- self.register_buffer("inverse_basis", inverse_basis.float())
76
- self.register_buffer("fft_window", fft_window.float())
77
-
78
- def transform(self, input_data, return_phase=False):
79
- """Take input data (audio) to STFT domain.
80
-
81
- Arguments:
82
- input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
83
-
84
- Returns:
85
- magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
86
- num_frequencies, num_frames)
87
- phase {tensor} -- Phase of STFT with shape (num_batch,
88
- num_frequencies, num_frames)
89
- """
90
- input_data = F.pad(
91
- input_data,
92
- (self.pad_amount, self.pad_amount),
93
- mode="reflect",
94
- )
95
- forward_transform = input_data.unfold(
96
- 1, self.filter_length, self.hop_length
97
- ).permute(0, 2, 1)
98
- forward_transform = torch.matmul(self.forward_basis, forward_transform)
99
- cutoff = int((self.filter_length / 2) + 1)
100
- real_part = forward_transform[:, :cutoff, :]
101
- imag_part = forward_transform[:, cutoff:, :]
102
- magnitude = torch.sqrt(real_part**2 + imag_part**2)
103
- if return_phase:
104
- phase = torch.atan2(imag_part.data, real_part.data)
105
- return magnitude, phase
106
- else:
107
- return magnitude
108
-
109
- def inverse(self, magnitude, phase):
110
- """Call the inverse STFT (iSTFT), given magnitude and phase tensors produced
111
- by the ```transform``` function.
112
-
113
- Arguments:
114
- magnitude {tensor} -- Magnitude of STFT with shape (num_batch,
115
- num_frequencies, num_frames)
116
- phase {tensor} -- Phase of STFT with shape (num_batch,
117
- num_frequencies, num_frames)
118
-
119
- Returns:
120
- inverse_transform {tensor} -- Reconstructed audio given magnitude and phase. Of
121
- shape (num_batch, num_samples)
122
- """
123
- cat = torch.cat(
124
- [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
125
- )
126
- fold = torch.nn.Fold(
127
- output_size=(1, (cat.size(-1) - 1) * self.hop_length + self.filter_length),
128
- kernel_size=(1, self.filter_length),
129
- stride=(1, self.hop_length),
130
- )
131
- inverse_transform = torch.matmul(self.inverse_basis, cat)
132
- inverse_transform = fold(inverse_transform)[
133
- :, 0, 0, self.pad_amount : -self.pad_amount
134
- ]
135
- window_square_sum = (
136
- self.fft_window.pow(2).repeat(cat.size(-1), 1).T.unsqueeze(0)
137
- )
138
- window_square_sum = fold(window_square_sum)[
139
- :, 0, 0, self.pad_amount : -self.pad_amount
140
- ]
141
- inverse_transform /= window_square_sum
142
- return inverse_transform
143
 
144
- def forward(self, input_data):
145
- """Take input data (audio) to STFT domain and then back to audio.
 
146
 
147
- Arguments:
148
- input_data {tensor} -- Tensor of floats, with shape (num_batch, num_samples)
 
 
 
 
149
 
150
- Returns:
151
- reconstruction {tensor} -- Reconstructed audio given magnitude and phase. Of
152
- shape (num_batch, num_samples)
153
- """
154
- self.magnitude, self.phase = self.transform(input_data, return_phase=True)
155
- reconstruction = self.inverse(self.magnitude, self.phase)
156
- return reconstruction
157
 
 
 
 
 
 
 
 
158
 
159
- from time import time as ttime
 
160
 
 
 
 
 
 
 
161
 
162
- class BiGRU(nn.Module):
163
- def __init__(self, input_features, hidden_features, num_layers):
164
- super(BiGRU, self).__init__()
165
- self.gru = nn.GRU(
166
- input_features,
167
- hidden_features,
168
- num_layers=num_layers,
169
- batch_first=True,
170
- bidirectional=True,
171
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  def forward(self, x):
174
- return self.gru(x)[0]
175
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  class ConvBlockRes(nn.Module):
178
  def __init__(self, in_channels, out_channels, momentum=0.01):
179
  super(ConvBlockRes, self).__init__()
180
  self.conv = nn.Sequential(
181
  nn.Conv2d(
182
- in_channels=in_channels,
183
- out_channels=out_channels,
184
- kernel_size=(3, 3),
185
- stride=(1, 1),
186
- padding=(1, 1),
187
- bias=False,
188
- ),
189
- nn.BatchNorm2d(out_channels, momentum=momentum),
190
- nn.ReLU(),
 
 
 
191
  nn.Conv2d(
192
- in_channels=out_channels,
193
- out_channels=out_channels,
194
- kernel_size=(3, 3),
195
- stride=(1, 1),
196
- padding=(1, 1),
197
- bias=False,
198
- ),
199
- nn.BatchNorm2d(out_channels, momentum=momentum),
200
- nn.ReLU(),
 
 
 
201
  )
202
- # self.shortcut:Optional[nn.Module] = None
203
  if in_channels != out_channels:
204
  self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
 
 
205
 
206
- def forward(self, x: torch.Tensor):
207
- if not hasattr(self, "shortcut"):
208
- return self.conv(x) + x
209
- else:
210
- return self.conv(x) + self.shortcut(x)
 
 
 
 
211
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  class Encoder(nn.Module):
214
- def __init__(
215
- self,
216
- in_channels,
217
- in_size,
218
- n_encoders,
219
- kernel_size,
220
- n_blocks,
221
- out_channels=16,
222
- momentum=0.01,
223
- ):
224
  super(Encoder, self).__init__()
225
  self.n_encoders = n_encoders
226
  self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
227
  self.layers = nn.ModuleList()
228
- self.latent_channels = []
229
- for i in range(self.n_encoders):
230
- self.layers.append(
231
- ResEncoderBlock(
232
- in_channels, out_channels, kernel_size, n_blocks, momentum=momentum
233
- )
234
- )
235
- self.latent_channels.append([out_channels, in_size])
236
  in_channels = out_channels
237
  out_channels *= 2
238
  in_size //= 2
 
239
  self.out_size = in_size
240
  self.out_channel = out_channels
241
 
242
- def forward(self, x: torch.Tensor):
243
- concat_tensors: List[torch.Tensor] = []
244
  x = self.bn(x)
245
- for i, layer in enumerate(self.layers):
 
246
  t, x = layer(x)
247
  concat_tensors.append(t)
248
- return x, concat_tensors
249
-
250
-
251
- class ResEncoderBlock(nn.Module):
252
- def __init__(
253
- self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01
254
- ):
255
- super(ResEncoderBlock, self).__init__()
256
- self.n_blocks = n_blocks
257
- self.conv = nn.ModuleList()
258
- self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
259
- for i in range(n_blocks - 1):
260
- self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
261
- self.kernel_size = kernel_size
262
- if self.kernel_size is not None:
263
- self.pool = nn.AvgPool2d(kernel_size=kernel_size)
264
-
265
- def forward(self, x):
266
- for i, conv in enumerate(self.conv):
267
- x = conv(x)
268
- if self.kernel_size is not None:
269
- return x, self.pool(x)
270
- else:
271
- return x
272
 
 
273
 
274
- class Intermediate(nn.Module): #
275
  def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
276
  super(Intermediate, self).__init__()
277
- self.n_inters = n_inters
278
  self.layers = nn.ModuleList()
279
- self.layers.append(
280
- ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum)
281
- )
282
- for i in range(self.n_inters - 1):
283
- self.layers.append(
284
- ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum)
285
- )
286
 
287
  def forward(self, x):
288
- for i, layer in enumerate(self.layers):
289
  x = layer(x)
290
- return x
291
 
 
292
 
293
  class ResDecoderBlock(nn.Module):
294
  def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
295
  super(ResDecoderBlock, self).__init__()
296
  out_padding = (0, 1) if stride == (1, 2) else (1, 1)
297
- self.n_blocks = n_blocks
298
  self.conv1 = nn.Sequential(
299
  nn.ConvTranspose2d(
300
- in_channels=in_channels,
301
- out_channels=out_channels,
302
- kernel_size=(3, 3),
303
- stride=stride,
304
- padding=(1, 1),
305
- output_padding=out_padding,
306
- bias=False,
307
- ),
308
- nn.BatchNorm2d(out_channels, momentum=momentum),
309
- nn.ReLU(),
 
 
 
310
  )
 
311
  self.conv2 = nn.ModuleList()
312
  self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
313
- for i in range(n_blocks - 1):
 
314
  self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
315
 
316
  def forward(self, x, concat_tensor):
317
- x = self.conv1(x)
318
- x = torch.cat((x, concat_tensor), dim=1)
319
- for i, conv2 in enumerate(self.conv2):
320
  x = conv2(x)
321
- return x
322
 
 
323
 
324
  class Decoder(nn.Module):
325
  def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
326
  super(Decoder, self).__init__()
327
  self.layers = nn.ModuleList()
328
- self.n_decoders = n_decoders
329
- for i in range(self.n_decoders):
330
  out_channels = in_channels // 2
331
- self.layers.append(
332
- ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum)
333
- )
334
  in_channels = out_channels
335
 
336
- def forward(self, x: torch.Tensor, concat_tensors: List[torch.Tensor]):
337
  for i, layer in enumerate(self.layers):
338
  x = layer(x, concat_tensors[-1 - i])
339
- return x
340
 
 
341
 
342
  class DeepUnet(nn.Module):
343
- def __init__(
344
- self,
345
- kernel_size,
346
- n_blocks,
347
- en_de_layers=5,
348
- inter_layers=4,
349
- in_channels=1,
350
- en_out_channels=16,
351
- ):
352
  super(DeepUnet, self).__init__()
353
- self.encoder = Encoder(
354
- in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels
355
- )
356
- self.intermediate = Intermediate(
357
- self.encoder.out_channel // 2,
358
- self.encoder.out_channel,
359
- inter_layers,
360
- n_blocks,
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  )
362
- self.decoder = Decoder(
363
- self.encoder.out_channel, en_de_layers, kernel_size, n_blocks
 
 
 
364
  )
365
 
366
- def forward(self, x: torch.Tensor) -> torch.Tensor:
367
- x, concat_tensors = self.encoder(x)
368
- x = self.intermediate(x)
369
- x = self.decoder(x, concat_tensors)
370
- return x
371
 
 
 
 
 
372
 
 
 
 
 
 
 
 
373
  class E2E(nn.Module):
374
- def __init__(
375
- self,
376
- n_blocks,
377
- n_gru,
378
- kernel_size,
379
- en_de_layers=5,
380
- inter_layers=4,
381
- in_channels=1,
382
- en_out_channels=16,
383
- ):
384
  super(E2E, self).__init__()
385
- self.unet = DeepUnet(
386
- kernel_size,
387
- n_blocks,
388
- en_de_layers,
389
- inter_layers,
390
- in_channels,
391
- en_out_channels,
 
 
 
 
 
 
 
 
 
 
 
 
392
  )
 
393
  self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
394
- if n_gru:
395
- self.fc = nn.Sequential(
396
- BiGRU(3 * 128, 256, n_gru),
397
- nn.Linear(512, 360),
398
- nn.Dropout(0.25),
399
- nn.Sigmoid(),
400
  )
401
- else:
402
- self.fc = nn.Sequential(
403
- nn.Linear(3 * nn.N_MELS, nn.N_CLASS), nn.Dropout(0.25), nn.Sigmoid()
 
 
404
  )
 
405
 
406
  def forward(self, mel):
407
- # print(mel.shape)
408
- mel = mel.transpose(-1, -2).unsqueeze(1)
409
- x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
410
- x = self.fc(x)
411
- # print(x.shape)
412
- return x
413
-
414
-
415
- from librosa.filters import mel
416
-
417
 
418
- class MelSpectrogram(torch.nn.Module):
419
- def __init__(
420
- self,
421
- is_half,
422
- n_mel_channels,
423
- sampling_rate,
424
- win_length,
425
- hop_length,
426
- n_fft=None,
427
- mel_fmin=0,
428
- mel_fmax=None,
429
- clamp=1e-5,
430
- ):
431
  super().__init__()
432
  n_fft = win_length if n_fft is None else n_fft
433
  self.hann_window = {}
434
- mel_basis = mel(
435
- sr=sampling_rate,
436
- n_fft=n_fft,
437
- n_mels=n_mel_channels,
438
- fmin=mel_fmin,
439
- fmax=mel_fmax,
440
- htk=True,
441
- )
442
  mel_basis = torch.from_numpy(mel_basis).float()
443
  self.register_buffer("mel_basis", mel_basis)
444
  self.n_fft = win_length if n_fft is None else n_fft
445
  self.hop_length = hop_length
446
  self.win_length = win_length
447
- self.sampling_rate = sampling_rate
448
  self.n_mel_channels = n_mel_channels
449
  self.clamp = clamp
450
- self.is_half = is_half
451
 
452
  def forward(self, audio, keyshift=0, speed=1, center=True):
453
  factor = 2 ** (keyshift / 12)
454
- n_fft_new = int(np.round(self.n_fft * factor))
455
  win_length_new = int(np.round(self.win_length * factor))
456
- hop_length_new = int(np.round(self.hop_length * speed))
457
  keyshift_key = str(keyshift) + "_" + str(audio.device)
458
- if keyshift_key not in self.hann_window:
459
- self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(
460
- audio.device
461
- )
462
- if "privateuseone" in str(audio.device):
463
- if not hasattr(self, "stft"):
464
- self.stft = STFT(
465
- filter_length=n_fft_new,
466
- hop_length=hop_length_new,
467
- win_length=win_length_new,
468
- window="hann",
469
- ).to(audio.device)
470
- magnitude = self.stft.transform(audio)
471
- else:
472
- fft = torch.stft(
473
- audio,
474
- n_fft=n_fft_new,
475
- hop_length=hop_length_new,
476
- win_length=win_length_new,
477
- window=self.hann_window[keyshift_key],
478
- center=center,
479
- return_complex=True,
480
- )
481
- magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
482
  if keyshift != 0:
483
  size = self.n_fft // 2 + 1
484
  resize = magnitude.size(1)
485
- if resize < size:
486
- magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
487
  magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
488
- mel_output = torch.matmul(self.mel_basis, magnitude)
489
- if self.is_half == True:
490
- mel_output = mel_output.half()
491
- log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
492
- return log_mel_spec
493
 
 
 
494
 
495
  class RMVPE:
496
- def __init__(self, model_path: str, is_half, device=None, use_jit=False):
497
- self.resample_kernel = {}
498
- self.resample_kernel = {}
499
- self.is_half = is_half
500
- if device is None:
501
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
502
- self.device = device
503
- self.mel_extractor = MelSpectrogram(
504
- is_half, 128, 16000, 1024, 160, None, 30, 8000
505
- ).to(device)
506
- if "privateuseone" in str(device):
507
  import onnxruntime as ort
508
 
509
- ort_session = ort.InferenceSession(
510
- "%s/rmvpe.onnx" % os.environ["rmvpe_root"],
511
- providers=["DmlExecutionProvider"],
512
- )
513
- self.model = ort_session
514
  else:
515
- if str(self.device) == "cuda":
516
- self.device = torch.device("cuda:0")
517
-
518
- def get_jit_model():
519
- jit_model_path = model_path.rstrip(".pth")
520
- jit_model_path += ".half.jit" if is_half else ".jit"
521
- reload = False
522
- if os.path.exists(jit_model_path):
523
- ckpt = jit.load(jit_model_path)
524
- model_device = ckpt["device"]
525
- if model_device != str(self.device):
526
- reload = True
527
- else:
528
- reload = True
529
-
530
- if reload:
531
- ckpt = jit.rmvpe_jit_export(
532
- model_path=model_path,
533
- mode="script",
534
- inputs_path=None,
535
- save_path=jit_model_path,
536
- device=device,
537
- is_half=is_half,
538
- )
539
- model = torch.jit.load(BytesIO(ckpt["model"]), map_location=device)
540
- return model
541
-
542
- def get_default_model():
543
- model = E2E(4, 1, (2, 2))
544
- ckpt = torch.load(model_path, map_location="cpu")
545
- model.load_state_dict(ckpt)
546
- model.eval()
547
- if is_half:
548
- model = model.half()
549
- else:
550
- model = model.float()
551
- return model
552
-
553
- if use_jit:
554
- if is_half and "cpu" in str(self.device):
555
- logger.warning(
556
- "Use default rmvpe model. \
557
- Jit is not supported on the CPU for half floating point"
558
- )
559
- self.model = get_default_model()
560
- else:
561
- self.model = get_jit_model()
562
- else:
563
- self.model = get_default_model()
564
-
565
- self.model = self.model.to(device)
566
- cents_mapping = 20 * np.arange(360) + 1997.3794084376191
567
- self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
568
-
569
- def mel2hidden(self, mel):
570
  with torch.no_grad():
571
  n_frames = mel.shape[-1]
572
- n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
573
- if n_pad > 0:
574
- mel = F.pad(mel, (0, n_pad), mode="constant")
575
- if "privateuseone" in str(self.device):
576
- onnx_input_name = self.model.get_inputs()[0].name
577
- onnx_outputs_names = self.model.get_outputs()[0].name
578
- hidden = self.model.run(
579
- [onnx_outputs_names],
580
- input_feed={onnx_input_name: mel.cpu().numpy()},
581
- )[0]
582
- else:
583
- mel = mel.half() if self.is_half else mel.float()
584
- hidden = self.model(mel)
 
 
 
 
 
 
585
  return hidden[:, :n_frames]
586
 
587
  def decode(self, hidden, thred=0.03):
588
- cents_pred = self.to_local_average_cents(hidden, thred=thred)
589
- f0 = 10 * (2 ** (cents_pred / 1200))
590
  f0[f0 == 10] = 0
591
- # f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
592
  return f0
593
 
594
  def infer_from_audio(self, audio, thred=0.03):
595
- # torch.cuda.synchronize()
596
- # t0 = ttime()
597
- if not torch.is_tensor(audio):
598
- audio = torch.from_numpy(audio)
599
- mel = self.mel_extractor(
600
- audio.float().to(self.device).unsqueeze(0), center=True
601
- )
602
- # print(123123123,mel.device.type)
603
- # torch.cuda.synchronize()
604
- # t1 = ttime()
605
- hidden = self.mel2hidden(mel)
606
- # torch.cuda.synchronize()
607
- # t2 = ttime()
608
- # print(234234,hidden.device.type)
609
- if "privateuseone" not in str(self.device):
610
- hidden = hidden.squeeze(0).cpu().numpy()
611
- else:
612
- hidden = hidden[0]
613
- if self.is_half == True:
614
- hidden = hidden.astype("float32")
615
-
616
- f0 = self.decode(hidden, thred=thred)
617
- # torch.cuda.synchronize()
618
- # t3 = ttime()
619
- # print("hmvpe:%s\t%s\t%s\t%s"%(t1-t0,t2-t1,t3-t2,t3-t0))
620
  return f0
621
 
622
  def to_local_average_cents(self, salience, thred=0.05):
623
- # t0 = ttime()
624
- center = np.argmax(salience, axis=1) # 帧长#index
625
- salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
626
- # t1 = ttime()
627
  center += 4
628
- todo_salience = []
629
- todo_cents_mapping = []
630
  starts = center - 4
631
  ends = center + 5
 
632
  for idx in range(salience.shape[0]):
633
  todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
634
  todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
635
- # t2 = ttime()
636
- todo_salience = np.array(todo_salience) # 帧长,9
637
- todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
638
- product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
639
- weight_sum = np.sum(todo_salience, 1) # 帧长
640
- devided = product_sum / weight_sum # 帧长
641
- # t3 = ttime()
642
- maxx = np.max(salience, axis=1) # 帧长
643
- devided[maxx <= thred] = 0
644
- # t4 = ttime()
645
- # print("decode:%s\t%s\t%s\t%s" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
646
- return devided
647
-
648
-
649
- if __name__ == "__main__":
650
- import librosa
651
- import soundfile as sf
652
-
653
- audio, sampling_rate = sf.read(r"C:\Users\liujing04\Desktop\Z\冬之花clip1.wav")
654
- if len(audio.shape) > 1:
655
- audio = librosa.to_mono(audio.transpose(1, 0))
656
- audio_bak = audio.copy()
657
- if sampling_rate != 16000:
658
- audio = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
659
- model_path = r"D:\BaiduNetdiskDownload\RVC-beta-v2-0727AMD_realtime\rmvpe.pt"
660
- thred = 0.03 # 0.01
661
- device = "cuda" if torch.cuda.is_available() else "cpu"
662
- rmvpe = RMVPE(model_path, is_half=False, device=device)
663
- t0 = ttime()
664
- f0 = rmvpe.infer_from_audio(audio, thred=thred)
665
- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
666
- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
667
- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
668
- # f0 = rmvpe.infer_from_audio(audio, thred=thred)
669
- t1 = ttime()
670
- logger.info("%s %.2f", f0.shape, t1 - t0)
 
 
1
  import os
2
+ import sys
 
3
  import torch
4
 
5
+ import numpy as np
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
 
9
+ from librosa.filters import mel
 
 
10
 
11
+ sys.path.append(os.getcwd())
 
12
 
13
+ N_MELS, N_CLASS = 128, 360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ def autopad(k, p=None):
16
+ if p is None: p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
17
+ return p
18
 
19
+ class Conv(nn.Module):
20
+ def __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True):
21
+ super().__init__()
22
+ self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)
23
+ self.bn = nn.BatchNorm2d(c2)
24
+ self.act = nn.SiLU() if act else nn.Identity()
25
 
26
+ def forward(self, x):
27
+ return self.act(self.bn(self.conv(x)))
 
 
 
 
 
28
 
29
+ class DSConv(nn.Module):
30
+ def __init__(self, c1, c2, k=3, s=1, p=None, act=True):
31
+ super().__init__()
32
+ self.dwconv = nn.Conv2d(c1, c1, k, s, autopad(k, p), groups=c1, bias=False)
33
+ self.pwconv = nn.Conv2d(c1, c2, 1, 1, 0, bias=False)
34
+ self.bn = nn.BatchNorm2d(c2)
35
+ self.act = nn.SiLU() if act else nn.Identity()
36
 
37
+ def forward(self, x):
38
+ return self.act(self.bn(self.pwconv(self.dwconv(x))))
39
 
40
+ class DS_Bottleneck(nn.Module):
41
+ def __init__(self, c1, c2, k=3, shortcut=True):
42
+ super().__init__()
43
+ self.dsconv1 = DSConv(c1, c1, k=3, s=1)
44
+ self.dsconv2 = DSConv(c1, c2, k=k, s=1)
45
+ self.shortcut = shortcut and c1 == c2
46
 
47
+ def forward(self, x):
48
+ return x + self.dsconv2(self.dsconv1(x)) if self.shortcut else self.dsconv2(self.dsconv1(x))
49
+
50
+ class DS_C3k(nn.Module):
51
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
52
+ super().__init__()
53
+ self.cv1 = Conv(c1, int(c2 * e), 1, 1)
54
+ self.cv2 = Conv(c1, int(c2 * e), 1, 1)
55
+ self.cv3 = Conv(2 * int(c2 * e), c2, 1, 1)
56
+ self.m = nn.Sequential(*[DS_Bottleneck(int(c2 * e), int(c2 * e), k=k, shortcut=True) for _ in range(n)])
57
+
58
+ def forward(self, x):
59
+ return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))
60
+
61
+ class DS_C3k2(nn.Module):
62
+ def __init__(self, c1, c2, n=1, k=3, e=0.5):
63
+ super().__init__()
64
+ self.cv1 = Conv(c1, int(c2 * e), 1, 1)
65
+ self.m = DS_C3k(int(c2 * e), int(c2 * e), n=n, k=k, e=1.0)
66
+ self.cv2 = Conv(int(c2 * e), c2, 1, 1)
67
+
68
+ def forward(self, x):
69
+ return self.cv2(self.m(self.cv1(x)))
70
+
71
+ class AdaptiveHyperedgeGeneration(nn.Module):
72
+ def __init__(self, in_channels, num_hyperedges, num_heads):
73
+ super().__init__()
74
+ self.num_hyperedges = num_hyperedges
75
+ self.num_heads = num_heads
76
+ self.head_dim = max(1, in_channels // num_heads)
77
+ self.global_proto = nn.Parameter(torch.randn(num_hyperedges, in_channels))
78
+ self.context_mapper = nn.Linear(2 * in_channels, num_hyperedges * in_channels, bias=False)
79
+ self.query_proj = nn.Linear(in_channels, in_channels, bias=False)
80
+ self.scale = self.head_dim ** -0.5
81
+
82
+ def forward(self, x):
83
+ B, N, C = x.shape
84
+ P = self.global_proto.unsqueeze(0) + self.context_mapper(torch.cat((F.adaptive_avg_pool1d(x.permute(0, 2, 1), 1).squeeze(-1), F.adaptive_max_pool1d(x.permute(0, 2, 1), 1).squeeze(-1)), dim=1)).view(B, self.num_hyperedges, C)
85
+
86
+ return F.softmax(((self.query_proj(x).view(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3) @ P.view(B, self.num_hyperedges, self.num_heads, self.head_dim).permute(0, 2, 3, 1)) * self.scale).mean(dim=1).permute(0, 2, 1), dim=-1)
87
+
88
+ class HypergraphConvolution(nn.Module):
89
+ def __init__(self, in_channels, out_channels):
90
+ super().__init__()
91
+ self.W_e = nn.Linear(in_channels, in_channels, bias=False)
92
+ self.W_v = nn.Linear(in_channels, out_channels, bias=False)
93
+ self.act = nn.SiLU()
94
+
95
+ def forward(self, x, A):
96
+ return x + self.act(self.W_v(A.transpose(1, 2).bmm(self.act(self.W_e(A.bmm(x))))))
97
+
98
+ class AdaptiveHypergraphComputation(nn.Module):
99
+ def __init__(self, in_channels, out_channels, num_hyperedges, num_heads):
100
+ super().__init__()
101
+ self.adaptive_hyperedge_gen = AdaptiveHyperedgeGeneration(in_channels, num_hyperedges, num_heads)
102
+ self.hypergraph_conv = HypergraphConvolution(in_channels, out_channels)
103
+
104
+ def forward(self, x):
105
+ B, _, H, W = x.shape
106
+ x_flat = x.flatten(2).permute(0, 2, 1)
107
+ return self.hypergraph_conv(x_flat, self.adaptive_hyperedge_gen(x_flat)).permute(0, 2, 1).view(B, -1, H, W)
108
+
109
+ class C3AH(nn.Module):
110
+ def __init__(self, c1, c2, num_hyperedges, num_heads, e=0.5):
111
+ super().__init__()
112
+ self.cv1 = Conv(c1, int(c1 * e), 1, 1)
113
+ self.cv2 = Conv(c1, int(c1 * e), 1, 1)
114
+ self.ahc = AdaptiveHypergraphComputation(int(c1 * e), int(c1 * e), num_hyperedges, num_heads)
115
+ self.cv3 = Conv(2 * int(c1 * e), c2, 1, 1)
116
 
117
  def forward(self, x):
118
+ return self.cv3(torch.cat((self.ahc(self.cv2(x)), self.cv1(x)), dim=1))
119
 
120
+ class HyperACE(nn.Module):
121
+ def __init__(self, in_channels, out_channels, num_hyperedges=16, num_heads=8, k=2, l=1, c_h=0.5, c_l=0.25):
122
+ super().__init__()
123
+ c2, c3, c4, c5 = in_channels
124
+ c_mid = c4
125
+ self.fuse_conv = Conv(c2 + c3 + c4 + c5, c_mid, 1, 1)
126
+ self.c_h = int(c_mid * c_h)
127
+ self.c_l = int(c_mid * c_l)
128
+ self.c_s = c_mid - self.c_h - self.c_l
129
+ self.high_order_branch = nn.ModuleList([C3AH(self.c_h, self.c_h, num_hyperedges=num_hyperedges, num_heads=num_heads, e=1.0) for _ in range(k)])
130
+ self.high_order_fuse = Conv(self.c_h * k, self.c_h, 1, 1)
131
+ self.low_order_branch = nn.Sequential(*[DS_C3k(self.c_l, self.c_l, n=1, k=3, e=1.0) for _ in range(l)])
132
+ self.final_fuse = Conv(self.c_h + self.c_l + self.c_s, out_channels, 1, 1)
133
+
134
+ def forward(self, x):
135
+ B2, B3, B4, B5 = x
136
+ _, _, H4, W4 = B4.shape
137
+
138
+ x_h, x_l, x_s = self.fuse_conv(
139
+ torch.cat(
140
+ (
141
+ F.interpolate(B2, size=(H4, W4), mode='bilinear', align_corners=False),
142
+ F.interpolate(B3, size=(H4, W4), mode='bilinear', align_corners=False),
143
+ B4,
144
+ F.interpolate(B5, size=(H4, W4), mode='bilinear', align_corners=False)
145
+ ),
146
+ dim=1
147
+ )
148
+ ).split([self.c_h, self.c_l, self.c_s], dim=1)
149
+
150
+ return self.final_fuse(torch.cat((self.high_order_fuse(torch.cat([m(x_h) for m in self.high_order_branch], dim=1)), self.low_order_branch(x_l), x_s), dim=1))
151
+
152
+ class GatedFusion(nn.Module):
153
+ def __init__(self, in_channels):
154
+ super().__init__()
155
+ self.gamma = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
156
+
157
+ def forward(self, f_in, h):
158
+ return f_in + self.gamma * h
159
+
160
+ class YOLO13Encoder(nn.Module):
161
+ def __init__(self, in_channels, base_channels=32):
162
+ super().__init__()
163
+ self.stem = DSConv(in_channels, base_channels, k=3, s=1)
164
+
165
+ self.p2 = nn.Sequential(
166
+ DSConv(base_channels, base_channels*2, k=3, s=(2, 2)),
167
+ DS_C3k2(base_channels*2, base_channels*2, n=1)
168
+ )
169
+
170
+ self.p3 = nn.Sequential(
171
+ DSConv(base_channels*2, base_channels*4, k=3, s=(2, 2)),
172
+ DS_C3k2(base_channels*4, base_channels*4, n=2)
173
+ )
174
+
175
+ self.p4 = nn.Sequential(
176
+ DSConv(base_channels*4, base_channels*8, k=3, s=(2, 2)),
177
+ DS_C3k2(base_channels*8, base_channels*8, n=2)
178
+ )
179
+
180
+ self.p5 = nn.Sequential(
181
+ DSConv(base_channels*8, base_channels*16, k=3, s=(2, 2)),
182
+ DS_C3k2(base_channels*16, base_channels*16, n=1)
183
+ )
184
+
185
+ self.out_channels = [base_channels*2, base_channels*4, base_channels*8, base_channels*16]
186
+
187
+ def forward(self, x):
188
+ x = self.stem(x)
189
+ p2 = self.p2(x)
190
+ p3 = self.p3(p2)
191
+ p4 = self.p4(p3)
192
+ p5 = self.p5(p4)
193
+ return [p2, p3, p4, p5]
194
+
195
+ class YOLO13FullPADDecoder(nn.Module):
196
+ def __init__(self, encoder_channels, hyperace_out_c, out_channels_final):
197
+ super().__init__()
198
+ c_p2, c_p3, c_p4, c_p5 = encoder_channels
199
+ c_d5, c_d4, c_d3, c_d2 = c_p5, c_p4, c_p3, c_p2
200
+
201
+ self.h_to_d5 = Conv(hyperace_out_c, c_d5, 1, 1)
202
+ self.h_to_d4 = Conv(hyperace_out_c, c_d4, 1, 1)
203
+ self.h_to_d3 = Conv(hyperace_out_c, c_d3, 1, 1)
204
+ self.h_to_d2 = Conv(hyperace_out_c, c_d2, 1, 1)
205
+
206
+ self.fusion_d5 = GatedFusion(c_d5)
207
+ self.fusion_d4 = GatedFusion(c_d4)
208
+ self.fusion_d3 = GatedFusion(c_d3)
209
+ self.fusion_d2 = GatedFusion(c_d2)
210
+
211
+ self.skip_p5 = Conv(c_p5, c_d5, 1, 1)
212
+ self.skip_p4 = Conv(c_p4, c_d4, 1, 1)
213
+ self.skip_p3 = Conv(c_p3, c_d3, 1, 1)
214
+ self.skip_p2 = Conv(c_p2, c_d2, 1, 1)
215
+
216
+ self.up_d5 = DS_C3k2(c_d5, c_d4, n=1)
217
+ self.up_d4 = DS_C3k2(c_d4, c_d3, n=1)
218
+ self.up_d3 = DS_C3k2(c_d3, c_d2, n=1)
219
+
220
+ self.final_d2 = DS_C3k2(c_d2, c_d2, n=1)
221
+ self.final_conv = Conv(c_d2, out_channels_final, 1, 1)
222
+
223
+ def forward(self, enc_feats, h_ace):
224
+ p2, p3, p4, p5 = enc_feats
225
+
226
+ d5 = self.skip_p5(p5)
227
+ d4 = self.up_d5(F.interpolate(self.fusion_d5(d5, self.h_to_d5(F.interpolate(h_ace, size=d5.shape[2:], mode='bilinear', align_corners=False))), size=p4.shape[2:], mode='bilinear', align_corners=False)) + self.skip_p4(p4)
228
+ d3 = self.up_d4(F.interpolate(self.fusion_d4(d4, self.h_to_d4(F.interpolate(h_ace, size=d4.shape[2:], mode='bilinear', align_corners=False))), size=p3.shape[2:], mode='bilinear', align_corners=False)) + self.skip_p3(p3)
229
+ d2 = self.up_d3(F.interpolate(self.fusion_d3(d3, self.h_to_d3(F.interpolate(h_ace, size=d3.shape[2:], mode='bilinear', align_corners=False))), size=p2.shape[2:], mode='bilinear', align_corners=False)) + self.skip_p2(p2)
230
+
231
+ return self.final_conv(self.final_d2(self.fusion_d2(d2, self.h_to_d2(F.interpolate(h_ace, size=d2.shape[2:], mode='bilinear', align_corners=False)))))
232
 
233
  class ConvBlockRes(nn.Module):
234
  def __init__(self, in_channels, out_channels, momentum=0.01):
235
  super(ConvBlockRes, self).__init__()
236
  self.conv = nn.Sequential(
237
  nn.Conv2d(
238
+ in_channels=in_channels,
239
+ out_channels=out_channels,
240
+ kernel_size=(3, 3),
241
+ stride=(1, 1),
242
+ padding=(1, 1),
243
+ bias=False
244
+ ),
245
+ nn.BatchNorm2d(
246
+ out_channels,
247
+ momentum=momentum
248
+ ),
249
+ nn.ReLU(),
250
  nn.Conv2d(
251
+ in_channels=out_channels,
252
+ out_channels=out_channels,
253
+ kernel_size=(3, 3),
254
+ stride=(1, 1),
255
+ padding=(1, 1),
256
+ bias=False
257
+ ),
258
+ nn.BatchNorm2d(
259
+ out_channels,
260
+ momentum=momentum
261
+ ),
262
+ nn.ReLU()
263
  )
264
+
265
  if in_channels != out_channels:
266
  self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
267
+ self.is_shortcut = True
268
+ else: self.is_shortcut = False
269
 
270
+ def forward(self, x):
271
+ return (self.conv(x) + self.shortcut(x)) if self.is_shortcut else (self.conv(x) + x)
272
+
273
+ class ResEncoderBlock(nn.Module):
274
+ def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
275
+ super(ResEncoderBlock, self).__init__()
276
+ self.n_blocks = n_blocks
277
+ self.conv = nn.ModuleList()
278
+ self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
279
 
280
+ for _ in range(n_blocks - 1):
281
+ self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
282
+
283
+ self.kernel_size = kernel_size
284
+ if self.kernel_size is not None: self.pool = nn.AvgPool2d(kernel_size=kernel_size)
285
+
286
+ def forward(self, x):
287
+ for i in range(self.n_blocks):
288
+ x = self.conv[i](x)
289
+
290
+ if self.kernel_size is not None: return x, self.pool(x)
291
+ else: return x
292
 
293
  class Encoder(nn.Module):
294
+ def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
 
 
 
 
 
 
 
 
 
295
  super(Encoder, self).__init__()
296
  self.n_encoders = n_encoders
297
  self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
298
  self.layers = nn.ModuleList()
299
+
300
+ for _ in range(self.n_encoders):
301
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
 
 
 
 
 
302
  in_channels = out_channels
303
  out_channels *= 2
304
  in_size //= 2
305
+
306
  self.out_size = in_size
307
  self.out_channel = out_channels
308
 
309
+ def forward(self, x):
310
+ concat_tensors = []
311
  x = self.bn(x)
312
+
313
+ for layer in self.layers:
314
  t, x = layer(x)
315
  concat_tensors.append(t)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
316
 
317
+ return x, concat_tensors
318
 
319
+ class Intermediate(nn.Module):
320
  def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
321
  super(Intermediate, self).__init__()
 
322
  self.layers = nn.ModuleList()
323
+ self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
324
+
325
+ for _ in range(n_inters - 1):
326
+ self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
 
 
 
327
 
328
  def forward(self, x):
329
+ for layer in self.layers:
330
  x = layer(x)
 
331
 
332
+ return x
333
 
334
  class ResDecoderBlock(nn.Module):
335
  def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
336
  super(ResDecoderBlock, self).__init__()
337
  out_padding = (0, 1) if stride == (1, 2) else (1, 1)
 
338
  self.conv1 = nn.Sequential(
339
  nn.ConvTranspose2d(
340
+ in_channels=in_channels,
341
+ out_channels=out_channels,
342
+ kernel_size=(3, 3),
343
+ stride=stride,
344
+ padding=(1, 1),
345
+ output_padding=out_padding,
346
+ bias=False
347
+ ),
348
+ nn.BatchNorm2d(
349
+ out_channels,
350
+ momentum=momentum
351
+ ),
352
+ nn.ReLU()
353
  )
354
+
355
  self.conv2 = nn.ModuleList()
356
  self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
357
+
358
+ for _ in range(n_blocks - 1):
359
  self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
360
 
361
  def forward(self, x, concat_tensor):
362
+ x = torch.cat((self.conv1(x), concat_tensor), dim=1)
363
+ for conv2 in self.conv2:
 
364
  x = conv2(x)
 
365
 
366
+ return x
367
 
368
  class Decoder(nn.Module):
369
  def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
370
  super(Decoder, self).__init__()
371
  self.layers = nn.ModuleList()
372
+
373
+ for _ in range(n_decoders):
374
  out_channels = in_channels // 2
375
+ self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
 
 
376
  in_channels = out_channels
377
 
378
+ def forward(self, x, concat_tensors):
379
  for i, layer in enumerate(self.layers):
380
  x = layer(x, concat_tensors[-1 - i])
 
381
 
382
+ return x
383
 
384
  class DeepUnet(nn.Module):
385
+ def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
 
 
 
 
 
 
 
 
386
  super(DeepUnet, self).__init__()
387
+ self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
388
+ self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
389
+ self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
390
+
391
+ def forward(self, x):
392
+ x, concat_tensors = self.encoder(x)
393
+ return self.decoder(self.intermediate(x), concat_tensors)
394
+
395
+ class HPADeepUnet(nn.Module):
396
+ def __init__(self, in_channels=1, en_out_channels=16, base_channels=64, hyperace_k=2, hyperace_l=1, num_hyperedges=16, num_heads=8):
397
+ super().__init__()
398
+ self.encoder = YOLO13Encoder(in_channels, base_channels)
399
+ enc_ch = self.encoder.out_channels
400
+
401
+ self.hyperace = HyperACE(
402
+ in_channels=enc_ch,
403
+ out_channels=enc_ch[-1],
404
+ num_hyperedges=num_hyperedges,
405
+ num_heads=num_heads,
406
+ k=hyperace_k,
407
+ l=hyperace_l
408
  )
409
+
410
+ self.decoder = YOLO13FullPADDecoder(
411
+ encoder_channels=enc_ch,
412
+ hyperace_out_c=enc_ch[-1],
413
+ out_channels_final=en_out_channels
414
  )
415
 
416
+ def forward(self, x):
417
+ features = self.encoder(x)
418
+ return nn.functional.interpolate(self.decoder(features, self.hyperace(features)), size=x.shape[2:], mode='bilinear', align_corners=False)
 
 
419
 
420
+ class BiGRU(nn.Module):
421
+ def __init__(self, input_features, hidden_features, num_layers):
422
+ super(BiGRU, self).__init__()
423
+ self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
424
 
425
+ def forward(self, x):
426
+ try:
427
+ return self.gru(x)[0]
428
+ except:
429
+ torch.backends.cudnn.enabled = False
430
+ return self.gru(x)[0]
431
+
432
  class E2E(nn.Module):
433
+ def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16, hpa=False):
 
 
 
 
 
 
 
 
 
434
  super(E2E, self).__init__()
435
+ self.unet = (
436
+ HPADeepUnet(
437
+ in_channels=in_channels,
438
+ en_out_channels=en_out_channels,
439
+ base_channels=64,
440
+ hyperace_k=2,
441
+ hyperace_l=1,
442
+ num_hyperedges=16,
443
+ num_heads=4
444
+ )
445
+ ) if hpa else (
446
+ DeepUnet(
447
+ kernel_size,
448
+ n_blocks,
449
+ en_de_layers,
450
+ inter_layers,
451
+ in_channels,
452
+ en_out_channels
453
+ )
454
  )
455
+
456
  self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
457
+ self.fc = (
458
+ nn.Sequential(
459
+ BiGRU(3 * 128, 256, n_gru),
460
+ nn.Linear(512, N_CLASS),
461
+ nn.Dropout(0.25),
462
+ nn.Sigmoid()
463
  )
464
+ ) if n_gru else (
465
+ nn.Sequential(
466
+ nn.Linear(3 * N_MELS, N_CLASS),
467
+ nn.Dropout(0.25),
468
+ nn.Sigmoid()
469
  )
470
+ )
471
 
472
  def forward(self, mel):
473
+ return self.fc(self.cnn(self.unet(mel.transpose(-1, -2).unsqueeze(1))).transpose(1, 2).flatten(-2))
 
 
 
 
 
 
 
 
 
474
 
475
+ class MelSpectrogram(nn.Module):
476
+ def __init__(self, n_mel_channels, sample_rate, win_length, hop_length, n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
 
 
 
 
 
 
 
 
 
 
 
477
  super().__init__()
478
  n_fft = win_length if n_fft is None else n_fft
479
  self.hann_window = {}
480
+ mel_basis = mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True)
 
 
 
 
 
 
 
481
  mel_basis = torch.from_numpy(mel_basis).float()
482
  self.register_buffer("mel_basis", mel_basis)
483
  self.n_fft = win_length if n_fft is None else n_fft
484
  self.hop_length = hop_length
485
  self.win_length = win_length
486
+ self.sample_rate = sample_rate
487
  self.n_mel_channels = n_mel_channels
488
  self.clamp = clamp
 
489
 
490
  def forward(self, audio, keyshift=0, speed=1, center=True):
491
  factor = 2 ** (keyshift / 12)
 
492
  win_length_new = int(np.round(self.win_length * factor))
 
493
  keyshift_key = str(keyshift) + "_" + str(audio.device)
494
+ if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
495
+
496
+ n_fft = int(np.round(self.n_fft * factor))
497
+ hop_length = int(np.round(self.hop_length * speed))
498
+
499
+ fft = torch.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, return_complex=True)
500
+ magnitude = (fft.real.pow(2) + fft.imag.pow(2)).sqrt()
501
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
502
  if keyshift != 0:
503
  size = self.n_fft // 2 + 1
504
  resize = magnitude.size(1)
505
+ if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
 
506
  magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
 
 
 
 
 
507
 
508
+ mel_output = self.mel_basis @ magnitude
509
+ return mel_output.clamp(min=self.clamp).log()
510
 
511
  class RMVPE:
512
+ def __init__(self, model_path, is_half, device=None, providers=None, onnx=False, hpa=False):
513
+ self.onnx = onnx
514
+
515
+ if self.onnx:
 
 
 
 
 
 
 
516
  import onnxruntime as ort
517
 
518
+ sess_options = ort.SessionOptions()
519
+ sess_options.log_severity_level = 3
520
+ self.model = ort.InferenceSession(model_path, sess_options=sess_options, providers=providers)
 
 
521
  else:
522
+ model = E2E(4, 1, (2, 2), 5, 4, 1, 16, hpa=hpa)
523
+
524
+ model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
525
+ model.eval()
526
+ if is_half: model = model.half()
527
+ self.model = model.to(device)
528
+
529
+ self.device = device
530
+ self.is_half = is_half
531
+ self.mel_extractor = MelSpectrogram(N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
532
+ cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
533
+ self.cents_mapping = np.pad(cents_mapping, (4, 4))
534
+
535
+ def mel2hidden(self, mel, chunk_size = 32000):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
  with torch.no_grad():
537
  n_frames = mel.shape[-1]
538
+ mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect")
539
+
540
+ output_chunks = []
541
+ pad_frames = mel.shape[-1]
542
+
543
+ for start in range(0, pad_frames, chunk_size):
544
+ mel_chunk = mel[..., start:min(start + chunk_size, pad_frames)]
545
+ assert mel_chunk.shape[-1] % 32 == 0
546
+
547
+ if self.onnx:
548
+ mel_chunk = mel_chunk.cpu().numpy().astype(np.float32)
549
+ out_chunk = torch.as_tensor(self.model.run([self.model.get_outputs()[0].name], {self.model.get_inputs()[0].name: mel_chunk})[0], device=self.device)
550
+ else:
551
+ if self.is_half: mel_chunk = mel_chunk.half()
552
+ out_chunk = self.model(mel_chunk)
553
+
554
+ output_chunks.append(out_chunk)
555
+
556
+ hidden = torch.cat(output_chunks, dim=1)
557
  return hidden[:, :n_frames]
558
 
559
  def decode(self, hidden, thred=0.03):
560
+ f0 = 10 * (2 ** (self.to_local_average_cents(hidden, thred=thred) / 1200))
 
561
  f0[f0 == 10] = 0
562
+
563
  return f0
564
 
565
  def infer_from_audio(self, audio, thred=0.03):
566
+ hidden = self.mel2hidden(self.mel_extractor(torch.from_numpy(audio).float().to(self.device).unsqueeze(0), center=True))
567
+
568
+ return self.decode(hidden.squeeze(0).cpu().numpy().astype(np.float32), thred=thred)
569
+
570
+ def infer_from_audio_with_pitch(self, audio, thred=0.03, f0_min=50, f0_max=1100):
571
+ f0 = self.infer_from_audio(audio, thred)
572
+ f0[(f0 < f0_min) | (f0 > f0_max)] = 0
573
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  return f0
575
 
576
  def to_local_average_cents(self, salience, thred=0.05):
577
+ center = np.argmax(salience, axis=1)
578
+ salience = np.pad(salience, ((0, 0), (4, 4)))
 
 
579
  center += 4
580
+ todo_salience, todo_cents_mapping = [], []
 
581
  starts = center - 4
582
  ends = center + 5
583
+
584
  for idx in range(salience.shape[0]):
585
  todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
586
  todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
587
+
588
+ todo_salience = np.array(todo_salience)
589
+ devided = np.sum(todo_salience * np.array(todo_cents_mapping), 1) / np.sum(todo_salience, 1)
590
+ devided[np.max(salience, axis=1) <= thred] = 0
591
+
592
+ return devided