from fastapi import FastAPI, UploadFile, File, Query, HTTPException, Request from fastapi.responses import StreamingResponse from transformers import AutoModelForCausalLM, AutoTokenizer from peft import PeftModel from snac import SNAC import torch import torchaudio import io import os # --- CONFIGURATION --- # Your Adapter (Trained Model) ADAPTER_REPO = "rudraPtGenAi/Orpheus-3B-Hindi-syspin-Smart" # The Base Model (Must be Pretrain because that is what you trained on) BASE_MODEL = "canopylabs/3b-hi-ft-research_release" DEVICE = "cpu" MAX_SNAC_VALUE = 4095 app = FastAPI(title="Voice Tech API") print("--- 🚀 Initializing Space ---") try: # Get token from Secret hf_token = os.environ.get("HF_TOKEN") if not hf_token: print("⚠️ Warning: HF_TOKEN not found.") print(f"Loading Base: {BASE_MODEL}") model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, torch_dtype=torch.float32, device_map="cpu", low_cpu_mem_usage=True, token=hf_token ) tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=hf_token) print(f"Loading Adapters: {ADAPTER_REPO}") model = PeftModel.from_pretrained(model, ADAPTER_REPO, token=hf_token) model.eval() print("Loading SNAC...") snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(DEVICE) snac_model.eval() print("✅ System Ready!") except Exception as e: print(f"❌ Critical Error: {e}") def reconstruct_snac_codes(flat_list): remainder = len(flat_list) % 7 if remainder != 0: flat_list = flat_list[:-remainder] n_frames = len(flat_list) // 7 c0, c1, c2 = [], [], [] for t in range(n_frames): chunk = flat_list[t*7 : (t+1)*7] clean_chunk = [x if 0 <= x <= MAX_SNAC_VALUE else 0 for x in chunk] c0.append(clean_chunk[0]) c1.extend(clean_chunk[1:3]) c2.extend(clean_chunk[3:7]) return [ torch.tensor(c0).unsqueeze(0).to(DEVICE), torch.tensor(c1).unsqueeze(0).to(DEVICE), torch.tensor(c2).unsqueeze(0).to(DEVICE) ] @app.api_route("/Get_Inference", methods=["GET", "POST"]) async def get_inference( request: Request, text: str = Query(..., description="Input text"), lang: str = Query(..., description="Language"), speaker_wav: UploadFile = File(None, description="Reference WAV") ): print(f"📩 Request: '{text}'") if not text: raise HTTPException(400, "Text required") try: prompt = text + "\n" inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) input_len = inputs.input_ids.shape[1] with torch.no_grad(): output_ids = model.generate( **inputs, max_new_tokens=400, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, temperature=0.3, do_sample=True ) generated_ids = output_ids[0, input_len:].tolist() if not generated_ids: return {"error": "Model generated silence"} codes = reconstruct_snac_codes(generated_ids) audio_hat = snac_model.decode(codes) buffer = io.BytesIO() torchaudio.save(buffer, audio_hat.squeeze(0), 24000, format="wav") buffer.seek(0) return StreamingResponse(buffer, media_type="audio/wav") except Exception as e: print(f"Error: {e}") return {"error": str(e)}