stvnnnnnn commited on
Commit
48fbd23
·
verified ·
1 Parent(s): 2b2e5ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -419
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import uuid
3
  import io
4
  import zipfile
5
  import re
@@ -17,8 +16,7 @@ from langdetect import detect
17
  from transformers import MarianMTModel, MarianTokenizer
18
  from openai import OpenAI
19
 
20
- import psycopg2
21
- import mysql.connector
22
 
23
  # ======================================================
24
  # 0) Configuración general
@@ -28,25 +26,8 @@ import mysql.connector
28
  MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider")
29
  DEVICE = torch.device("cpu") # inferencia en CPU
30
 
31
- # === Motores reales: variables de entorno ===
32
- # PostgreSQL: usaremos UN solo DB (POSTGRES_DB) y un schema por conexión lógica
33
- POSTGRES_HOST = os.getenv("POSTGRES_HOST", "localhost")
34
- POSTGRES_PORT = int(os.getenv("POSTGRES_PORT", "5432"))
35
- POSTGRES_USER = os.getenv("POSTGRES_USER", "postgres")
36
- POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD", "postgres")
37
- POSTGRES_DB = os.getenv("POSTGRES_DB", "postgres")
38
-
39
- # MySQL: crearemos una base de datos por conexión lógica
40
- MYSQL_HOST = os.getenv("MYSQL_HOST", "localhost")
41
- MYSQL_PORT = int(os.getenv("MYSQL_PORT", "3306"))
42
- MYSQL_USER = os.getenv("MYSQL_USER", "root")
43
- MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD", "root")
44
-
45
- # Registro en memoria de conexiones:
46
- # { conn_id: { "engine": "postgresql"|"mysql", "namespace": str, "label": str } }
47
- # - engine = motor real
48
- # - namespace = schema (Postgres) o database (MySQL) donde vive esa BD
49
- DB_REGISTRY: Dict[str, Dict[str, Any]] = {}
50
 
51
  # Cliente OpenAI para transcripción de audio (Whisper / gpt-4o-transcribe)
52
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
@@ -62,8 +43,8 @@ app = FastAPI(
62
  title="NL2SQL T5-large Backend (MySQL/PostgreSQL)",
63
  description=(
64
  "Intérprete NL→SQL (T5-large Spider) para usuarios no expertos. "
65
- "El usuario sube dumps .sql / .zip y se cargan en motores reales "
66
- "(MySQL/PostgreSQL)."
67
  ),
68
  version="2.0.0",
69
  )
@@ -137,271 +118,12 @@ def translate_es_to_en(text: str) -> str:
137
 
138
 
139
  # ======================================================
