stvnnnnnn commited on
Commit
f9bcb56
·
verified ·
1 Parent(s): 54858c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -1
app.py CHANGED
@@ -6,9 +6,10 @@ import csv
6
  import zipfile
7
  import re
8
  import difflib
 
9
  from typing import List, Optional, Dict, Any
10
 
11
- from fastapi import FastAPI, UploadFile, File, HTTPException
12
  from fastapi.middleware.cors import CORSMiddleware
13
  from pydantic import BaseModel
14
 
@@ -16,6 +17,7 @@ import torch
16
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
17
  from langdetect import detect
18
  from transformers import MarianMTModel, MarianTokenizer
 
19
 
20
  # ======================================================
21
  # 0) Configuración general
@@ -33,6 +35,12 @@ os.makedirs(UPLOAD_DIR, exist_ok=True)
33
  # { conn_id: { "db_path": str, "label": str } }
34
  DB_REGISTRY: Dict[str, Dict[str, Any]] = {}
35
 
 
 
 
 
 
 
36
  # ======================================================
37
  # 1) Inicialización de FastAPI
38
  # ======================================================
@@ -641,6 +649,11 @@ class InferResponse(BaseModel):
641
  candidates: List[Dict[str, Any]]
642
 
643
 
 
 
 
 
 
644
  # ======================================================
645
  # 7) Endpoints FastAPI
646
  # ======================================================
@@ -784,6 +797,57 @@ async def infer_sql(req: InferRequest):
784
  return InferResponse(**result)
785
 
786
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
787
  @app.get("/health")
788
  async def health():
789
  return {
@@ -804,6 +868,7 @@ async def root():
804
  "GET /schema/{id} (esquema resumido)",
805
  "GET /preview/{id}/{t} (preview de tabla)",
806
  "POST /infer (NL→SQL + ejecución)",
 
807
  "GET /health (estado del backend)",
808
  "GET /docs (OpenAPI UI)",
809
  ],
 
6
  import zipfile
7
  import re
8
  import difflib
9
+ import tempfile
10
  from typing import List, Optional, Dict, Any
11
 
12
+ from fastapi import FastAPI, UploadFile, File, HTTPException, Form
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from pydantic import BaseModel
15
 
 
17
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
18
  from langdetect import detect
19
  from transformers import MarianMTModel, MarianTokenizer
20
+ from openai import OpenAI
21
 
22
  # ======================================================
23
  # 0) Configuración general
 
35
  # { conn_id: { "db_path": str, "label": str } }
36
  DB_REGISTRY: Dict[str, Dict[str, Any]] = {}
37
 
38
+ # Cliente OpenAI para transcripción de audio (Whisper / gpt-4o-transcribe)
39
+ OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
40
+ if not OPENAI_API_KEY:
41
+ print("⚠️ OPENAI_API_KEY no está definido. El endpoint /speech-infer no funcionará hasta configurarlo.")
42
+ openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None
43
+
44
  # ======================================================
45
  # 1) Inicialización de FastAPI
46
  # ======================================================
 
649
  candidates: List[Dict[str, Any]]
650
 
651
 
652
+ class SpeechInferResponse(BaseModel):
653
+ transcript: str
654
+ result: InferResponse
655
+
656
+
657
  # ======================================================
658
  # 7) Endpoints FastAPI
659
  # ======================================================
 
797
  return InferResponse(**result)
798
 
799
 
800
+ @app.post("/speech-infer", response_model=SpeechInferResponse)
801
+ async def speech_infer(
802
+ connection_id: str = Form(...),
803
+ audio: UploadFile = File(...)
804
+ ):
805
+ """
806
+ Endpoint para consultas por VOZ:
807
+ - Recibe audio desde el navegador (multipart/form-data).
808
+ - Usa gpt-4o-transcribe para obtener el texto.
809
+ - Reutiliza el pipeline NL→SQL existente.
810
+ """
811
+ if openai_client is None:
812
+ raise HTTPException(
813
+ status_code=500,
814
+ detail="OPENAI_API_KEY no está configurado en el backend."
815
+ )
816
+
817
+ if audio.content_type is None:
818
+ raise HTTPException(status_code=400, detail="Archivo de audio inválido.")
819
+
820
+ # 1) Guardar audio temporalmente
821
+ try:
822
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as tmp:
823
+ tmp.write(await audio.read())
824
+ tmp_path = tmp.name
825
+ except Exception:
826
+ raise HTTPException(status_code=500, detail="No se pudo procesar el audio recibido.")
827
+
828
+ # 2) Transcribir con gpt-4o-transcribe
829
+ try:
830
+ with open(tmp_path, "rb") as f:
831
+ transcription = openai_client.audio.transcriptions.create(
832
+ model="gpt-4o-transcribe",
833
+ file=f,
834
+ # language="es", # opcional, si quieres forzar español
835
+ )
836
+ transcript_text: str = transcription.text
837
+ except Exception as e:
838
+ raise HTTPException(status_code=500, detail=f"Error al transcribir audio: {e}")
839
+
840
+ # 3) Reutilizar el pipeline NL→SQL con el texto transcrito
841
+ result_dict = nl2sql_with_rerank(transcript_text, connection_id)
842
+ infer_result = InferResponse(**result_dict)
843
+
844
+ # 4) Devolver transcripción + resultado NL→SQL
845
+ return SpeechInferResponse(
846
+ transcript=transcript_text,
847
+ result=infer_result,
848
+ )
849
+
850
+
851
  @app.get("/health")
852
  async def health():
853
  return {
 
868
  "GET /schema/{id} (esquema resumido)",
869
  "GET /preview/{id}/{t} (preview de tabla)",
870
  "POST /infer (NL→SQL + ejecución)",
871
+ "POST /speech-infer (NL por voz → SQL + ejecución)",
872
  "GET /health (estado del backend)",
873
  "GET /docs (OpenAPI UI)",
874
  ],