In [20]:
%cd /home/ubuntu/higgs_audio_train

import librosa
import torch
import torch.nn.functional as F
import numpy as np
import json
import torch
from IPython.display import Audio as Sawt
from higgs_audio_tokenizer import HiggsAudioTokenizer
import torch
import torch.nn as nn
import warnings

/home/ubuntu/higgs_audio_train


In [None]:
%cd /home/ubuntu/higgs_audio_train

import librosa
import torch
import torch.nn.functional as F
import numpy as np
import json
import torch
from IPython.display import Audio as Sawt
from higgs_audio_tokenizer import HiggsAudioTokenizer
import torch
import torch.nn as nn
import warnings


class EncodedResult:
 def __init__(self, audio_codes, quantized):
 self.audio_codes = audio_codes
 self.quantized = quantized


def encode_batch(model, x_batch):
 """
 Encodes a batch of audio tensors using the HiggsAudioTokenizer model.
 Args:
 model: The loaded HiggsAudioTokenizer model.
 x_batch: A tensor of shape [B, 1, T]
 """
 # Acoustic and Semantic Feature Extraction
 e_semantic_input = model.get_regress_target(x_batch).detach()
 e_semantic = model.encoder_semantic(e_semantic_input.transpose(1, 2))
 e_acoustic = model.encoder(x_batch)

 # This block contains the fix for batch processing
 if e_acoustic.shape[2] != e_semantic.shape[2]:
 pad_size = 160 * model.semantic_downsample_factor
 
 # 1. Remove channel dim, preserving batch dim -> [B, T]
 x_slice = x_batch[:, 0, :]
 
 # 2. Pad the tensor
 x_padded = F.pad(x_slice, (pad_size, pad_size))
 
 # 3. Re-add channel dim before passing to encoder -> [B, 1, T_padded]
 e_acoustic = model.encoder(x_padded.unsqueeze(1))

 # Ensure dimensions match before concatenating
 min_len = min(e_acoustic.shape[2], e_semantic.shape[2])
 e_acoustic = e_acoustic[:, :, :min_len]
 e_semantic = e_semantic[:, :, :min_len]

 # Remainder of the original encoding logic
 e = torch.cat([e_acoustic, e_semantic], dim=1)
 e = model.fc_prior(e.transpose(1, 2))

 if model.quantizer_type == "RVQ":
 e = e.transpose(1, 2)
 quantized, codes, _, _ = model.quantizer(e, model.frame_rate, None)
 codes = codes.permute(1, 0, 2)
 else: # RFSQ
 quantized, codes = model.quantizer(e)
 codes = codes.permute(0, 2, 1)

 return EncodedResult(audio_codes=codes, quantized=quantized)

def prepare(checkpoint_path, config_path, device='cuda'):

 # Load config
 print("Loading config...")
 with open(config_path, 'r') as f:
 config = json.load(f)
 
 # Create model
 print("Creating model...")
 model = HiggsAudioTokenizer(
 n_filters=config['n_filters'],
 D=config['D'],
 target_bandwidths=config['target_bandwidths'],
 ratios=config['ratios'],
 sample_rate=config['sample_rate'],
 bins=config['bins'],
 n_q=config['n_q'],
 codebook_dim=config.get('codebook_dim', None),
 semantic_techer=config['semantic_techer'],
 device=device
 ).to(device)
 
 # Load checkpoint
 print("Loading checkpoint...")
 checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
 
 if 'model_state_dict' in checkpoint:
 state_dict = checkpoint['model_state_dict']
 else:
 state_dict = checkpoint
 
 # Remove 'module.' prefix if present (from DDP)
 new_state_dict = {}
 for k, v in state_dict.items():
 if k.startswith('module.'):
 new_state_dict[k[7:]] = v
 else:
 new_state_dict[k] = v
 
 model.load_state_dict(new_state_dict, strict=False)
 

 
 return model

# Run the complete pipeline
checkpoint_path = '/home/ubuntu/higgs_audio_train/25hz_CQT_step_99000.pth' #NOTE: this is a 25cps test model trained during a single afternoon on a small dataset. in no way it is an indication of this architecture at its best.
config_path = '/home/ubuntu/higgs_audio_train/config_25.json'

device = 'cuda'
model = prepare(checkpoint_path, config_path, device)
_ = model.eval()

/home/ubuntu/higgs_audio_train
Loading config...
Creating model...


Loading checkpoint...


In [None]:


# ---------------------------------------------------------------------------------------------------


path = "shiki_test.wav"
# path = "/home/ubuntu/qatilu.wav"
wav, sr = librosa.load(path, sr=44100)

wav = torch.from_numpy(wav).unsqueeze(0).float().to('cuda')

with torch.no_grad():

 encoded = encode_batch(model, wav.unsqueeze(0)) 
 recon = model.decode(encoded.audio_codes).squeeze(0)
 
display(Sawt(recon, rate=sr))
display(Sawt(path))