140
- # 3) Conexiones a motores reales y helpers
141
- # ======================================================
142
-
143
- def get_pg_conn():
144
- return psycopg2.connect(
145
- host=POSTGRES_HOST,
146
- port=POSTGRES_PORT,
147
- user=POSTGRES_USER,
148
- password=POSTGRES_PASSWORD,
149
- dbname=POSTGRES_DB,
150
- )
151
-
152
-
153
- def get_mysql_conn(db_name: Optional[str] = None):
154
- params = dict(
155
- host=MYSQL_HOST,
156
- port=MYSQL_PORT,
157
- user=MYSQL_USER,
158
- password=MYSQL_PASSWORD,
159
- )
160
- if db_name:
161
- params["database"] = db_name
162
- return mysql.connector.connect(**params)
163
-
164
-
165
- def detect_sql_dialect(sql_text: str) -> str:
166
- """
167
- Heurística simple:
168
- - Si ve ENGINE, AUTO_INCREMENT, backticks, etc. → MySQL
169
- - Si ve SERIAL, search_path, ::, PL/pgSQL, etc. → PostgreSQL
170
- """
171
- text = sql_text.lower()
172
- if any(kw in text for kw in ["engine=", "auto_increment", "unsigned", " collate ", " character set ", "`"]):
173
- return "mysql"
174
- if any(kw in text for kw in ["serial", " set search_path", "copy ", "::", "language plpgsql"]):
175
- return "postgresql"
176
- return "unknown"
177
-
178
-
179
- def create_logical_db(engine: str, label: str) -> str:
180
- """
181
- Crea una "BD lógica":
182
- - En PostgreSQL: un SCHEMA nuevo dentro de POSTGRES_DB
183
- - En MySQL: una DATABASE nueva
184
- """
185
- conn_id = f"db_{uuid.uuid4().hex[:8]}"
186
-
187
- if engine == "postgresql":
188
- schema_name = conn_id
189
- conn = get_pg_conn()
190
- try:
191
- conn.autocommit = True
192
- cur = conn.cursor()
193
- cur.execute(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}";')
194
- cur.close()
195
- finally:
196
- conn.close()
197
- DB_REGISTRY[conn_id] = {"engine": "postgresql", "namespace": schema_name, "label": label}
198
-
199
- elif engine == "mysql":
200
- db_name = conn_id
201
- conn = get_mysql_conn()
202
- try:
203
- cur = conn.cursor()
204
- cur.execute(f"CREATE DATABASE IF NOT EXISTS `{db_name}`;")
205
- conn.commit()
206
- cur.close()
207
- finally:
208
- conn.close()
209
- DB_REGISTRY[conn_id] = {"engine": "mysql", "namespace": db_name, "label": label}
210
-
211
- else:
212
- raise ValueError(f"Engine no soportado: {engine}")
213
-
214
- return conn_id
215
-
216
-
217
- def ensure_connection(conn_id: str) -> Dict[str, Any]:
218
- if conn_id not in DB_REGISTRY:
219
- raise HTTPException(status_code=404, detail=f"connection_id '{conn_id}' no registrado")
220
- return DB_REGISTRY[conn_id]
221
-
222
-
223
- # ======================================================
224
- # 4) Carga de scripts SQL (schema.sql / data.sql / ZIP)
225
- # ======================================================
226
-
227
- def _execute_sql_script_postgres(schema: str, sql_text: str) -> None:
228
- conn = get_pg_conn()
229
- try:
230
- conn.autocommit = False
231
- cur = conn.cursor()
232
- # Trabajamos siempre dentro del schema lógico
233
- cur.execute(f'SET search_path TO "{schema}";')
234
- # Ejecución simple por ';' (suficiente para Sakila/Pagila)
235
- parts = sql_text.split(";")
236
- for stmt in parts:
237
- s = stmt.strip()
238
- if not s:
239
- continue
240
- cur.execute(s + ";")
241
- conn.commit()
242
- cur.close()
243
- finally:
244
- conn.close()
245
-
246
-
247
- def _execute_sql_script_mysql(db_name: str, sql_text: str) -> None:
248
- conn = get_mysql_conn()
249
- try:
250
- cur = conn.cursor()
251
- cur.execute(f"USE `{db_name}`;")
252
- # mysql-connector permite multi=True
253
- for _ in cur.execute(sql_text, multi=True):
254
- pass
255
- conn.commit()
256
- cur.close()
257
- finally:
258
- conn.close()
259
-
260
-
261
- def load_sql_into_connection(sql_text: str, engine: str, conn_id: Optional[str], label: str) -> str:
262
- """
263
- - Si conn_id es None → crea nueva BD lógica y carga el script.
264
- - Si conn_id existe → ejecuta el script encima (ej: primero schema.sql, luego data.sql).
265
- """
266
- if not engine or engine not in ("mysql", "postgresql"):
267
- raise HTTPException(status_code=400, detail="Engine inválido o no soportado (usa 'mysql' o 'postgresql').")
268
-
269
- if conn_id is None:
270
- conn_id = create_logical_db(engine, label)
271
- else:
272
- ensure_connection(conn_id)
273
-
274
- info = DB_REGISTRY[conn_id]
275
- namespace = info["namespace"]
276
-
277
- if engine == "postgresql":
278
- _execute_sql_script_postgres(namespace, sql_text)
279
- else:
280
- _execute_sql_script_mysql(namespace, sql_text)
281
-
282
- return conn_id
283
-
284
-
285
- # ======================================================
286
- # 5) Introspección de esquema y ejecución (sobre motores reales)
287
- # ======================================================
288
-
289
- def introspect_schema(conn_id: str) -> Dict[str, Any]:
290
- info = ensure_connection(conn_id)
291
- engine = info["engine"]
292
- ns = info["namespace"]
293
-
294
- tables_info: Dict[str, Dict[str, List[str]]] = {}
295
- parts: List[str] = []
296
-
297
- if engine == "postgresql":
298
- conn = get_pg_conn()
299
- try:
300
- cur = conn.cursor()
301
- cur.execute(
302
- """
303
- SELECT table_name, column_name
304
- FROM information_schema.columns
305
- WHERE table_schema = %s
306
- ORDER BY table_name, ordinal_position;
307
- """,
308
- (ns,),
309
- )
310
- rows = cur.fetchall()
311
- cur.close()
312
- finally:
313
- conn.close()
314
- for table, col in rows:
315
- tables_info.setdefault(table, {"columns": []})
316
- tables_info[table]["columns"].append(col)
317
-
318
- else: # MySQL
319
- conn = get_mysql_conn(ns)
320
- try:
321
- cur = conn.cursor()
322
- cur.execute(
323
- """
324
- SELECT table_name, column_name
325
- FROM information_schema.columns
326
- WHERE table_schema = %s
327
- ORDER BY table_name, ordinal_position;
328
- """,
329
- (ns,),
330
- )
331
- rows = cur.fetchall()
332
- cur.close()
333
- finally:
334
- conn.close()
335
- for table, col in rows:
336
- tables_info.setdefault(table, {"columns": []})
337
- tables_info[table]["columns"].append(col)
338
-
339
- for t, info_t in tables_info.items():
340
- parts.append(f"{t}(" + ", ".join(info_t["columns"]) + ")")
341
-
342
- schema_str = " ; ".join(parts) if parts else "(empty_schema)"
343
-
344
- return {
345
- "tables": tables_info,
346
- "foreign_keys": [], # podrías enriquecer esto luego
347
- "schema_str": schema_str,
348
- }
349
-
350
-
351
- def execute_sql(conn_id: str, sql: str) -> Dict[str, Any]:
352
- """
353
- Ejecuta SOLO consultas de lectura (SELECT).
354
- Bloquea operaciones destructivas por seguridad en el demo.
355
- """
356
- info = ensure_connection(conn_id)
357
- engine = info["engine"]
358
- ns = info["namespace"]
359
-
360
- forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace ", "truncate ", "create "]
361
- sql_low = sql.lower()
362
- if any(f in sql_low for f in forbidden):
363
- return {
364
- "ok": False,
365
- "error": "Query bloqueada por seguridad (operación potencialmente destructiva).",
366
- "rows": None,
367
- "columns": [],
368
- }
369
-
370
- try:
371
- if engine == "postgresql":
372
- conn = get_pg_conn()
373
- try:
374
- cur = conn.cursor()
375
- cur.execute(f'SET search_path TO "{ns}";')
376
- cur.execute(sql)
377
- rows = cur.fetchall()
378
- cols = [desc[0] for desc in cur.description] if cur.description else []
379
- cur.close()
380
- finally:
381
- conn.close()
382
- else:
383
- conn = get_mysql_conn(ns)
384
- try:
385
- cur = conn.cursor()
386
- cur.execute(sql)
387
- rows = cur.fetchall()
388
- cols = [desc[0] for desc in cur.description] if cur.description else []
389
- cur.close()
390
- finally:
391
- conn.close()
392
-
393
- return {"ok": True, "error": None, "rows": rows, "columns": cols}
394
- except Exception as e:
395
- return {"ok": False, "error": str(e), "rows": None, "columns": []}
396
-
397
-
398
- # ======================================================
399
- # 6) SQL REPAIR LAYER (igual que antes, pero agnóstico de motor)
400
  # ======================================================
