Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import torch | |
| import soundfile as sf | |
| from pathlib import Path | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| from snac import SNAC | |
| # ----------------------------- | |
| # CONFIG | |
| # ----------------------------- | |
| MODEL_NAME = "rahul7star/nava1.0" | |
| LORA_NAME = "rahul7star/nava-audio" | |
| SEQ_LEN = 240000 | |
| TARGET_SR = 240000 | |
| OUT_ROOT = Path("/tmp/data") | |
| OUT_ROOT.mkdir(parents=True, exist_ok=True) | |
| DEFAULT_TEXT = ( | |
| "राजनीतिज्ञों ने कहा कि उन्होंने निर्णायक मत को अनावश्यक रूप से " | |
| "निर्धारित करने के लिए अफ़गान संविधान में काफी अस्पष्टता पाई थी" | |
| ) | |
| # ----------------------------- | |
| # GENERATE AUDIO (LoRA) | |
| # ----------------------------- | |
| def generate_audio_cpu_lora(text: str): | |
| logs = [] | |
| try: | |
| DEVICE_CPU = "cpu" | |
| print(text) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| device_map={"": DEVICE_CPU}, | |
| torch_dtype=torch.float32, | |
| trust_remote_code=True | |
| ) | |
| model = PeftModel.from_pretrained(base_model, LORA_NAME, device_map={"": DEVICE_CPU}) | |
| model.eval() | |
| soh_token = tokenizer.decode([128259]) | |
| eoh_token = tokenizer.decode([128260]) | |
| soa_token = tokenizer.decode([128261]) | |
| sos_token = tokenizer.decode([128257]) | |
| eot_token = tokenizer.decode([128009]) | |
| bos_token = tokenizer.bos_token | |
| prompt = soh_token + bos_token + text + eot_token + eoh_token + soa_token + sos_token | |
| inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE_CPU) | |
| with torch.inference_mode(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=SEQ_LEN, | |
| temperature=0.4, | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| do_sample=True, | |
| eos_token_id=128258, | |
| pad_token_id=tokenizer.pad_token_id | |
| ) | |
| generated_ids = outputs[0, inputs['input_ids'].shape[1]:].tolist() | |
| snac_min, snac_max = 128266, 156937 | |
| eos_id = 128258 | |
| try: | |
| eos_idx = generated_ids.index(eos_id) | |
| except ValueError: | |
| eos_idx = len(generated_ids) | |
| snac_tokens = [t for t in generated_ids[:eos_idx] if snac_min <= t <= snac_max] | |
| l1, l2, l3 = [], [], [] | |
| frames = len(snac_tokens) // 7 | |
| snac_tokens = snac_tokens[:frames*7] | |
| for i in range(frames): | |
| slots = snac_tokens[i*7:(i+1)*7] | |
| l1.append((slots[0]-128266)%4096) | |
| l2.extend([(slots[1]-128266)%4096, (slots[4]-128266)%4096]) | |
| l3.extend([(slots[2]-128266)%4096, (slots[3]-128266)%4096, (slots[5]-128266)%4096, (slots[6]-128266)%4096]) | |
| snac_model = SNAC.from_pretrained("rahul7star/nava-snac").eval().to(DEVICE_CPU) | |
| codes_tensor = [torch.tensor(level, dtype=torch.long, device=DEVICE_CPU).unsqueeze(0) for level in [l1,l2,l3]] | |
| with torch.inference_mode(): | |
| z_q = snac_model.quantizer.from_codes(codes_tensor) | |
| audio = snac_model.decoder(z_q)[0, 0].cpu().numpy() | |
| if len(audio) > 2048: | |
| audio = audio[2048:] | |
| audio_path = OUT_ROOT / "tts_output_cpu_lora.wav" | |
| sf.write(audio_path, audio, TARGET_SR) | |
| return str(audio_path), str(audio_path), "\n".join(logs) | |
| except Exception as e: | |
| import traceback | |
| logs.append(f"[❌] CPU LoRA TTS error: {e}\n{traceback.format_exc()}") | |
| print(e) | |
| return None, None, "\n".join(logs) | |
| # ----------------------------- | |
| # GRADIO UI | |
| # ----------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Maya LoRA TTS (CPU)- 10 mins gen time else switch to GPU ") | |
| gr.Markdown("# Full Credit to Maya Team members ") | |
| # Input text | |
| input_text = gr.Textbox(label="Enter text", lines=2, value=DEFAULT_TEXT) | |
| # Generate button | |
| run_button = gr.Button("🔊 Generate Audio") | |
| # Outputs | |
| audio_output = gr.Audio(label="Play Generated Audio", type="filepath") | |
| download_output = gr.File(label="Download Audio") | |
| logs_output = gr.Textbox(label="Logs", lines=12) | |
| run_button.click( | |
| fn=generate_audio_cpu_lora, | |
| inputs=[input_text], | |
| outputs=[audio_output, download_output, logs_output] | |
| ) | |
| # ----------------------------- | |
| # Example section | |
| # ----------------------------- | |
| gr.Markdown("### Example") | |
| example_text = DEFAULT_TEXT | |
| example_audio_path = "audio.wav" | |
| gr.Textbox(label="Example Text", value=example_text, lines=2, interactive=False) | |
| gr.Audio(label="Example Audio", value=example_audio_path, type="filepath", interactive=False) | |
| if __name__ == "__main__": | |
| demo.launch() | |