{
"cells": [
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/ubuntu/higgs_audio_train\n"
]
}
],
"source": [
"%cd /home/ubuntu/higgs_audio_train\n",
"\n",
"import librosa\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import numpy as np\n",
"import json\n",
"import torch\n",
"from IPython.display import Audio as Sawt\n",
"from higgs_audio_tokenizer import HiggsAudioTokenizer\n",
"import torch\n",
"import torch.nn as nn\n",
"import warnings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/home/ubuntu/higgs_audio_train\n",
"Loading config...\n",
"Creating model...\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loading checkpoint...\n"
]
}
],
"source": [
"%cd /home/ubuntu/higgs_audio_train\n",
"\n",
"import librosa\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import numpy as np\n",
"import json\n",
"import torch\n",
"from IPython.display import Audio as Sawt\n",
"from higgs_audio_tokenizer import HiggsAudioTokenizer\n",
"import torch\n",
"import torch.nn as nn\n",
"import warnings\n",
"\n",
"\n",
"class EncodedResult:\n",
" def __init__(self, audio_codes, quantized):\n",
" self.audio_codes = audio_codes\n",
" self.quantized = quantized\n",
"\n",
"\n",
"def encode_batch(model, x_batch):\n",
" \"\"\"\n",
" Encodes a batch of audio tensors using the HiggsAudioTokenizer model.\n",
" Args:\n",
" model: The loaded HiggsAudioTokenizer model.\n",
" x_batch: A tensor of shape [B, 1, T]\n",
" \"\"\"\n",
" # Acoustic and Semantic Feature Extraction\n",
" e_semantic_input = model.get_regress_target(x_batch).detach()\n",
" e_semantic = model.encoder_semantic(e_semantic_input.transpose(1, 2))\n",
" e_acoustic = model.encoder(x_batch)\n",
"\n",
" # This block contains the fix for batch processing\n",
" if e_acoustic.shape[2] != e_semantic.shape[2]:\n",
" pad_size = 160 * model.semantic_downsample_factor\n",
" \n",
" # 1. Remove channel dim, preserving batch dim -> [B, T]\n",
" x_slice = x_batch[:, 0, :]\n",
" \n",
" # 2. Pad the tensor\n",
" x_padded = F.pad(x_slice, (pad_size, pad_size))\n",
" \n",
" # 3. Re-add channel dim before passing to encoder -> [B, 1, T_padded]\n",
" e_acoustic = model.encoder(x_padded.unsqueeze(1))\n",
"\n",
" # Ensure dimensions match before concatenating\n",
" min_len = min(e_acoustic.shape[2], e_semantic.shape[2])\n",
" e_acoustic = e_acoustic[:, :, :min_len]\n",
" e_semantic = e_semantic[:, :, :min_len]\n",
"\n",
" # Remainder of the original encoding logic\n",
" e = torch.cat([e_acoustic, e_semantic], dim=1)\n",
" e = model.fc_prior(e.transpose(1, 2))\n",
"\n",
" if model.quantizer_type == \"RVQ\":\n",
" e = e.transpose(1, 2)\n",
" quantized, codes, _, _ = model.quantizer(e, model.frame_rate, None)\n",
" codes = codes.permute(1, 0, 2)\n",
" else: # RFSQ\n",
" quantized, codes = model.quantizer(e)\n",
" codes = codes.permute(0, 2, 1)\n",
"\n",
" return EncodedResult(audio_codes=codes, quantized=quantized)\n",
"\n",
"def prepare(checkpoint_path, config_path, device='cuda'):\n",
"\n",
" # Load config\n",
" print(\"Loading config...\")\n",
" with open(config_path, 'r') as f:\n",
" config = json.load(f)\n",
" \n",
" # Create model\n",
" print(\"Creating model...\")\n",
" model = HiggsAudioTokenizer(\n",
" n_filters=config['n_filters'],\n",
" D=config['D'],\n",
" target_bandwidths=config['target_bandwidths'],\n",
" ratios=config['ratios'],\n",
" sample_rate=config['sample_rate'],\n",
" bins=config['bins'],\n",
" n_q=config['n_q'],\n",
" codebook_dim=config.get('codebook_dim', None),\n",
" semantic_techer=config['semantic_techer'],\n",
" device=device\n",
" ).to(device)\n",
" \n",
" # Load checkpoint\n",
" print(\"Loading checkpoint...\")\n",
" checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)\n",
" \n",
" if 'model_state_dict' in checkpoint:\n",
" state_dict = checkpoint['model_state_dict']\n",
" else:\n",
" state_dict = checkpoint\n",
" \n",
" # Remove 'module.' prefix if present (from DDP)\n",
" new_state_dict = {}\n",
" for k, v in state_dict.items():\n",
" if k.startswith('module.'):\n",
" new_state_dict[k[7:]] = v\n",
" else:\n",
" new_state_dict[k] = v\n",
" \n",
" model.load_state_dict(new_state_dict, strict=False)\n",
" \n",
"\n",
" \n",
" return model\n",
"\n",
"# Run the complete pipeline\n",
"checkpoint_path = '/home/ubuntu/higgs_audio_train/25hz_CQT_step_99000.pth' #NOTE: this is a 25cps test model trained during a single afternoon on a small dataset. in no way it is an indication of this architecture at its best.\n",
"config_path = '/home/ubuntu/higgs_audio_train/config_25.json'\n",
"\n",
"device = 'cuda'\n",
"model = prepare(checkpoint_path, config_path, device)\n",
"_ = model.eval()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"text/html": [
"\n",
" \n",
" "
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"\n",
"# ---------------------------------------------------------------------------------------------------\n",
"\n",
"\n",
"path = \"shiki_test.wav\"\n",
"# path = \"/home/ubuntu/qatilu.wav\"\n",
"wav, sr = librosa.load(path, sr=44100)\n",
"\n",
"wav = torch.from_numpy(wav).unsqueeze(0).float().to('cuda')\n",
"\n",
"with torch.no_grad():\n",
"\n",
" encoded = encode_batch(model, wav.unsqueeze(0)) \n",
" recon = model.decode(encoded.audio_codes).squeeze(0)\n",
" \n",
"display(Sawt(recon, rate=sr))\n",
"display(Sawt(path))\n",
"\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "respair",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}