401
 
402
  def _normalize_name_for_match(name: str) -> str:
403
  s = name.lower()
404
- s = s.replace('"', "").replace("`", "")
405
  s = s.replace("_", "")
406
  if s.endswith("s") and len(s) > 3:
407
  s = s[:-1]
@@ -430,9 +152,11 @@ def _build_schema_indexes(tables_info: Dict[str, Dict[str, List[str]]]) -> Dict[
430
  def _best_match_name(missing: str, index: Dict[str, List[str]]) -> Optional[str]:
431
  if not index:
432
  return None
 
433
  key = _normalize_name_for_match(missing)
434
  if key in index and index[key]:
435
  return index[key][0]
 
436
  candidates = difflib.get_close_matches(key, list(index.keys()), n=1, cutoff=0.7)
437
  if not candidates:
438
  return None
@@ -465,6 +189,9 @@ DOMAIN_SYNONYMS_COLUMN = {
465
 
466
 
467
  def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optional[str]:
 
 
 
468
  tables_info = schema_meta["tables"]
469
  idx = _build_schema_indexes(tables_info)
470
  table_index = idx["table_index"]
@@ -476,11 +203,15 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
476
  missing_table = None
477
  missing_column = None
478
 
479
- m_t = re.search(r"no such table: ([\w\.]+)", error or "", re.IGNORECASE)
 
 
480
  if m_t:
481
  missing_table = m_t.group(1)
482
 
483
- m_c = re.search(r"no such column: ([\w\.]+)", error or "", re.IGNORECASE)
 
 
484
  if m_c:
485
  missing_column = m_c.group(1)
486
 
@@ -522,7 +253,7 @@ def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optiona
522
 
523
 
524
  # ======================================================
525
- # 7) Construcción de prompt y NL→SQL + re-ranking
526
  # ======================================================
527
 
528
  def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
@@ -534,9 +265,18 @@ def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
534
 
535
 
536
  def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
537
- ensure_connection(conn_id)
538
- meta = introspect_schema(conn_id)
539
- schema_str = meta["schema_str"]
 
 
 
 
 
 
 
 
 
540
 
541
  detected = detect_language(question)
542
  question_en = translate_es_to_en(question) if detected == "es" else question
@@ -582,19 +322,21 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
582
  "raw_sql_model": raw_sql,
583
  }
584
 
585
- exec_info = execute_sql(conn_id, raw_sql)
586
 
 
587
  if (not exec_info["ok"]) and (
588
  "no such table" in (exec_info["error"] or "").lower()
589
  or "no such column" in (exec_info["error"] or "").lower()
 
590
  ):
591
  current_sql = raw_sql
592
- last_error = exec_info["error"]
593
  for step in range(1, 4):
594
- repaired_sql = try_repair_sql(current_sql, last_error or "", meta)
595
  if not repaired_sql or repaired_sql == current_sql:
596
  break
597
- exec_info2 = execute_sql(conn_id, repaired_sql)
598
  cand["repaired_from"] = current_sql if cand["repaired_from"] is None else cand["repaired_from"]
599
  cand["repair_note"] = f"auto-repair (table/column name, step {step})"
600
  cand["sql"] = repaired_sql
@@ -602,12 +344,12 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
602
  current_sql = repaired_sql
603
  if exec_info2["ok"]:
604
  break
605
- last_error = exec_info2["error"]
606
 
607
  cand["exec_ok"] = exec_info["ok"]
608
  cand["exec_error"] = exec_info["error"]
609
  cand["rows_preview"] = (
610
- [list(r) for r in exec_info["rows"][:5]] if exec_info["ok"] and exec_info["rows"] else None
611
  )
612
  cand["columns"] = exec_info["columns"]
613
 
@@ -641,33 +383,31 @@ def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
641
 
642
 
643
  # ======================================================
644
- # 8) Schemas Pydantic
645
  # ======================================================
646
 
647
  class UploadResponse(BaseModel):
648
  connection_id: str
649
  label: str
650
- engine: str
651
- namespace: str # schema (PG) o database (MySQL)
652
  note: Optional[str] = None
653
 
654
 
655
  class ConnectionInfo(BaseModel):
656
  connection_id: str
657
  label: str
658
- engine: str
 
659
 
660
 
661
  class SchemaResponse(BaseModel):
662
  connection_id: str
663
- engine: str
664
  schema_summary: str
665
  tables: Dict[str, Dict[str, List[str]]]
666
 
667
 
668
  class PreviewResponse(BaseModel):
669
  connection_id: str
670
- engine: str
671
  table: str
672
  columns: List[str]
673
  rows: List[List[Any]]
@@ -698,27 +438,61 @@ class SpeechInferResponse(BaseModel):
698
 
699
 
700
  # ======================================================
701
- # 9) Endpoints FastAPI
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
702
  # ======================================================
703
 
704
  @app.on_event("startup")
705
  async def startup_event():
706
  load_nl2sql_model()
707
- print("✅ Backend NL2SQL inicializado (motores MySQL/PostgreSQL).")
 
 
708
 
709
 
710
- @app.post("/upload-sql", response_model=UploadResponse)
711
- async def upload_sql(
712
- db_file: UploadFile = File(...),
713
- engine: Optional[str] = Form(None), # "mysql" | "postgresql" (opcional)
714
- connection_id: Optional[str] = Form(None), # para el caso schema.sql + data.sql
715
- ):
716
  """
717
- Subida de dumps SQL.
718
- - Acepta .sql directo o .zip con varios .sql
719
- - Detecta motor automáticamente (MySQL/PostgreSQL) si engine no se especifica
720
- - Si connection_id es None, crea una nueva BD lógica
721
- - Si connection_id existe, ejecuta el SQL encima (ej: schema.sql luego data.sql)
722
  """
723
  filename = db_file.filename
724
  if not filename:
@@ -727,140 +501,118 @@ async def upload_sql(
727
  fname_lower = filename.lower()
728
  contents = await db_file.read()
729
 
730
- if not (fname_lower.endswith(".sql") or fname_lower.endswith(".zip")):
731
- raise HTTPException(
732
- status_code=400,
733
- detail="Formato no soportado. Usa: .sql o .zip (con archivos .sql).",
734
- )
735
-
736
- sql_text = ""
737
- note = None
738
 
739
- # Caso: archivo .sql único
740
  if fname_lower.endswith(".sql"):
741
  sql_text = contents.decode("utf-8", errors="ignore")
742
- detected = detect_sql_dialect(sql_text)
743
- final_engine = engine or detected
744
- if final_engine == "unknown":
745
  raise HTTPException(
746
  status_code=400,
747
- detail=(
748
- "No se pudo detectar el motor SQL (MySQL/PostgreSQL). "
749
- "Vuelve a subir el archivo indicando engine='mysql' o 'postgresql'."
750
- ),
751
  )
 
 
 
 
752
 
753
- conn_id = load_sql_into_connection(sql_text, final_engine, connection_id, filename)
754
- info = DB_REGISTRY[conn_id]
755
- note = f"SQL ejecutado sobre motor {final_engine}."
756
- return UploadResponse(
757
- connection_id=conn_id,
758
- label=info["label"],
759
- engine=info["engine"],
760
- namespace=info["namespace"],
761
- note=note,
762
- )
763
-
764
- # Caso: ZIP con varios .sql (ej: schema.sql + data.sql, o muchos scripts)
765
- try:
766
- with zipfile.ZipFile(io.BytesIO(contents)) as zf:
767
- sql_names = [n for n in zf.namelist() if n.lower().endswith(".sql")]
768
- if not sql_names:
769
- raise HTTPException(
770
- status_code=400,
771
- detail="El ZIP no contiene archivos .sql utilizables.",
772
- )
773
-
774
- combined_sql_parts = []
775
- for name in sorted(sql_names):
776
- with zf.open(name) as f:
777
- combined_sql_parts.append(f"-- FILE: {name}\n")
778
- combined_sql_parts.append(f.read().decode("utf-8", errors="ignore"))
779
- sql_text = "\n\n".join(combined_sql_parts)
780
 
781
- except zipfile.BadZipFile:
782
- raise HTTPException(status_code=400, detail="Archivo ZIP inválido o corrupto.")
 
 
 
 
 
 
 
 
 
783
 
784
- detected = detect_sql_dialect(sql_text)
785
- final_engine = engine or detected
786
- if final_engine == "unknown":
787
  raise HTTPException(
788
  status_code=400,
789
- detail=(
790
- "No se pudo detectar el motor SQL (MySQL/PostgreSQL) en el ZIP. "
791
- "Vuelve a subir indicando engine='mysql' o 'postgresql'."
792
- ),
793
  )
794
 
795
- conn_id = load_sql_into_connection(sql_text, final_engine, connection_id, filename)
796
- info = DB_REGISTRY[conn_id]
797
- note = f"ZIP con scripts SQL ejecutado sobre motor {final_engine}."
 
 
 
 
798
  return UploadResponse(
799
  connection_id=conn_id,
800
- label=info["label"],
801
- engine=info["engine"],
802
- namespace=info["namespace"],
803
  note=note,
804
  )
805
 
806
 
807
  @app.get("/connections", response_model=List[ConnectionInfo])
808
  async def list_connections():
809
- out: List[ConnectionInfo] = []
810
- for cid, info in DB_REGISTRY.items():
811
- out.append(ConnectionInfo(connection_id=cid, label=info["label"], engine=info["engine"]))
812
- return out
 
 
 
 
 
 
 
 
 
 
 
813
 
814
 
815
  @app.get("/schema/{connection_id}", response_model=SchemaResponse)
816
  async def get_schema(connection_id: str):
817
- info = ensure_connection(connection_id)
818
- meta = introspect_schema(connection_id)
 
 
 
 
 
 
 
 
 
 
819
  return SchemaResponse(
820
  connection_id=connection_id,
821
- engine=info["engine"],
822
- schema_summary=meta["schema_str"],
823
- tables=meta["tables"],
824
  )
825
 
826
 
827
  @app.get("/preview/{connection_id}/{table}", response_model=PreviewResponse)
828
  async def preview_table(connection_id: str, table: str, limit: int = 20):
829
- info = ensure_connection(connection_id)
830
- engine = info["engine"]
831
- ns = info["namespace"]
832
 
833
  try:
834
- if engine == "postgresql":
835
- conn = get_pg_conn()
836
- try:
837
- cur = conn.cursor()
838
- cur.execute(f'SET search_path TO "{ns}";')
839
- cur.execute(f'SELECT * FROM "{table}" LIMIT %s;', (int(limit),))
840
- rows = cur.fetchall()
841
- cols = [d[0] for d in cur.description] if cur.description else []
842
- cur.close()
843
- finally:
844
- conn.close()
845
- else:
846
- conn = get_mysql_conn(ns)
847
- try:
848
- cur = conn.cursor()
849
- cur.execute(f"SELECT * FROM `{table}` LIMIT %s;", (int(limit),))
850
- rows = cur.fetchall()
851
- cols = [d[0] for d in cur.description] if cur.description else []
852
- cur.close()
853
- finally:
854
- conn.close()
855
  except Exception as e:
856
  raise HTTPException(status_code=400, detail=f"Error al leer tabla '{table}': {e}")
857
 
858
  return PreviewResponse(
859
  connection_id=connection_id,
860
- engine=engine,
861
  table=table,
862
- columns=cols,
863
- rows=[list(r) for r in rows],
864
  )
865
 
866
 
@@ -873,12 +625,12 @@ async def infer_sql(req: InferRequest):
873
  @app.post("/speech-infer", response_model=SpeechInferResponse)
874
  async def speech_infer(
875
  connection_id: str = Form(...),
876
- audio: UploadFile = File(...),
877
  ):
878
  if openai_client is None:
879
  raise HTTPException(
880
  status_code=500,
881
- detail="OPENAI_API_KEY no está configurado en el backend.",
882
  )
883
 
884
  if audio.content_type is None:
@@ -915,8 +667,7 @@ async def health():
915
  return {
916
  "status": "ok",
917
  "model_loaded": t5_model is not None,
918
- "connections": len(DB_REGISTRY),
919
- "engines_in_use": list({info["engine"] for info in DB_REGISTRY.values()}),
920
  "device": str(DEVICE),
921
  }
922
 
@@ -924,13 +675,13 @@ async def health():
924
  @app.get("/")
925
  async def root():
926
  return {
927
- "message": "NL2SQL T5-large backend is running (MySQL/PostgreSQL engines).",
928
  "endpoints": [
929
- "POST /upload-sql (subir .sql o .zip con .sql, engine auto o manual)",
930
- "GET /connections (listar BDs lógicas)",
931
- "GET /schema/{id} (esquema resumido desde motor real)",
932
  "GET /preview/{id}/{t} (preview de tabla)",
933
- "POST /infer (NL→SQL + ejecución segura en el motor)",
934
  "POST /speech-infer (NL por voz → SQL + ejecución)",
935
  "GET /health (estado del backend)",
936
  "GET /docs (OpenAPI UI)",
 
1
  import os
 
2
  import io
3
  import zipfile
4
  import re
 
16
  from transformers import MarianMTModel, MarianTokenizer
17
  from openai import OpenAI
18
 
19
+ from sqlmanager import SQLManager
 
20
 
21
  # ======================================================
22
  # 0) Configuración general
 
26
  MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider")
27
  DEVICE = torch.device("cpu") # inferencia en CPU
28
 
29
+ # Gestor de conexiones reales (MySQL/PostgreSQL)
30
+ sql_manager = SQLManager()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # Cliente OpenAI para transcripción de audio (Whisper / gpt-4o-transcribe)
33
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
 
43
  title="NL2SQL T5-large Backend (MySQL/PostgreSQL)",
44
  description=(
45
  "Intérprete NL→SQL (T5-large Spider) para usuarios no expertos. "
46
+ "El usuario sube sus dumps .sql (o ZIP con .sql) y se levantan "
47
+ "bases reales en MySQL/PostgreSQL; las consultas se ejecutan ahí."
48
  ),
49
  version="2.0.0",
50
  )
 
118
 
119
 
120
  # ======================================================
121
+ # 3) Capa de reparación de SQL (usa el schema real)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  # ======================================================
123
 
124
  def _normalize_name_for_match(name: str) -> str:
125
  s = name.lower()
126
+ s = s.replace('"', '').replace("`", "")
127
  s = s.replace("_", "")
128
  if s.endswith("s") and len(s) > 3:
129
  s = s[:-1]
 
152
  def _best_match_name(missing: str, index: Dict[str, List[str]]) -> Optional[str]:
153
  if not index:
154
  return None
155
+
156
  key = _normalize_name_for_match(missing)
157
  if key in index and index[key]:
158
  return index[key][0]
159
+
160
  candidates = difflib.get_close_matches(key, list(index.keys()), n=1, cutoff=0.7)
161
  if not candidates:
162
  return None
 
189
 
190
 
191
  def try_repair_sql(sql: str, error: str, schema_meta: Dict[str, Any]) -> Optional[str]:
192
+ """
193
+ Intenta reparar nombres de tablas/columnas basándose en el esquema real.
194
+ """
195
  tables_info = schema_meta["tables"]
196
  idx = _build_schema_indexes(tables_info)
197
  table_index = idx["table_index"]
 
203
  missing_table = None
204
  missing_column = None
205
 
206
+ m_t = re.search(r"relation \"([\w\.]+)\" does not exist", error, re.IGNORECASE)
207
+ if not m_t:
208
+ m_t = re.search(r"no such table: ([\w\.]+)", error)
209
  if m_t:
210
  missing_table = m_t.group(1)
211
 
212
+ m_c = re.search(r"column \"([\w\.]+)\" does not exist", error, re.IGNORECASE)
213
+ if not m_c:
214
+ m_c = re.search(r"no such column: ([\w\.]+)", error)
215
  if m_c:
216
  missing_column = m_c.group(1)
217
 
 
253
 
254
 
255
  # ======================================================
256
+ # 4) Construcción de prompt y NL→SQL + re-ranking
257
  # ======================================================
258
 
259
  def build_prompt(question_en: str, db_id: str, schema_str: str) -> str:
 
265
 
266
 
267
  def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]:
268
+ if conn_id not in sql_manager.connections:
269
+ raise HTTPException(status_code=404, detail=f"connection_id '{conn_id}' no registrado")
270
+
271
+ # Obtener esquema real desde MySQL/Postgres
272
+ meta = sql_manager.get_schema(conn_id)
273
+ tables_info = meta["tables"]
274
+
275
+ parts = []
276
+ for t, info in tables_info.items():
277
+ cols = info.get("columns", [])
278
+ parts.append(f"{t}(" + ", ".join(cols) + ")")
279
+ schema_str = " ; ".join(parts) if parts else "(empty_schema)"
280
 
281
  detected = detect_language(question)
282
  question_en = translate_es_to_en(question) if detected == "es" else question
 
322
  "raw_sql_model": raw_sql,
323
  }
324
 
325
+ exec_info = sql_manager.execute_sql(conn_id, raw_sql)
326
 
327
+ # Intentar reparación solo si es error por tabla/columna
328
  if (not exec_info["ok"]) and (
329
  "no such table" in (exec_info["error"] or "").lower()
330
  or "no such column" in (exec_info["error"] or "").lower()
331
+ or "does not exist" in (exec_info["error"] or "").lower()
332
  ):
333
  current_sql = raw_sql
334
+ last_error = exec_info["error"] or ""
335
  for step in range(1, 4):
336
+ repaired_sql = try_repair_sql(current_sql, last_error, meta)
337
  if not repaired_sql or repaired_sql == current_sql:
338
  break
339
+ exec_info2 = sql_manager.execute_sql(conn_id, repaired_sql)
340
  cand["repaired_from"] = current_sql if cand["repaired_from"] is None else cand["repaired_from"]
341
  cand["repair_note"] = f"auto-repair (table/column name, step {step})"
342
  cand["sql"] = repaired_sql
 
344
  current_sql = repaired_sql
345
  if exec_info2["ok"]:
346
  break
347
+ last_error = exec_info2["error"] or ""
348
 
349
  cand["exec_ok"] = exec_info["ok"]
350
  cand["exec_error"] = exec_info["error"]
351
  cand["rows_preview"] = (
352
+ exec_info["rows"][:5] if exec_info["ok"] and exec_info["rows"] else None
353
  )
354
  cand["columns"] = exec_info["columns"]
355
 
 
383
 
384
 
385
  # ======================================================
386
+ # 5) Schemas Pydantic
387
  # ======================================================
388
 
389
  class UploadResponse(BaseModel):
390
  connection_id: str
391
  label: str
392
+ db_path: str # ahora será un pseudo-path (engine://db_name)
 
393
  note: Optional[str] = None
394
 
395
 
396
  class ConnectionInfo(BaseModel):
397
  connection_id: str
398
  label: str
399
+ engine: Optional[str] = None
400
+ db_name: Optional[str] = None
401
 
402
 
403
  class SchemaResponse(BaseModel):
404
  connection_id: str
 
405
  schema_summary: str
406
  tables: Dict[str, Dict[str, List[str]]]
407
 
408
 
409
  class PreviewResponse(BaseModel):
410
  connection_id: str
 
411
  table: str
412
  columns: List[str]
413
  rows: List[List[Any]]
 
438
 
439
 
440
  # ======================================================
441
+ # 6) Helpers para /upload (.sql y .zip)
442
+ # ======================================================
443
+
444
+ def _combine_sql_files_from_zip(zip_bytes: bytes) -> str:
445
+ """
446
+ Lee un ZIP, se queda solo con los .sql y los concatena.
447
+ Orden:
448
+ 1) archivos con 'schema' o 'structure' en el nombre
449
+ 2) el resto (data, etc.)
450
+ """
451
+ try:
452
+ with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf:
453
+ names = [info.filename for info in zf.infolist() if not info.is_dir()]
454
+ sql_names = [n for n in names if n.lower().endswith(".sql")]
455
+
456
+ if not sql_names:
457
+ raise ValueError("El ZIP no contiene archivos .sql utilizables.")
458
+
459
+ def sort_key(name: str) -> int:
460
+ nl = name.lower()
461
+ if "schema" in nl or "structure" in nl:
462
+ return 0
463
+ return 1
464
+
465
+ sql_names_sorted = sorted(sql_names, key=sort_key)
466
+
467
+ parts: List[str] = []
468
+ for name in sql_names_sorted:
469
+ with zf.open(name) as f:
470
+ text = f.read().decode("utf-8", errors="ignore")
471
+ parts.append(f"-- FILE: {name}\n{text}\n")
472
+
473
+ return "\n\n".join(parts)
474
+ except zipfile.BadZipFile:
475
+ raise ValueError("Archivo ZIP inválido o corrupto.")
476
+
477
+
478
+ # ======================================================
479
+ # 7) Endpoints FastAPI
480
  # ======================================================
481
 
482
  @app.on_event("startup")
483
  async def startup_event():
484
  load_nl2sql_model()
485
+ print("✅ Backend NL2SQL inicializado (MySQL/PostgreSQL).")
486
+ print(f"MODEL_DIR={MODEL_DIR}, DEVICE={DEVICE}")
487
+ print(f"Conexiones activas al inicio: {len(sql_manager.connections)}")
488
 
489
 
490
+ @app.post("/upload", response_model=UploadResponse)
491
+ async def upload_database(db_file: UploadFile = File(...)):
 
 
 
 
492
  """
493
+ Subida de BD basada en dumps:
494
+ - .sql dump MySQL/PostgreSQL (schema + data) → BD real
495
+ - .zip debe contener uno o varios .sql (se concatenan)
 
 
496
  """
497
  filename = db_file.filename
498
  if not filename:
 
501
  fname_lower = filename.lower()
502
  contents = await db_file.read()
503
 
504
+ note: Optional[str] = None
 
 
 
 
 
 
 
505
 
506
+ # Caso 1: dump .sql
507
  if fname_lower.endswith(".sql"):
508
  sql_text = contents.decode("utf-8", errors="ignore")
509
+ try:
510
+ conn_id = sql_manager.create_database_from_dump(label=filename, sql_text=sql_text)
511
+ except Exception as e:
512
  raise HTTPException(
513
  status_code=400,
514
+ detail=f"No se pudo crear la BD desde el dump SQL: {e}",
 
 
 
515
  )
516
+ meta = sql_manager.connections[conn_id]
517
+ engine = meta["engine"]
518
+ db_name = meta["db_name"]
519
+ note = f"SQL dump imported into {engine.upper()} database '{db_name}'."
520
 
521
+ # Caso 2: ZIP con uno o varios .sql
522
+ elif fname_lower.endswith(".zip"):
523
+ try:
524
+ sql_text = _combine_sql_files_from_zip(contents)
525
+ except ValueError as ve:
526
+ raise HTTPException(status_code=400, detail=str(ve))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
527
 
528
+ try:
529
+ conn_id = sql_manager.create_database_from_dump(label=filename, sql_text=sql_text)
530
+ except Exception as e:
531
+ raise HTTPException(
532
+ status_code=400,
533
+ detail=f"No se pudo crear la BD desde los .sql dentro del ZIP: {e}",
534
+ )
535
+ meta = sql_manager.connections[conn_id]
536
+ engine = meta["engine"]
537
+ db_name = meta["db_name"]
538
+ note = f"ZIP with SQL dumps imported into {engine.upper()} database '{db_name}'."
539
 
540
+ else:
 
 
541
  raise HTTPException(
542
  status_code=400,
543
+ detail="Formato no soportado. Usa: .sql o .zip (con archivos .sql dentro).",
 
 
 
544
  )
545
 
546
+ meta = sql_manager.connections[conn_id]
547
+ engine = meta["engine"]
548
+ db_name = meta["db_name"]
549
+
550
+ # db_path ahora es un pseudo-path para mantener compatibilidad
551
+ db_path = f"{engine}://{db_name}"
552
+
553
  return UploadResponse(
554
  connection_id=conn_id,
555
+ label=meta["label"],
556
+ db_path=db_path,
 
557
  note=note,
558
  )
559
 
560
 
561
  @app.get("/connections", response_model=List[ConnectionInfo])
562
  async def list_connections():
563
+ return [
564
+ ConnectionInfo(
565
+ connection_id=c["connection_id"],
566
+ label=c.get("label", ""),
567
+ engine=c.get("engine"),
568
+ db_name=c.get("db_name"),
569
+ )
570
+ for c in [
571
+ {
572
+ "connection_id": cid,
573
+ **meta,
574
+ }
575
+ for cid, meta in sql_manager.connections.items()
576
+ ]
577
+ ]
578
 
579
 
580
  @app.get("/schema/{connection_id}", response_model=SchemaResponse)
581
  async def get_schema(connection_id: str):
582
+ if connection_id not in sql_manager.connections:
583
+ raise HTTPException(status_code=404, detail="connection_id no encontrado")
584
+
585
+ meta = sql_manager.get_schema(connection_id)
586
+ tables = meta["tables"]
587
+
588
+ parts = []
589
+ for t, info in tables.items():
590
+ cols = info.get("columns", [])
591
+ parts.append(f"{t}(" + ", ".join(cols) + ")")
592
+ schema_str = " ; ".join(parts) if parts else "(empty_schema)"
593
+
594
  return SchemaResponse(
595
  connection_id=connection_id,
596
+ schema_summary=schema_str,
597
+ tables=tables,
 
598
  )
599
 
600
 
601
  @app.get("/preview/{connection_id}/{table}", response_model=PreviewResponse)
602
  async def preview_table(connection_id: str, table: str, limit: int = 20):
603
+ if connection_id not in sql_manager.connections:
604
+ raise HTTPException(status_code=404, detail="connection_id no encontrado")
 
605
 
606
  try:
607
+ preview = sql_manager.get_preview(connection_id, table, limit)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
608
  except Exception as e:
609
  raise HTTPException(status_code=400, detail=f"Error al leer tabla '{table}': {e}")
610
 
611
  return PreviewResponse(
612
  connection_id=connection_id,
 
613
  table=table,
614
+ columns=preview["columns"],
615
+ rows=preview["rows"],
616
  )
617
 
618
 
 
625
  @app.post("/speech-infer", response_model=SpeechInferResponse)
626
  async def speech_infer(
627
  connection_id: str = Form(...),
628
+ audio: UploadFile = File(...)
629
  ):
630
  if openai_client is None:
631
  raise HTTPException(
632
  status_code=500,
633
+ detail="OPENAI_API_KEY no está configurado en el backend."
634
  )
635
 
636
  if audio.content_type is None:
 
667
  return {
668
  "status": "ok",
669
  "model_loaded": t5_model is not None,
670
+ "connections": len(sql_manager.connections),
 
671
  "device": str(DEVICE),
672
  }
673
 
 
675
  @app.get("/")
676
  async def root():
677
  return {
678
+ "message": "NL2SQL T5-large backend is running with real MySQL/PostgreSQL engines.",
679
  "endpoints": [
680
+ "POST /upload (subir .sql o .zip con .sql crear BD dinámica)",
681
+ "GET /connections (listar BDs subidas)",
682
+ "GET /schema/{id} (esquema resumido)",
683
  "GET /preview/{id}/{t} (preview de tabla)",
684
+ "POST /infer (NL→SQL + ejecución en BD real)",
685
  "POST /speech-infer (NL por voz → SQL + ejecución)",
686
  "GET /health (estado del backend)",
687
  "GET /docs (OpenAPI UI)",