Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, UploadFile, File, Query, HTTPException, Request | |
| from fastapi.responses import StreamingResponse | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| from snac import SNAC | |
| import torch | |
| import torchaudio | |
| import io | |
| import os | |
| # --- CONFIGURATION --- | |
| # Your Adapter (Trained Model) | |
| ADAPTER_REPO = "rudraPtGenAi/Orpheus-3B-Hindi-syspin-Smart" | |
| # The Base Model (Must be Pretrain because that is what you trained on) | |
| BASE_MODEL = "canopylabs/3b-hi-ft-research_release" | |
| DEVICE = "cpu" | |
| MAX_SNAC_VALUE = 4095 | |
| app = FastAPI(title="Voice Tech API") | |
| print("--- π Initializing Space ---") | |
| try: | |
| # Get token from Secret | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: print("β οΈ Warning: HF_TOKEN not found.") | |
| print(f"Loading Base: {BASE_MODEL}") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| low_cpu_mem_usage=True, | |
| token=hf_token | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=hf_token) | |
| print(f"Loading Adapters: {ADAPTER_REPO}") | |
| model = PeftModel.from_pretrained(model, ADAPTER_REPO, token=hf_token) | |
| model.eval() | |
| print("Loading SNAC...") | |
| snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(DEVICE) | |
| snac_model.eval() | |
| print("β System Ready!") | |
| except Exception as e: | |
| print(f"β Critical Error: {e}") | |
| def reconstruct_snac_codes(flat_list): | |
| remainder = len(flat_list) % 7 | |
| if remainder != 0: flat_list = flat_list[:-remainder] | |
| n_frames = len(flat_list) // 7 | |
| c0, c1, c2 = [], [], [] | |
| for t in range(n_frames): | |
| chunk = flat_list[t*7 : (t+1)*7] | |
| clean_chunk = [x if 0 <= x <= MAX_SNAC_VALUE else 0 for x in chunk] | |
| c0.append(clean_chunk[0]) | |
| c1.extend(clean_chunk[1:3]) | |
| c2.extend(clean_chunk[3:7]) | |
| return [ | |
| torch.tensor(c0).unsqueeze(0).to(DEVICE), | |
| torch.tensor(c1).unsqueeze(0).to(DEVICE), | |
| torch.tensor(c2).unsqueeze(0).to(DEVICE) | |
| ] | |
| async def get_inference( | |
| request: Request, | |
| text: str = Query(..., description="Input text"), | |
| lang: str = Query(..., description="Language"), | |
| speaker_wav: UploadFile = File(None, description="Reference WAV") | |
| ): | |
| print(f"π© Request: '{text}'") | |
| if not text: raise HTTPException(400, "Text required") | |
| try: | |
| prompt = text + "\n" | |
| inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) | |
| input_len = inputs.input_ids.shape[1] | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=400, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| temperature=0.3, | |
| do_sample=True | |
| ) | |
| generated_ids = output_ids[0, input_len:].tolist() | |
| if not generated_ids: return {"error": "Model generated silence"} | |
| codes = reconstruct_snac_codes(generated_ids) | |
| audio_hat = snac_model.decode(codes) | |
| buffer = io.BytesIO() | |
| torchaudio.save(buffer, audio_hat.squeeze(0), 24000, format="wav") | |
| buffer.seek(0) | |
| return StreamingResponse(buffer, media_type="audio/wav") | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| return {"error": str(e)} |