{ "cells": [ { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/ubuntu/higgs_audio_train\n" ] } ], "source": [ "%cd /home/ubuntu/higgs_audio_train\n", "\n", "import librosa\n", "import torch\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import json\n", "import torch\n", "from IPython.display import Audio as Sawt\n", "from higgs_audio_tokenizer import HiggsAudioTokenizer\n", "import torch\n", "import torch.nn as nn\n", "import warnings" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/home/ubuntu/higgs_audio_train\n", "Loading config...\n", "Creating model...\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Loading checkpoint...\n" ] } ], "source": [ "%cd /home/ubuntu/higgs_audio_train\n", "\n", "import librosa\n", "import torch\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import json\n", "import torch\n", "from IPython.display import Audio as Sawt\n", "from higgs_audio_tokenizer import HiggsAudioTokenizer\n", "import torch\n", "import torch.nn as nn\n", "import warnings\n", "\n", "\n", "class EncodedResult:\n", " def __init__(self, audio_codes, quantized):\n", " self.audio_codes = audio_codes\n", " self.quantized = quantized\n", "\n", "\n", "def encode_batch(model, x_batch):\n", " \"\"\"\n", " Encodes a batch of audio tensors using the HiggsAudioTokenizer model.\n", " Args:\n", " model: The loaded HiggsAudioTokenizer model.\n", " x_batch: A tensor of shape [B, 1, T]\n", " \"\"\"\n", " # Acoustic and Semantic Feature Extraction\n", " e_semantic_input = model.get_regress_target(x_batch).detach()\n", " e_semantic = model.encoder_semantic(e_semantic_input.transpose(1, 2))\n", " e_acoustic = model.encoder(x_batch)\n", "\n", " # This block contains the fix for batch processing\n", " if e_acoustic.shape[2] != e_semantic.shape[2]:\n", " pad_size = 160 * model.semantic_downsample_factor\n", " \n", " # 1. Remove channel dim, preserving batch dim -> [B, T]\n", " x_slice = x_batch[:, 0, :]\n", " \n", " # 2. Pad the tensor\n", " x_padded = F.pad(x_slice, (pad_size, pad_size))\n", " \n", " # 3. Re-add channel dim before passing to encoder -> [B, 1, T_padded]\n", " e_acoustic = model.encoder(x_padded.unsqueeze(1))\n", "\n", " # Ensure dimensions match before concatenating\n", " min_len = min(e_acoustic.shape[2], e_semantic.shape[2])\n", " e_acoustic = e_acoustic[:, :, :min_len]\n", " e_semantic = e_semantic[:, :, :min_len]\n", "\n", " # Remainder of the original encoding logic\n", " e = torch.cat([e_acoustic, e_semantic], dim=1)\n", " e = model.fc_prior(e.transpose(1, 2))\n", "\n", " if model.quantizer_type == \"RVQ\":\n", " e = e.transpose(1, 2)\n", " quantized, codes, _, _ = model.quantizer(e, model.frame_rate, None)\n", " codes = codes.permute(1, 0, 2)\n", " else: # RFSQ\n", " quantized, codes = model.quantizer(e)\n", " codes = codes.permute(0, 2, 1)\n", "\n", " return EncodedResult(audio_codes=codes, quantized=quantized)\n", "\n", "def prepare(checkpoint_path, config_path, device='cuda'):\n", "\n", " # Load config\n", " print(\"Loading config...\")\n", " with open(config_path, 'r') as f:\n", " config = json.load(f)\n", " \n", " # Create model\n", " print(\"Creating model...\")\n", " model = HiggsAudioTokenizer(\n", " n_filters=config['n_filters'],\n", " D=config['D'],\n", " target_bandwidths=config['target_bandwidths'],\n", " ratios=config['ratios'],\n", " sample_rate=config['sample_rate'],\n", " bins=config['bins'],\n", " n_q=config['n_q'],\n", " codebook_dim=config.get('codebook_dim', None),\n", " semantic_techer=config['semantic_techer'],\n", " device=device\n", " ).to(device)\n", " \n", " # Load checkpoint\n", " print(\"Loading checkpoint...\")\n", " checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)\n", " \n", " if 'model_state_dict' in checkpoint:\n", " state_dict = checkpoint['model_state_dict']\n", " else:\n", " state_dict = checkpoint\n", " \n", " # Remove 'module.' prefix if present (from DDP)\n", " new_state_dict = {}\n", " for k, v in state_dict.items():\n", " if k.startswith('module.'):\n", " new_state_dict[k[7:]] = v\n", " else:\n", " new_state_dict[k] = v\n", " \n", " model.load_state_dict(new_state_dict, strict=False)\n", " \n", "\n", " \n", " return model\n", "\n", "# Run the complete pipeline\n", "checkpoint_path = '/home/ubuntu/higgs_audio_train/25hz_CQT_step_99000.pth' #NOTE: this is a 25cps test model trained during a single afternoon on a small dataset. in no way it is an indication of this architecture at its best.\n", "config_path = '/home/ubuntu/higgs_audio_train/config_25.json'\n", "\n", "device = 'cuda'\n", "model = prepare(checkpoint_path, config_path, device)\n", "_ = model.eval()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", " \n", " " ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "\n", "\n", "# ---------------------------------------------------------------------------------------------------\n", "\n", "\n", "path = \"shiki_test.wav\"\n", "# path = \"/home/ubuntu/qatilu.wav\"\n", "wav, sr = librosa.load(path, sr=44100)\n", "\n", "wav = torch.from_numpy(wav).unsqueeze(0).float().to('cuda')\n", "\n", "with torch.no_grad():\n", "\n", " encoded = encode_batch(model, wav.unsqueeze(0)) \n", " recon = model.decode(encoded.audio_codes).squeeze(0)\n", " \n", "display(Sawt(recon, rate=sr))\n", "display(Sawt(path))\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "respair", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 2 }