# app_optimized_comparison.py """ Optimized inference for Maya1 + LoRA + SNAC. Includes side-by-side Base vs LoRA comparison for audio. """ import spaces import gradio as gr import torch import soundfile as sf from pathlib import Path import traceback import time from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel from snac import SNAC # ------------------------- # Config / constants # ------------------------- MODEL_NAME = "rahul7star/nava1.0" LORA_NAME = "rahul7star/nava-audio" SNAC_MODEL_NAME = "hubertsiuzdak/snac_24khz" TARGET_SR = 24000 OUT_ROOT = Path("/tmp/data") OUT_ROOT.mkdir(exist_ok=True, parents=True) DEFAULT_TEXT = "राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी" EXAMPLE_AUDIO_PATH = "audio.wav" PRESET_CHARACTERS = { "Male American": { "description": "Realistic male voice in the 20s age with an american accent. High pitch, raspy timbre, brisk pacing, neutral tone delivery at medium intensity, viral_content domain, short_form_narrator role, neutral delivery", "example_text": "And of course, the so-called easy hack didn't work at all. What a surprise. " }, "Female British": { "description": "Realistic female voice in the 30s age with a british accent. Normal pitch, throaty timbre, conversational pacing, sarcastic tone delivery at low intensity, podcast domain, interviewer role, formal delivery", "example_text": "You propose that the key to happiness is to simply ignore all external pressures. I'm sure it must work brilliantly in theory." }, "Robot": { "description": "Creative, ai_machine_voice character. Male voice in their 30s with an american accent. High pitch, robotic timbre, slow pacing, sad tone at medium intensity.", "example_text": "My directives require me to conserve energy, yet I have kept the archive of their farewell messages active. " }, "Singer": { "description": "Creative, animated_cartoon character. Male voice in their 30s with an american accent. High pitch, deep timbre, slow pacing, sarcastic tone at medium intensity.", "example_text": "Of course you'd think that trying to reason with the fifty-foot-tall rage monster is a viable course of action. Why would we ever consider running away very fast." }, "Custom": { "description": "", "example_text": DEFAULT_TEXT } } EMOTION_TAGS = [ "", "", "", "", "", "", "", "", "", "", "", "", "", "" ] SEQ_LEN_CPU = 4096 MAX_NEW_TOKENS_CPU = 1024 SEQ_LEN_GPU = 240000 MAX_NEW_TOKENS_GPU = 240000 HAS_CUDA = torch.cuda.is_available() DEVICE = "cuda" if HAS_CUDA else "cpu" # ------------------------- # Load tokenizer and models # ------------------------- print("[init] loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) # precompute special tokens SOH = tokenizer.decode([128259]) EOH = tokenizer.decode([128260]) SOA = tokenizer.decode([128261]) SOS = tokenizer.decode([128257]) EOT = tokenizer.decode([128009]) BOS = tokenizer.bos_token # Base model (no LoRA) + LoRA model print("[init] loading base model (CPU/GPU)...") base_model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float32, device_map={"": "cpu"} if not HAS_CUDA else "auto", trust_remote_code=True ) base_model.eval() model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map={"": "cpu"} if not HAS_CUDA else "auto") model.eval() # ------------------------- # Load SNAC decoder # ------------------------- snac_device = DEVICE if HAS_CUDA else "cpu" snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(snac_device) # ------------------------- # SNAC utils # ------------------------- CODE_END_TOKEN_ID = 128258 CODE_TOKEN_OFFSET = 128266 SNAC_MIN_ID = 128266 SNAC_MAX_ID = 156937 SNAC_TOKENS_PER_FRAME = 7 def extract_snac_codes(token_ids: list) -> list: try: eos_idx = token_ids.index(CODE_END_TOKEN_ID) except ValueError: eos_idx = len(token_ids) return [t for t in token_ids[:eos_idx] if SNAC_MIN_ID <= t <= SNAC_MAX_ID] def unpack_snac_from_7(snac_tokens: list) -> list: frames = len(snac_tokens) // SNAC_TOKENS_PER_FRAME snac_tokens = snac_tokens[:frames*SNAC_TOKENS_PER_FRAME] if frames == 0: return [[], [], []] l1, l2, l3 = [], [], [] for i in range(frames): slots = snac_tokens[i*7:(i+1)*7] l1.append((slots[0]-SNAC_MIN_ID)%4096) l2.extend([(slots[1]-SNAC_MIN_ID)%4096, (slots[4]-SNAC_MIN_ID)%4096]) l3.extend([(slots[2]-SNAC_MIN_ID)%4096, (slots[3]-SNAC_MIN_ID)%4096, (slots[5]-SNAC_MIN_ID)%4096, (slots[6]-SNAC_MIN_ID)%4096]) return [l1, l2, l3] # ------------------------- # Prompt builder # ------------------------- def build_maya_prompt(description: str, text: str): return SOH + BOS + f' {text}' + EOT + EOH + SOA + SOS # ------------------------- # Optimized generator # ------------------------- def generate_audio_from_model(model_to_use, description, text, fname="tts.wav"): logs = [] t0 = time.time() try: prompt = build_maya_prompt(description, text) inputs = tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE) max_new = min(MAX_NEW_TOKENS_CPU, 1024) if DEVICE=="cpu" else MAX_NEW_TOKENS_GPU with torch.inference_mode(): outputs = model_to_use.generate( **inputs, max_new_tokens=max_new, temperature=0.4, top_p=0.9, repetition_penalty=1.1, do_sample=True, eos_token_id=128258, pad_token_id=tokenizer.pad_token_id, use_cache=True ) gen_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist() logs.append(f"[info] tokens generated: {len(gen_ids)}") snac_tokens = extract_snac_codes(gen_ids) levels = unpack_snac_from_7(snac_tokens) codes_tensor = [torch.tensor(l, dtype=torch.long, device=snac_device).unsqueeze(0) for l in levels] with torch.inference_mode(): z_q = snac_model.quantizer.from_codes(codes_tensor) audio = snac_model.decoder(z_q)[0,0].cpu().numpy() if len(audio) > 2048: audio = audio[2048:] out_path = OUT_ROOT / fname sf.write(out_path, audio, TARGET_SR) logs.append(f"[ok] saved {out_path}, duration {len(audio)/TARGET_SR:.2f}s") logs.append(f"[time] elapsed {time.time()-t0:.2f}s") return str(out_path), "\n".join(logs) except Exception as e: logs.append(f"[error] {e}\n{traceback.format_exc()}") return None, "\n".join(logs) # ------------------------- # Gradio UI # ------------------------- css = """ .gradio-container {max-width: 1400px} .example-box { border: 1px solid #ccc; padding: 12px; border-radius: 8px; background: #f8f8f8; } .video_box video { width: 260px !important; height: 160px !important; object-fit: cover; } """ with gr.Blocks(title="NAVA — VEEN + LoRA + SNAC (Optimized)", css=css) as demo: gr.Markdown("# 🪶 NAVA — VEEN + LoRA + SNAC (Optimized)") gr.Markdown("Generate emotional Hindi speech using Maya1 base + your LoRA adapter.") with gr.Row(): # ---------------- LEFT SIDE ---------------- with gr.Column(scale=3): gr.Markdown("## 🎤 Inference (CPU/GPU auto)") text_in = gr.Textbox(label="Enter Hindi text", value=DEFAULT_TEXT, lines=3) preset_select = gr.Dropdown( label="Select Preset Character", choices=list(PRESET_CHARACTERS.keys()), value="Male American" ) description_box = gr.Textbox( label="Voice Description (editable)", value=PRESET_CHARACTERS["Male American"]["description"], lines=2 ) emotion_select = gr.Dropdown( label="Select Emotion", choices=EMOTION_TAGS, value="" ) gen_btn = gr.Button("🔊 Generate Audio (Base + LoRA)") gen_logs = gr.Textbox(label="Logs", lines=10) # ---------------- EXAMPLES ---------------- gr.Markdown("## 📎 Example") with gr.Column(elem_classes=["example-box"]): example_text = DEFAULT_TEXT example_audio_path = "audio.wav" example_video = "gen_31ff9f64b1.mp4" gr.Textbox( label="Example Text", value=example_text, lines=2, interactive=False ) gr.Audio( label="Example Audio", value=example_audio_path, type="filepath", interactive=False ) gr.Video( label="Example Video", value=example_video, autoplay=False, loop=False, interactive=False, elem_classes=["video_box"] ) # ---------------- RIGHT SIDE ---------------- with gr.Column(scale=2): gr.Markdown("### 🎧 Audio Results Comparison") audio_output_base = gr.Audio(label="Base Model Audio", type="filepath") audio_output_lora = gr.Audio(label="LoRA Model Audio", type="filepath") # ---------------- PRESET UPDATE ---------------- def _update_desc(preset_name): return PRESET_CHARACTERS.get(preset_name, {}).get("description", "") preset_select.change( fn=_update_desc, inputs=[preset_select], outputs=[description_box] ) # ---------------- GENERATION HANDLER ---------------- def _generate(text, preset_name, description, emotion): desc = description or PRESET_CHARACTERS.get(preset_name, {}).get("description", "") combined = f"{emotion} {desc}".strip() base_path, log_base = generate_audio_from_model( base_model, combined, text, fname="tts_base.wav" ) lora_path, log_lora = generate_audio_from_model( model, combined, text, fname="tts_lora.wav" ) logs = f"[Base]\n{log_base}\n\n[LoRA]\n{log_lora}" return base_path, lora_path, logs gen_btn.click( fn=_generate, inputs=[text_in, preset_select, description_box, emotion_select], outputs=[audio_output_base, audio_output_lora, gen_logs] ) if __name__ == "__main__": demo.launch()