File size: 3,454 Bytes
1744895
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from fastapi import FastAPI, UploadFile, File, Query, HTTPException, Request
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from snac import SNAC
import torch
import torchaudio
import io
import os

# --- CONFIGURATION ---
# Your Adapter (Trained Model)
ADAPTER_REPO = "rudraPtGenAi/Orpheus-3B-Hindi-syspin-Smart" 

# The Base Model (Must be Pretrain because that is what you trained on)
BASE_MODEL = "canopylabs/3b-hi-ft-research_release"

DEVICE = "cpu"
MAX_SNAC_VALUE = 4095

app = FastAPI(title="Voice Tech API")

print("--- πŸš€ Initializing Space ---")

try:
    # Get token from Secret
    hf_token = os.environ.get("HF_TOKEN")
    if not hf_token: print("⚠️ Warning: HF_TOKEN not found.")

    print(f"Loading Base: {BASE_MODEL}")
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        torch_dtype=torch.float32, 
        device_map="cpu",
        low_cpu_mem_usage=True,
        token=hf_token
    )
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=hf_token)

    print(f"Loading Adapters: {ADAPTER_REPO}")
    model = PeftModel.from_pretrained(model, ADAPTER_REPO, token=hf_token)
    model.eval()

    print("Loading SNAC...")
    snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(DEVICE)
    snac_model.eval()
    print("βœ… System Ready!")

except Exception as e:
    print(f"❌ Critical Error: {e}")

def reconstruct_snac_codes(flat_list):
    remainder = len(flat_list) % 7
    if remainder != 0: flat_list = flat_list[:-remainder]
    n_frames = len(flat_list) // 7
    c0, c1, c2 = [], [], []
    for t in range(n_frames):
        chunk = flat_list[t*7 : (t+1)*7]
        clean_chunk = [x if 0 <= x <= MAX_SNAC_VALUE else 0 for x in chunk]
        c0.append(clean_chunk[0])
        c1.extend(clean_chunk[1:3])
        c2.extend(clean_chunk[3:7])
    return [
        torch.tensor(c0).unsqueeze(0).to(DEVICE),
        torch.tensor(c1).unsqueeze(0).to(DEVICE),
        torch.tensor(c2).unsqueeze(0).to(DEVICE)
    ]

@app.api_route("/Get_Inference", methods=["GET", "POST"])
async def get_inference(
    request: Request,
    text: str = Query(..., description="Input text"),
    lang: str = Query(..., description="Language"),
    speaker_wav: UploadFile = File(None, description="Reference WAV")
):
    print(f"πŸ“© Request: '{text}'")
    if not text: raise HTTPException(400, "Text required")

    try:
        prompt = text + "\n"
        inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
        input_len = inputs.input_ids.shape[1]

        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=400,
                pad_token_id=tokenizer.pad_token_id,
                eos_token_id=tokenizer.eos_token_id,
                temperature=0.3,
                do_sample=True
            )

        generated_ids = output_ids[0, input_len:].tolist()
        if not generated_ids: return {"error": "Model generated silence"}

        codes = reconstruct_snac_codes(generated_ids)
        audio_hat = snac_model.decode(codes)
        
        buffer = io.BytesIO()
        torchaudio.save(buffer, audio_hat.squeeze(0), 24000, format="wav")
        buffer.seek(0)

        return StreamingResponse(buffer, media_type="audio/wav")

    except Exception as e:
        print(f"Error: {e}")
        return {"error": str(e)}