Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,9 +6,10 @@ import csv
|
|
| 6 |
import zipfile
|
| 7 |
import re
|
| 8 |
import difflib
|
|
|
|
| 9 |
from typing import List, Optional, Dict, Any
|
| 10 |
|
| 11 |
-
from fastapi import FastAPI, UploadFile, File, HTTPException
|
| 12 |
from fastapi.middleware.cors import CORSMiddleware
|
| 13 |
from pydantic import BaseModel
|
| 14 |
|
|
@@ -16,6 +17,7 @@ import torch
|
|
| 16 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 17 |
from langdetect import detect
|
| 18 |
from transformers import MarianMTModel, MarianTokenizer
|
|
|
|
| 19 |
|
| 20 |
# ======================================================
|
| 21 |
# 0) Configuración general
|
|
@@ -33,6 +35,12 @@ os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
| 33 |
# { conn_id: { "db_path": str, "label": str } }
|
| 34 |
DB_REGISTRY: Dict[str, Dict[str, Any]] = {}
|
| 35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
# ======================================================
|
| 37 |
# 1) Inicialización de FastAPI
|
| 38 |
# ======================================================
|
|
@@ -641,6 +649,11 @@ class InferResponse(BaseModel):
|
|
| 641 |
candidates: List[Dict[str, Any]]
|
| 642 |
|
| 643 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
# ======================================================
|
| 645 |
# 7) Endpoints FastAPI
|
| 646 |
# ======================================================
|
|
@@ -784,6 +797,57 @@ async def infer_sql(req: InferRequest):
|
|
| 784 |
return InferResponse(**result)
|
| 785 |
|
| 786 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 787 |
@app.get("/health")
|
| 788 |
async def health():
|
| 789 |
return {
|
|
@@ -804,6 +868,7 @@ async def root():
|
|
| 804 |
"GET /schema/{id} (esquema resumido)",
|
| 805 |
"GET /preview/{id}/{t} (preview de tabla)",
|
| 806 |
"POST /infer (NL→SQL + ejecución)",
|
|
|
|
| 807 |
"GET /health (estado del backend)",
|
| 808 |
"GET /docs (OpenAPI UI)",
|
| 809 |
],
|
|
|
|
| 6 |
import zipfile
|
| 7 |
import re
|
| 8 |
import difflib
|
| 9 |
+
import tempfile
|
| 10 |
from typing import List, Optional, Dict, Any
|
| 11 |
|
| 12 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
|
| 13 |
from fastapi.middleware.cors import CORSMiddleware
|
| 14 |
from pydantic import BaseModel
|
| 15 |
|
|
|
|
| 17 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 18 |
from langdetect import detect
|
| 19 |
from transformers import MarianMTModel, MarianTokenizer
|
| 20 |
+
from openai import OpenAI
|
| 21 |
|
| 22 |
# ======================================================
|
| 23 |
# 0) Configuración general
|
|
|
|
| 35 |
# { conn_id: { "db_path": str, "label": str } }
|
| 36 |
DB_REGISTRY: Dict[str, Dict[str, Any]] = {}
|
| 37 |
|
| 38 |
+
# Cliente OpenAI para transcripción de audio (Whisper / gpt-4o-transcribe)
|
| 39 |
+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
|
| 40 |
+
if not OPENAI_API_KEY:
|
| 41 |
+
print("⚠️ OPENAI_API_KEY no está definido. El endpoint /speech-infer no funcionará hasta configurarlo.")
|
| 42 |
+
openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
|
| 43 |
+
|
| 44 |
# ======================================================
|
| 45 |
# 1) Inicialización de FastAPI
|
| 46 |
# ======================================================
|
|
|
|
| 649 |
candidates: List[Dict[str, Any]]
|
| 650 |
|
| 651 |
|
| 652 |
+
class SpeechInferResponse(BaseModel):
|
| 653 |
+
transcript: str
|
| 654 |
+
result: InferResponse
|
| 655 |
+
|
| 656 |
+
|
| 657 |
# ======================================================
|
| 658 |
# 7) Endpoints FastAPI
|
| 659 |
# ======================================================
|
|
|
|
| 797 |
return InferResponse(**result)
|
| 798 |
|
| 799 |
|
| 800 |
+
@app.post("/speech-infer", response_model=SpeechInferResponse)
|
| 801 |
+
async def speech_infer(
|
| 802 |
+
connection_id: str = Form(...),
|
| 803 |
+
audio: UploadFile = File(...)
|
| 804 |
+
):
|
| 805 |
+
"""
|
| 806 |
+
Endpoint para consultas por VOZ:
|
| 807 |
+
- Recibe audio desde el navegador (multipart/form-data).
|
| 808 |
+
- Usa gpt-4o-transcribe para obtener el texto.
|
| 809 |
+
- Reutiliza el pipeline NL→SQL existente.
|
| 810 |
+
"""
|
| 811 |
+
if openai_client is None:
|
| 812 |
+
raise HTTPException(
|
| 813 |
+
status_code=500,
|
| 814 |
+
detail="OPENAI_API_KEY no está configurado en el backend."
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
if audio.content_type is None:
|
| 818 |
+
raise HTTPException(status_code=400, detail="Archivo de audio inválido.")
|
| 819 |
+
|
| 820 |
+
# 1) Guardar audio temporalmente
|
| 821 |
+
try:
|
| 822 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as tmp:
|
| 823 |
+
tmp.write(await audio.read())
|
| 824 |
+
tmp_path = tmp.name
|
| 825 |
+
except Exception:
|
| 826 |
+
raise HTTPException(status_code=500, detail="No se pudo procesar el audio recibido.")
|
| 827 |
+
|
| 828 |
+
# 2) Transcribir con gpt-4o-transcribe
|
| 829 |
+
try:
|
| 830 |
+
with open(tmp_path, "rb") as f:
|
| 831 |
+
transcription = openai_client.audio.transcriptions.create(
|
| 832 |
+
model="gpt-4o-transcribe",
|
| 833 |
+
file=f,
|
| 834 |
+
# language="es", # opcional, si quieres forzar español
|
| 835 |
+
)
|
| 836 |
+
transcript_text: str = transcription.text
|
| 837 |
+
except Exception as e:
|
| 838 |
+
raise HTTPException(status_code=500, detail=f"Error al transcribir audio: {e}")
|
| 839 |
+
|
| 840 |
+
# 3) Reutilizar el pipeline NL→SQL con el texto transcrito
|
| 841 |
+
result_dict = nl2sql_with_rerank(transcript_text, connection_id)
|
| 842 |
+
infer_result = InferResponse(**result_dict)
|
| 843 |
+
|
| 844 |
+
# 4) Devolver transcripción + resultado NL→SQL
|
| 845 |
+
return SpeechInferResponse(
|
| 846 |
+
transcript=transcript_text,
|
| 847 |
+
result=infer_result,
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
|
| 851 |
@app.get("/health")
|
| 852 |
async def health():
|
| 853 |
return {
|
|
|
|
| 868 |
"GET /schema/{id} (esquema resumido)",
|
| 869 |
"GET /preview/{id}/{t} (preview de tabla)",
|
| 870 |
"POST /infer (NL→SQL + ejecución)",
|
| 871 |
+
"POST /speech-infer (NL por voz → SQL + ejecución)",
|
| 872 |
"GET /health (estado del backend)",
|
| 873 |
"GET /docs (OpenAPI UI)",
|
| 874 |
],
|