Spaces:
Running
Running
| import os | |
| import io | |
| import zipfile | |
| import re | |
| import difflib | |
| import tempfile | |
| import uuid | |
| from typing import List, Optional, Dict, Any | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Form, Header | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from langdetect import detect | |
| from transformers import MarianMTModel, MarianTokenizer | |
| from openai import OpenAI | |
| # ---- Postgres ---- | |
| import psycopg2 | |
| from psycopg2 import sql as pgsql | |
| # ---- Supabase ---- | |
| from supabase import create_client, Client | |
| SUPABASE_URL = "https://bnvmqgjawtaslczewqyd.supabase.co" | |
| SUPABASE_ANON_KEY = ( | |
| "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImJudm1x" | |
| "Z2phd3Rhc2xjemV3cXlkIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NjQ0NjM5NDAsImV4cCI6MjA4" | |
| "MDAzOTk0MH0.9zkyqrsm-QOSwMTUPZEWqyFeNpbbuar01rB7pmObkUI" | |
| ) | |
| supabase: Client = create_client(SUPABASE_URL, SUPABASE_ANON_KEY) | |
| # ====================================================== | |
| # 0) Configuración general de paths / modelo / OpenAI | |
| # ====================================================== | |
| MODEL_DIR = os.getenv("MODEL_DIR", "stvnnnnnn/t5-large-nl2sql-spider") | |
| DEVICE = torch.device("cpu") | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| openai_client = OpenAI(api_key=OPENAI_API_KEY) if OPENAI_API_KEY else None | |
| # DSN de Supabase Postgres – EJEMPLO: | |
| # postgresql://postgres:[email protected]:5432/postgres | |
| POSTGRES_DSN = os.getenv("POSTGRES_DSN") | |
| if not POSTGRES_DSN: | |
| raise RuntimeError( | |
| "⚠️ POSTGRES_DSN no está definido. " | |
| "Configúralo en los secrets del Space con la cadena de conexión de Supabase." | |
| ) | |
| # ====================================================== | |
| # 1) Gestor de conexiones dinámicas: Postgres (Neon) | |
| # ====================================================== | |
| class PostgresManager: | |
| """ | |
| Cada upload crea un *schema* aislado en Neon. | |
| connections[connection_id] = { | |
| "label": str, # nombre de archivo original | |
| "engine": "postgres", | |
| "schema": str # nombre del schema en Neon | |
| } | |
| """ | |
| def __init__(self, dsn: str): | |
| self.dsn = dsn | |
| self.connections: Dict[str, Dict[str, Any]] = {} | |
| # ---------- utilidades internas ---------- | |
| def _new_connection_id(self) -> str: | |
| return f"db_{uuid.uuid4().hex[:8]}" | |
| def _get_info(self, connection_id: str) -> Dict[str, Any]: | |
| if connection_id not in self.connections: | |
| raise KeyError(f"connection_id '{connection_id}' no registrado") | |
| return self.connections[connection_id] | |
| def _get_conn(self, autocommit: bool = True): | |
| conn = psycopg2.connect(self.dsn) | |
| conn.autocommit = autocommit | |
| return conn | |
| # ---------- helpers de sanitización de dumps ---------- | |
| def _rewrite_line_for_schema(self, line: str, schema_name: str) -> str: | |
| """ | |
| Versión simplificada: | |
| - Solo elimina líneas que modifican el search_path. | |
| - NO reescribe public./pagila. → dejamos que el dump use su propio schema. | |
| """ | |
| if "search_path" in line.lower(): | |
| return "" | |
| return line | |
| def _should_skip_statement(self, stmt: str) -> bool: | |
| """ | |
| Devuelve True si el statement NO debe ejecutarse (grants, owner, create db, domains, etc.). | |
| Filtro universal para dumps PostgreSQL (Neon, Pagila, etc.). | |
| """ | |
| if not stmt: | |
| return True | |
| upper = stmt.upper().strip() | |
| # 1) Statements globales / de administración que SIEMPRE ignoramos | |
| skip_prefixes = ( | |
| "SET ", | |
| "RESET ", | |
| "SELECT PG_CATALOG.SET_CONFIG", | |
| "COMMENT ON EXTENSION", | |
| "COMMENT ON SCHEMA", | |
| "COMMENT ON DATABASE", | |
| "COMMENT ON COLLATION", | |
| "COMMENT ON CONVERSION", | |
| "COMMENT ON LANGUAGE", | |
| "COMMENT ON TEXT SEARCH", | |
| "COMMENT ON FOREIGN", | |
| "CREATE DATABASE", | |
| "ALTER DATABASE", | |
| "DROP DATABASE", | |
| "CREATE EXTENSION", | |
| "ALTER EXTENSION", | |
| "DROP EXTENSION", | |
| "REVOKE ", | |
| "GRANT ", | |
| "ALTER ROLE", | |
| "CREATE ROLE", | |
| "DROP ROLE", | |
| "CREATE USER", | |
| "ALTER USER", | |
| "DROP USER", | |
| "ALTER DEFAULT PRIVILEGES", | |
| "SECURITY LABEL", | |
| "BEGIN", | |
| "COMMIT", | |
| "ROLLBACK", | |
| ) | |
| if upper.startswith(skip_prefixes): | |
| return True | |
| # 2) Cualquier cosa que toque OWNER / AUTHORIZATION la ignoramos | |
| owner_markers = ( | |
| " OWNER TO ", | |
| " OWNER ", | |
| "AUTHORIZATION POSTGRES", | |
| "AUTHORIZATION PUBLIC", | |
| "AUTHORIZATION CURRENT_USER", | |
| "AUTHORIZATION \"POSTGRES\"", | |
| ) | |
| if any(marker in upper for marker in owner_markers): | |
| return True | |
| # 3) Grants / revokes explícitos a postgres o public (aunque no empiecen por GRANT/REVOKE) | |
| if " TO POSTGRES" in upper or " FROM POSTGRES" in upper: | |
| return True | |
| if " TO PUBLIC" in upper or " FROM PUBLIC" in upper: | |
| return True | |
| return False | |
| def _execute_sanitized_pg_dump( | |
| self, cur, sql_text: str, schema_name: str | |
| ) -> None: | |
| """ | |
| Ejecuta un dump de PostgreSQL dentro de un schema de sesión, | |
| aplicando sanitización y soportando COPY ... FROM stdin;. | |
| - Reescribe public./pagila. -> schema_name. | |
| - Respeta funciones con $$...$$ (no corta por ';' internos). | |
| - Ignora statements peligrosos via _should_skip_statement(). | |
| """ | |
| in_copy = False | |
| copy_sql = "" | |
| copy_lines: list[str] = [] | |
| buffer = "" # statement acumulado | |
| in_dollar = False # estamos dentro de $$...$$ ? | |
| dollar_tag = "" # por ej. "$func$" | |
| in_domain_block = False # 👈 estamos dentro de un bloque CREATE DOMAIN ? | |
| in_function_block = False # 👈 estamos dentro de un CREATE FUNCTION ? | |
| def flush_statement(): | |
| nonlocal buffer | |
| stmt = buffer.strip() | |
| buffer = "" | |
| if not stmt: | |
| return | |
| if self._should_skip_statement(stmt): | |
| return | |
| try: | |
| cur.execute(stmt) | |
| except Exception as e: | |
| msg = str(e).lower() | |
| # Ignoramos errores típicos de dumps que no son fatales | |
| if "already exists" in msg or "duplicate key value" in msg: | |
| print("[WARN] Ignorando error no crítico:", e) | |
| return | |
| raise | |
| # Procesar línea por línea | |
| for raw_line in sql_text.splitlines(): | |
| line = raw_line.rstrip("\n") | |
| stripped = line.strip() | |
| # ====== BLOQUE CREATE FUNCTION (lo ignoramos entero) ====== | |
| if in_function_block: | |
| # Cerramos cuando vemos algo tipo "$_$;" o "$func$;" | |
| if re.search(r"\$[A-Za-z0-9_]*\$;", stripped): | |
| in_function_block = False | |
| continue | |
| # Comentarios y líneas vacías (fuera de COPY / DOMAIN / FUNCTION) | |
| if not in_copy and not in_domain_block: | |
| if not stripped or stripped.startswith("--"): | |
| continue | |
| upper_line = stripped.upper() | |
| if ( | |
| upper_line.startswith("CREATE FUNCTION") | |
| or upper_line.startswith("CREATE OR REPLACE FUNCTION") | |
| or upper_line.startswith("ALTER FUNCTION") | |
| ): | |
| # Ignoramos toda la función (cabecera + cuerpo) | |
| in_function_block = True | |
| continue | |
| # ====== BLOQUE COPY ... FROM stdin ====== | |
| if in_copy: | |
| if stripped == r"\.": | |
| # fin de COPY | |
| data = "\n".join(copy_lines) + "\n" | |
| cur.copy_expert(copy_sql, io.StringIO(data)) | |
| in_copy = False | |
| copy_sql = "" | |
| copy_lines.clear() | |
| else: | |
| copy_lines.append(line) | |
| continue | |
| # Reescribimos la línea según el schema de sesión | |
| line = self._rewrite_line_for_schema(line, schema_name) | |
| stripped = line.strip() | |
| if not stripped: | |
| continue | |
| # Detectar inicio de COPY ahora que la línea ya está reescrita | |
| if stripped.upper().startswith("COPY ") and "FROM stdin" in stripped.upper(): | |
| # Ejecutar lo que haya pendiente antes del COPY | |
| flush_statement() | |
| in_copy = True | |
| copy_sql = stripped # ya reescrita | |
| copy_lines = [] | |
| continue | |
| # Escanear la línea caracter a caracter para detectar $tag$ y ';' | |
| i = 0 | |
| start_seg = 0 | |
| length = len(line) | |
| while i < length: | |
| ch = line[i] | |
| # Manejo de delimitadores $tag$ | |
| if ch == "$": | |
| # ¿Inicio o fin de bloque dollar-quoted? | |
| j = i + 1 | |
| while j < length and (line[j].isalnum() or line[j] == "_"): | |
| j += 1 | |
| if j < length and line[j] == "$": | |
| tag = line[i : j + 1] # incluye ambos '$' | |
| if not in_dollar: | |
| in_dollar = True | |
| dollar_tag = tag | |
| else: | |
| if tag == dollar_tag: | |
| in_dollar = False | |
| dollar_tag = "" | |
| i = j + 1 | |
| continue | |
| # Fin de statement: ';' fuera de bloque dollar-quoted | |
| if ch == ";" and not in_dollar: | |
| segment = line[start_seg : i + 1] | |
| buffer += segment + "\n" | |
| flush_statement() | |
| start_seg = i + 1 | |
| i += 1 | |
| continue | |
| i += 1 | |
| # Resto de la línea (después del último ';' o toda la línea si no hubo ';') | |
| if start_seg < length: | |
| buffer += line[start_seg:] + "\n" | |
| # Ejecutar lo que quede pendiente | |
| flush_statement() | |
| # Por seguridad, aseguramos que no haya COPY abierto sin cerrar | |
| if in_copy: | |
| raise RuntimeError("Dump SQL inválido: COPY sin terminación '\\.'") | |
| # ---------- creación de BD desde dump ---------- | |
| def create_database_from_dump(self, label: str, sql_text: str) -> str: | |
| """ | |
| Restaura un dump de Postgres (schema + datos) en la BD Neon. | |
| NO crea schemas de sesión: deja que el dump use sus propios schemas | |
| (public, pagila, etc.). Luego detecta el schema con más tablas. | |
| """ | |
| connection_id = self._new_connection_id() | |
| schema_name: str | None = None | |
| conn = self._get_conn() | |
| try: | |
| with conn.cursor() as cur: | |
| # 1) Ejecutar el dump tal cual (solo limpiamos search_path) | |
| self._execute_sanitized_pg_dump(cur, sql_text, schema_name="public") | |
| # 2) Detectar el schema REAL donde quedaron las tablas del dump | |
| cur.execute( | |
| """ | |
| SELECT table_schema, COUNT(*) AS n | |
| FROM information_schema.tables | |
| WHERE table_type = 'BASE TABLE' | |
| AND table_schema NOT IN ('pg_catalog','information_schema') | |
| GROUP BY table_schema | |
| ORDER BY n DESC; | |
| """ | |
| ) | |
| rows = cur.fetchall() | |
| if not rows: | |
| raise RuntimeError( | |
| "El dump se ejecutó pero no se encontraron tablas de usuario." | |
| ) | |
| # Tomamos el schema con más tablas (pagila, public, etc.) | |
| schema_name = rows[0][0] | |
| except Exception as e: | |
| conn.close() | |
| raise RuntimeError(f"Error ejecutando dump SQL en Postgres: {e}") | |
| finally: | |
| conn.close() | |
| self.connections[connection_id] = { | |
| "label": label, | |
| "engine": "postgres", | |
| "schema": schema_name, # 👈 ahora es el schema REAL con tablas | |
| } | |
| return connection_id | |
| # ---------- ejecución segura de SQL ---------- | |
| def execute_sql(self, connection_id: str, sql_text: str) -> Dict[str, Any]: | |
| """ | |
| Ejecuta un SELECT dentro del schema asociado al connection_id. | |
| Bloquea operaciones destructivas por seguridad. | |
| """ | |
| info = self._get_info(connection_id) | |
| schema = info["schema"] | |
| forbidden = ["drop ", "delete ", "update ", "insert ", "alter ", "replace "] | |
| sql_low = sql_text.lower() | |
| if any(tok in sql_low for tok in forbidden): | |
| return { | |
| "ok": False, | |
| "error": "Query bloqueada por seguridad (operación destructiva).", | |
| "rows": None, | |
| "columns": [], | |
| } | |
| conn = self._get_conn() | |
| try: | |
| with conn.cursor() as cur: | |
| # usar el schema de la sesión | |
| cur.execute( | |
| pgsql.SQL("SET search_path TO {}").format( | |
| pgsql.Identifier(schema) | |
| ) | |
| ) | |
| cur.execute(sql_text) | |
| if cur.description: | |
| rows = cur.fetchall() | |
| cols = [d[0] for d in cur.description] | |
| else: | |
| rows, cols = [], [] | |
| return { | |
| "ok": True, | |
| "error": None, | |
| "rows": [list(r) for r in rows], | |
| "columns": cols, | |
| } | |
| except Exception as e: | |
| return {"ok": False, "error": str(e), "rows": None, "columns": []} | |
| finally: | |
| conn.close() | |
| # ---------- introspección de esquema ---------- | |
| def get_schema(self, connection_id: str) -> Dict[str, Any]: | |
| info = self._get_info(connection_id) | |
| schema = info["schema"] # schema "ideal" que registramos | |
| conn = self._get_conn() | |
| try: | |
| tables_info: Dict[str, Dict[str, Any]] = {} | |
| foreign_keys: List[Dict[str, Any]] = [] | |
| with conn.cursor() as cur: | |
| # 1) Intentamos solo con el schema registrado | |
| cur.execute( | |
| """ | |
| SELECT table_name | |
| FROM information_schema.tables | |
| WHERE table_schema = %s | |
| AND table_type = 'BASE TABLE' | |
| ORDER BY table_name; | |
| """, | |
| (schema,), | |
| ) | |
| tables = [r[0] for r in cur.fetchall()] | |
| # 2) 🔁 Fallback: si no hay tablas en ese schema, | |
| # buscamos en TODOS los schemas de usuario | |
| if not tables: | |
| cur.execute( | |
| """ | |
| SELECT table_schema, table_name | |
| FROM information_schema.tables | |
| WHERE table_type = 'BASE TABLE' | |
| AND table_schema NOT IN ('pg_catalog','information_schema') | |
| ORDER BY table_schema, table_name; | |
| """ | |
| ) | |
| rows = cur.fetchall() | |
| if not rows: | |
| # No hay tablas en ningún schema de usuario | |
| return { | |
| "tables": {}, | |
| "foreign_keys": [], | |
| } | |
| # Schemas candidatos que sí tienen tablas | |
| schemas = sorted({s for (s, _) in rows}) | |
| # Preferimos: | |
| # 1) el schema ya registrado (si por alguna razón tiene tablas) | |
| # 2) 'pagila' | |
| # 3) 'public' | |
| # 4) el primero que aparezca | |
| target_schema = None | |
| if schema in schemas: | |
| target_schema = schema | |
| elif "pagila" in schemas: | |
| target_schema = "pagila" | |
| elif "public" in schemas: | |
| target_schema = "public" | |
| else: | |
| target_schema = schemas[0] | |
| print( | |
| f"[WARN] Schema '{schema}' sin tablas; usando schema real '{target_schema}'" | |
| ) | |
| # Actualizamos el schema asociado a esta conexión | |
| schema = target_schema | |
| info["schema"] = schema | |
| tables = [t for (s, t) in rows if s == schema] | |
| # 3) Columnas por tabla del schema final seleccionado | |
| for t in tables: | |
| cur.execute( | |
| """ | |
| SELECT column_name | |
| FROM information_schema.columns | |
| WHERE table_schema = %s | |
| AND table_name = %s | |
| ORDER BY ordinal_position; | |
| """, | |
| (schema, t), | |
| ) | |
| cols = [r[0] for r in cur.fetchall()] | |
| tables_info[t] = {"columns": cols} | |
| # 4) Foreign keys del schema final | |
| cur.execute( | |
| """ | |
| SELECT | |
| tc.table_name AS from_table, | |
| kcu.column_name AS from_column, | |
| ccu.table_name AS to_table, | |
| ccu.column_name AS to_column | |
| FROM information_schema.table_constraints AS tc | |
| JOIN information_schema.key_column_usage AS kcu | |
| ON tc.constraint_name = kcu.constraint_name | |
| AND tc.table_schema = kcu.table_schema | |
| JOIN information_schema.constraint_column_usage AS ccu | |
| ON ccu.constraint_name = tc.constraint_name | |
| AND ccu.table_schema = tc.table_schema | |
| WHERE tc.constraint_type = 'FOREIGN KEY' | |
| AND tc.table_schema = %s; | |
| """, | |
| (schema,), | |
| ) | |
| for ft, fc, tt, tc2 in cur.fetchall(): | |
| foreign_keys.append( | |
| { | |
| "from_table": ft, | |
| "from_column": fc, | |
| "to_table": tt, | |
| "to_column": tc2, | |
| } | |
| ) | |
| return { | |
| "tables": tables_info, | |
| "foreign_keys": foreign_keys, | |
| } | |
| finally: | |
| conn.close() | |
| # ---------- preview de tabla ---------- | |
| def get_preview( | |
| self, connection_id: str, table: str, limit: int = 20 | |
| ) -> Dict[str, Any]: | |
| info = self._get_info(connection_id) | |
| schema = info["schema"] | |
| conn = self._get_conn() | |
| try: | |
| with conn.cursor() as cur: | |
| cur.execute( | |
| pgsql.SQL("SET search_path TO {}").format( | |
| pgsql.Identifier(schema) | |
| ) | |
| ) | |
| query = pgsql.SQL("SELECT * FROM {} LIMIT %s").format( | |
| pgsql.Identifier(table) | |
| ) | |
| cur.execute(query, (int(limit),)) | |
| rows = cur.fetchall() | |
| cols = [d[0] for d in cur.description] if cur.description else [] | |
| return { | |
| "columns": cols, | |
| "rows": [list(r) for r in rows], | |
| } | |
| finally: | |
| conn.close() | |
| # Instancia global de PostgresManager | |
| sql_manager = PostgresManager(POSTGRES_DSN) | |
| # ====================================================== | |
| # 2) Inicialización de FastAPI | |
| # ====================================================== | |
| app = FastAPI( | |
| title="NL2SQL Backend", | |
| version="3.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ====================================================== | |
| # 3) Modelo NL→SQL y traductor ES→EN | |
| # ====================================================== | |
| t5_tokenizer = None | |
| t5_model = None | |
| mt_tokenizer = None | |
| mt_model = None | |
| def load_nl2sql_model(): | |
| """Carga el modelo NL→SQL (T5-large fine-tuned en Spider) desde HF Hub.""" | |
| global t5_tokenizer, t5_model | |
| if t5_model is not None: | |
| return | |
| print(f"🔁 Cargando modelo NL→SQL desde: {MODEL_DIR}") | |
| t5_tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True) | |
| t5_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| MODEL_DIR, torch_dtype=torch.float32 | |
| ) | |
| t5_model.to(DEVICE) | |
| t5_model.eval() | |
| print("✅ Modelo NL→SQL listo en memoria.") | |
| def load_es_en_translator(): | |
| """Carga el modelo Helsinki-NLP para traducción ES→EN (solo una vez).""" | |
| global mt_tokenizer, mt_model | |
| if mt_model is not None: | |
| return | |
| model_name = "Helsinki-NLP/opus-mt-es-en" | |
| print(f"🔁 Cargando traductor ES→EN: {model_name}") | |
| mt_tokenizer = MarianTokenizer.from_pretrained(model_name) | |
| mt_model = MarianMTModel.from_pretrained(model_name) | |
| mt_model.to(DEVICE) | |
| mt_model.eval() | |
| print("✅ Traductor ES→EN listo.") | |
| def detect_language(text: str) -> str: | |
| try: | |
| return detect(text) | |
| except Exception: | |
| return "unknown" | |
| def translate_es_to_en(text: str) -> str: | |
| """ | |
| Usa Marian ES→EN solo si el texto se detecta como español ('es'). | |
| Si no, devuelve el texto tal cual. | |
| """ | |
| lang = detect_language(text) | |
| if lang != "es": | |
| return text | |
| if mt_model is None: | |
| load_es_en_translator() | |
| inputs = mt_tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE) | |
| with torch.no_grad(): | |
| out = mt_model.generate(**inputs, max_length=256) | |
| return mt_tokenizer.decode(out[0], skip_special_tokens=True) | |
| # ====================================================== | |
| # 4) Capa de reparación de SQL (usa el schema real) | |
| # ====================================================== | |
| def _normalize_name_for_match(name: str) -> str: | |
| s = name.lower() | |
| s = s.replace('"', "").replace("`", "") | |
| s = s.replace("_", "") | |
| if s.endswith("s") and len(s) > 3: | |
| s = s[:-1] | |
| return s | |
| def _build_schema_indexes( | |
| tables_info: Dict[str, Dict[str, List[str]]] | |
| ) -> Dict[str, Dict[str, List[str]]]: | |
| table_index: Dict[str, List[str]] = {} | |
| column_index: Dict[str, List[str]] = {} | |
| for t, info in tables_info.items(): | |
| tn = _normalize_name_for_match(t) | |
| table_index.setdefault(tn, []) | |
| if t not in table_index[tn]: | |
| table_index[tn].append(t) | |
| for c in info.get("columns", []): | |
| cn = _normalize_name_for_match(c) | |
| column_index.setdefault(cn, []) | |
| if c not in column_index[cn]: | |
| column_index[cn].append(c) | |
| return {"table_index": table_index, "column_index": column_index} | |
| def _best_match_name(missing: str, index: Dict[str, List[str]]) -> Optional[str]: | |
| if not index: | |
| return None | |
| key = _normalize_name_for_match(missing) | |
| if key in index and index[key]: | |
| return index[key][0] | |
| candidates = difflib.get_close_matches(key, list(index.keys()), n=1, cutoff=0.7) | |
| if not candidates: | |
| return None | |
| best_key = candidates[0] | |
| if index[best_key]: | |
| return index[best_key][0] | |
| return None | |
| DOMAIN_SYNONYMS_TABLE = { | |
| "song": "track", | |
| "songs": "track", | |
| "tracks": "track", | |
| "artist": "artist", | |
| "artists": "artist", | |
| "album": "album", | |
| "albums": "album", | |
| "order": "invoice", | |
| "orders": "invoice", | |
| } | |
| DOMAIN_SYNONYMS_COLUMN = { | |
| "song": "name", | |
| "songs": "name", | |
| "track": "name", | |
| "title": "name", | |
| "length": "milliseconds", | |
| "duration": "milliseconds", | |
| } | |
| def try_repair_sql( | |
| sql: str, error: str, schema_meta: Dict[str, Any] | |
| ) -> Optional[str]: | |
| """ | |
| Intenta reparar nombres de tablas/columnas basándose en el esquema real. | |
| Compatible con mensajes de Postgres y también con los de SQLite | |
| (por si algún día reusamos la lógica). | |
| """ | |
| tables_info = schema_meta["tables"] | |
| idx = _build_schema_indexes(tables_info) | |
| table_index = idx["table_index"] | |
| column_index = idx["column_index"] | |
| repaired_sql = sql | |
| changed = False | |
| missing_table = None | |
| missing_column = None | |
| m_t = re.search(r'relation "([\w\.]+)" does not exist', error, re.IGNORECASE) | |
| if not m_t: | |
| m_t = re.search(r"no such table: ([\w\.]+)", error) | |
| if m_t: | |
| missing_table = m_t.group(1) | |
| m_c = re.search(r'column "([\w\.]+)" does not exist', error, re.IGNORECASE) | |
| if not m_c: | |
| m_c = re.search(r"no such column: ([\w\.]+)", error) | |
| if m_c: | |
| missing_column = m_c.group(1) | |
| if missing_table: | |
| short = missing_table.split(".")[-1] | |
| syn = DOMAIN_SYNONYMS_TABLE.get(short.lower()) | |
| target = None | |
| if syn: | |
| target = _best_match_name(syn, table_index) or syn | |
| if not target: | |
| target = _best_match_name(short, table_index) | |
| if target: | |
| pattern = r"\b" + re.escape(short) + r"\b" | |
| new_sql = re.sub(pattern, target, repaired_sql) | |
| if new_sql != repaired_sql: | |
| repaired_sql = new_sql | |
| changed = True | |
| if missing_column: | |
| short = missing_column.split(".")[-1] | |
| syn = DOMAIN_SYNONYMS_COLUMN.get(short.lower()) | |
| target = None | |
| if syn: | |
| target = _best_match_name(syn, column_index) or syn | |
| if not target: | |
| target = _best_match_name(short, column_index) | |
| if target: | |
| pattern = r"\b" + re.escape(short) + r"\b" | |
| new_sql = re.sub(pattern, target, repaired_sql) | |
| if new_sql != repaired_sql: | |
| repaired_sql = new_sql | |
| changed = True | |
| if not changed: | |
| return None | |
| return repaired_sql | |
| # ====================================================== | |
| # 5) Prompt NL→SQL + re-ranking | |
| # ====================================================== | |
| def build_prompt(question_en: str, db_id: str, schema_str: str) -> str: | |
| return ( | |
| f"translate to SQL: {question_en} | " | |
| f"db: {db_id} | schema: {schema_str} | " | |
| f"note: use JOIN when foreign keys link tables" | |
| ) | |
| def normalize_score(raw: float) -> float: | |
| """Normaliza el score logit del modelo a un porcentaje 0-100.""" | |
| norm = (raw + 20) / 25 | |
| norm = max(0, min(1, norm)) | |
| return round(norm * 100, 2) | |
| def nl2sql_with_rerank(question: str, conn_id: str) -> Dict[str, Any]: | |
| if conn_id not in sql_manager.connections: | |
| raise HTTPException( | |
| status_code=404, detail=f"connection_id '{conn_id}' no registrado" | |
| ) | |
| meta = sql_manager.get_schema(conn_id) | |
| tables_info = meta["tables"] | |
| parts = [] | |
| for t, info in tables_info.items(): | |
| cols = info.get("columns", []) | |
| parts.append(f"{t}(" + ", ".join(cols) + ")") | |
| schema_str = " ; ".join(parts) if parts else "(empty_schema)" | |
| detected = detect_language(question) | |
| question_en = translate_es_to_en(question) if detected == "es" else question | |
| prompt = build_prompt(question_en, db_id=conn_id, schema_str=schema_str) | |
| if t5_model is None: | |
| load_nl2sql_model() | |
| inputs = t5_tokenizer( | |
| [prompt], return_tensors="pt", truncation=True, max_length=768 | |
| ).to(DEVICE) | |
| num_beams = 6 | |
| num_return = 6 | |
| with torch.no_grad(): | |
| out = t5_model.generate( | |
| **inputs, | |
| max_length=220, | |
| num_beams=num_beams, | |
| num_return_sequences=num_return, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| ) | |
| sequences = out.sequences | |
| scores = out.sequences_scores | |
| if scores is not None: | |
| scores = scores.cpu().tolist() | |
| else: | |
| scores = [0.0] * sequences.size(0) | |
| candidates: List[Dict[str, Any]] = [] | |
| best = None | |
| best_exec = False | |
| best_score = -1e9 | |
| for i in range(sequences.size(0)): | |
| raw_sql = t5_tokenizer.decode( | |
| sequences[i], skip_special_tokens=True | |
| ).strip() | |
| cand: Dict[str, Any] = { | |
| "sql": raw_sql, | |
| "score": float(scores[i]), | |
| "repaired_from": None, | |
| "repair_note": None, | |
| "raw_sql_model": raw_sql, | |
| } | |
| exec_info = sql_manager.execute_sql(conn_id, raw_sql) | |
| err_lower = (exec_info["error"] or "").lower() | |
| if (not exec_info["ok"]) and ( | |
| "no such table" in err_lower | |
| or "no such column" in err_lower | |
| or "does not exist" in err_lower | |
| ): | |
| current_sql = raw_sql | |
| last_error = exec_info["error"] or "" | |
| for step in range(1, 4): | |
| repaired_sql = try_repair_sql(current_sql, last_error, meta) | |
| if not repaired_sql or repaired_sql == current_sql: | |
| break | |
| exec_info2 = sql_manager.execute_sql(conn_id, repaired_sql) | |
| cand["repaired_from"] = ( | |
| current_sql | |
| if cand["repaired_from"] is None | |
| else cand["repaired_from"] | |
| ) | |
| cand["repair_note"] = ( | |
| f"auto-repair (table/column name, step {step})" | |
| ) | |
| cand["sql"] = repaired_sql | |
| exec_info = exec_info2 | |
| current_sql = repaired_sql | |
| if exec_info2["ok"]: | |
| break | |
| last_error = exec_info2["error"] or "" | |
| cand["exec_ok"] = exec_info["ok"] | |
| cand["exec_error"] = exec_info["error"] | |
| cand["rows_preview"] = ( | |
| exec_info["rows"][:5] if exec_info["ok"] and exec_info["rows"] else None | |
| ) | |
| cand["columns"] = exec_info["columns"] | |
| candidates.append(cand) | |
| if exec_info["ok"]: | |
| if (not best_exec) or cand["score"] > best_score: | |
| best_exec = True | |
| best_score = cand["score"] | |
| best = cand | |
| elif not best_exec and cand["score"] > best_score: | |
| best_score = cand["score"] | |
| best = cand | |
| if best is None and candidates: | |
| best = candidates[0] | |
| return { | |
| "question_original": question, | |
| "detected_language": detected, | |
| "question_en": question_en, | |
| "connection_id": conn_id, | |
| "schema_summary": schema_str, | |
| "best_sql": best["sql"], | |
| "best_exec_ok": best.get("exec_ok", False), | |
| "best_exec_error": best.get("exec_error"), | |
| "best_rows_preview": best.get("rows_preview"), | |
| "best_columns": best.get("columns", []), | |
| "candidates": candidates, | |
| "score_percent": normalize_score(best["score"]), | |
| } | |
| # ====================================================== | |
| # 6) Schemas Pydantic | |
| # ====================================================== | |
| class UploadResponse(BaseModel): | |
| connection_id: str | |
| label: str | |
| db_path: str | |
| note: Optional[str] = None | |
| class ConnectionInfo(BaseModel): | |
| connection_id: str | |
| label: str | |
| engine: Optional[str] = None | |
| db_name: Optional[str] = None # ya no usamos archivo, pero mantenemos campo | |
| class SchemaResponse(BaseModel): | |
| connection_id: str | |
| schema_summary: str | |
| tables: Dict[str, Dict[str, List[str]]] | |
| class PreviewResponse(BaseModel): | |
| connection_id: str | |
| table: str | |
| columns: List[str] | |
| rows: List[List[Any]] | |
| class InferRequest(BaseModel): | |
| connection_id: str | |
| question: str | |
| class InferResponse(BaseModel): | |
| question_original: str | |
| detected_language: str | |
| question_en: str | |
| connection_id: str | |
| schema_summary: str | |
| best_sql: str | |
| best_exec_ok: bool | |
| best_exec_error: Optional[str] | |
| best_rows_preview: Optional[List[List[Any]]] | |
| best_columns: List[str] | |
| candidates: List[Dict[str, Any]] | |
| class SpeechInferResponse(BaseModel): | |
| transcript: str | |
| result: InferResponse | |
| # ====================================================== | |
| # 7) Helpers para /upload (.sql y .zip) | |
| # ====================================================== | |
| def _combine_sql_files_from_zip(zip_bytes: bytes) -> str: | |
| """ | |
| Lee un ZIP, se queda solo con los .sql y los concatena. | |
| Orden: | |
| 1) archivos con 'schema' o 'structure' en el nombre | |
| 2) el resto (data, etc.) | |
| """ | |
| try: | |
| with zipfile.ZipFile(io.BytesIO(zip_bytes)) as zf: | |
| names = [info.filename for info in zf.infolist() if not info.is_dir()] | |
| sql_names = [n for n in names if n.lower().endswith(".sql")] | |
| if not sql_names: | |
| raise ValueError("El ZIP no contiene archivos .sql utilizables.") | |
| def sort_key(name: str) -> int: | |
| nl = name.lower() | |
| if "schema" in nl or "structure" in nl: | |
| return 0 | |
| return 1 | |
| sql_names_sorted = sorted(sql_names, key=sort_key) | |
| parts: List[str] = [] | |
| for name in sql_names_sorted: | |
| with zf.open(name) as f: | |
| text = f.read().decode("utf-8", errors="ignore") | |
| parts.append(f"-- FILE: {name}\n{text}\n") | |
| return "\n\n".join(parts) | |
| except zipfile.BadZipFile: | |
| raise ValueError("Archivo ZIP inválido o corrupto.") | |
| # ====================================================== | |
| # 8) Endpoints FastAPI | |
| # ====================================================== | |
| async def startup_event(): | |
| load_nl2sql_model() | |
| print("✅ Backend NL2SQL inicializado.") | |
| print(f"MODEL_DIR={MODEL_DIR}, DEVICE={DEVICE}") | |
| print(f"Conexiones activas al inicio: {len(sql_manager.connections)}") | |
| async def upload_database( | |
| mode: str = Form("full"), # "full" | "schema_data" | "zip" | |
| db_files: List[UploadFile] = File(...), # uno o varios archivos | |
| authorization: Optional[str] = Header(None), | |
| ): | |
| """ | |
| Sube uno o varios archivos SQL/ZIP según el modo: | |
| - mode = "full": | |
| * Espera EXACTAMENTE 1 archivo .sql | |
| * El .sql trae esquema + datos juntos (dump de PostgreSQL) | |
| - mode = "schema_data": | |
| * Espera EXACTAMENTE 2 archivos .sql | |
| * Uno de esquema y otro de datos (el orden lo resolvemos nosotros) | |
| - mode = "zip": | |
| * Espera EXACTAMENTE 1 archivo .zip | |
| * Dentro del zip buscamos SOLO archivos .sql (ignoramos el resto) | |
| """ | |
| if authorization is None: | |
| raise HTTPException(401, "Missing Authorization header") | |
| jwt = authorization.replace("Bearer ", "") | |
| user = supabase.auth.get_user(jwt) | |
| if not user or not user.user: | |
| raise HTTPException(401, "Invalid Supabase token") | |
| if not db_files: | |
| raise HTTPException(400, "No se recibió ningún archivo.") | |
| mode = mode.lower().strip() | |
| # ======================= | |
| # MODO 1: FULL (.sql único) | |
| # ======================= | |
| if mode == "full": | |
| if len(db_files) != 1: | |
| raise HTTPException( | |
| 400, "Modo FULL requiere exactamente 1 archivo .sql." | |
| ) | |
| file = db_files[0] | |
| filename = file.filename or "" | |
| if not filename.lower().endswith(".sql"): | |
| raise HTTPException(400, "Modo FULL solo acepta archivos .sql.") | |
| contents = await file.read() | |
| sql_text = contents.decode("utf-8", errors="ignore") | |
| # ==================================== | |
| # MODO 2: ESQUEMA + DATOS (2 archivos) | |
| # ==================================== | |
| elif mode == "schema_data": | |
| if len(db_files) != 2: | |
| raise HTTPException( | |
| 400, | |
| "Modo esquema+datos requiere exactamente 2 archivos .sql.", | |
| ) | |
| print("FILES RECEIVED:", [f.filename for f in db_files]) | |
| files_info: List[tuple[str, str]] = [] | |
| for f in db_files: | |
| fname = f.filename or "" | |
| if not fname.lower().endswith(".sql"): | |
| raise HTTPException(400, "Todos los archivos deben ser .sql.") | |
| contents = await f.read() | |
| files_info.append( | |
| (fname, contents.decode("utf-8", errors="ignore")) | |
| ) | |
| # Intentamos poner primero el esquema y luego los datos | |
| def weight(name: str) -> int: | |
| nl = name.lower().replace("-", "_").replace(" ", "_") | |
| if any(x in nl for x in ["schema", "structure", "ddl"]): | |
| return 0 | |
| if any(x in nl for x in ["data", "dml", "insert", "rows"]): | |
| return 1 | |
| return 2 | |
| files_info_sorted = sorted(files_info, key=lambda x: weight(x[0])) | |
| sql_parts: List[str] = [] | |
| for fname, text in files_info_sorted: | |
| sql_parts.append(f"-- FILE: {fname}\n{text}\n") | |
| sql_text = "\n\n".join(sql_parts) | |
| # usamos el nombre del primer archivo como label "principal" | |
| filename = files_info_sorted[0][0] | |
| # ================== | |
| # MODO 3: ZIP (.zip) | |
| # ================== | |
| elif mode == "zip": | |
| if len(db_files) != 1: | |
| raise HTTPException( | |
| 400, "Modo ZIP requiere exactamente 1 archivo .zip." | |
| ) | |
| file = db_files[0] | |
| filename = file.filename or "" | |
| if not filename.lower().endswith(".zip"): | |
| raise HTTPException(400, "Modo ZIP solo acepta archivos .zip.") | |
| contents = await file.read() | |
| # tu helper ya ignora carpetas y solo concatena .sql | |
| sql_text = _combine_sql_files_from_zip(contents) | |
| else: | |
| raise HTTPException(400, f"Modo no soportado: {mode}") | |
| # --- crear schema dinámico en Postgres (Neon) --- | |
| try: | |
| conn_id = sql_manager.create_database_from_dump( | |
| label=filename, sql_text=sql_text | |
| ) | |
| except Exception as e: | |
| raise HTTPException(400, f"Error creando BD: {e}") | |
| meta = sql_manager.connections[conn_id] | |
| # --- guardar metadatos en Supabase (sin romper el upload si falla) --- | |
| try: | |
| supabase.table("databases").insert( | |
| { | |
| "user_id": user.user.id, | |
| "filename": filename, | |
| "engine": meta["engine"], | |
| "connection_id": conn_id, | |
| } | |
| ).execute() | |
| except Exception as e: | |
| # Solo logeamos, pero NO rompemos el endpoint | |
| print("[WARN] No se pudieron guardar metadatos en Supabase:", repr(e)) | |
| return UploadResponse( | |
| connection_id=conn_id, | |
| label=filename, | |
| db_path=f"{meta['engine']}://schema/{meta['schema']}", | |
| note="Database schema created in Neon and indexed in Supabase.", | |
| ) | |
| async def list_connections(): | |
| return [ | |
| ConnectionInfo( | |
| connection_id=cid, | |
| label=meta.get("label", ""), | |
| engine=meta.get("engine"), | |
| db_name=meta.get("schema"), # usamos schema como "nombre" | |
| ) | |
| for cid, meta in sql_manager.connections.items() | |
| ] | |
| async def get_schema(connection_id: str): | |
| if connection_id not in sql_manager.connections: | |
| raise HTTPException(status_code=404, detail="connection_id no encontrado") | |
| meta = sql_manager.get_schema(connection_id) | |
| tables = meta["tables"] | |
| parts = [] | |
| for t, info in tables.items(): | |
| cols = info.get("columns", []) | |
| parts.append(f"{t}(" + ", ".join(cols) + ")") | |
| schema_str = " ; ".join(parts) if parts else "(empty_schema)" | |
| return SchemaResponse( | |
| connection_id=connection_id, | |
| schema_summary=schema_str, | |
| tables=tables, | |
| ) | |
| async def preview_table(connection_id: str, table: str, limit: int = 20): | |
| if connection_id not in sql_manager.connections: | |
| raise HTTPException(status_code=404, detail="connection_id no encontrado") | |
| try: | |
| preview = sql_manager.get_preview(connection_id, table, limit) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, detail=f"Error al leer tabla '{table}': {e}" | |
| ) | |
| return PreviewResponse( | |
| connection_id=connection_id, | |
| table=table, | |
| columns=preview["columns"], | |
| rows=preview["rows"], | |
| ) | |
| async def infer_sql( | |
| req: InferRequest, | |
| authorization: Optional[str] = Header(None), | |
| ): | |
| if authorization is None: | |
| raise HTTPException(401, "Missing Authorization header") | |
| jwt = authorization.replace("Bearer ", "") | |
| user = supabase.auth.get_user(jwt) | |
| if not user or not user.user: | |
| raise HTTPException(401, "Invalid Supabase token") | |
| result = nl2sql_with_rerank(req.question, req.connection_id) | |
| score = normalize_score(result["candidates"][0]["score"]) | |
| db_row = ( | |
| supabase.table("databases") | |
| .select("id") | |
| .eq("connection_id", req.connection_id) | |
| .eq("user_id", user.user.id) | |
| .execute() | |
| ) | |
| db_id = db_row.data[0]["id"] if db_row.data else None | |
| supabase.table("queries").insert( | |
| { | |
| "user_id": user.user.id, | |
| "db_id": db_id, | |
| "nl": result["question_original"], | |
| "sql_generated": result["best_sql"], | |
| "sql_repaired": result["candidates"][0]["sql"], | |
| "execution_ok": result["best_exec_ok"], | |
| "error": result["best_exec_error"], | |
| "rows_preview": result["best_rows_preview"], | |
| "score": score, | |
| } | |
| ).execute() | |
| result["score_percent"] = score | |
| return InferResponse(**result) | |
| async def speech_infer( | |
| connection_id: str = Form(...), | |
| audio: UploadFile = File(...), | |
| ): | |
| if openai_client is None: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="OPENAI_API_KEY no está configurado en el backend.", | |
| ) | |
| if audio.content_type is None: | |
| raise HTTPException(status_code=400, detail="Archivo de audio inválido.") | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".webm") as tmp: | |
| tmp.write(await audio.read()) | |
| tmp_path = tmp.name | |
| except Exception: | |
| raise HTTPException( | |
| status_code=500, detail="No se pudo procesar el audio recibido." | |
| ) | |
| try: | |
| with open(tmp_path, "rb") as f: | |
| transcription = openai_client.audio.transcriptions.create( | |
| model="gpt-4o-transcribe", | |
| file=f, | |
| ) | |
| transcript_text: str = transcription.text | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error al transcribir audio: {e}") | |
| result_dict = nl2sql_with_rerank(transcript_text, connection_id) | |
| infer_result = InferResponse(**result_dict) | |
| return SpeechInferResponse( | |
| transcript=transcript_text, | |
| result=infer_result, | |
| ) | |
| async def health(): | |
| return { | |
| "status": "ok", | |
| "model_loaded": t5_model is not None, | |
| "connections": len(sql_manager.connections), | |
| "device": str(DEVICE), | |
| "engine": "postgres", | |
| } | |
| def get_history(authorization: Optional[str] = Header(None)): | |
| if authorization is None: | |
| raise HTTPException(401, "Missing Authorization") | |
| jwt = authorization.replace("Bearer ", "") | |
| user = supabase.auth.get_user(jwt) | |
| rows = ( | |
| supabase.table("queries") | |
| .select("*") | |
| .eq("user_id", user.user.id) | |
| .order("created_at", desc=True) | |
| .execute() | |
| ) | |
| return rows.data | |
| def get_my_databases(authorization: Optional[str] = Header(None)): | |
| if authorization is None: | |
| raise HTTPException(401, "Missing Authorization") | |
| jwt = authorization.replace("Bearer ", "") | |
| user = supabase.auth.get_user(jwt) | |
| rows = ( | |
| supabase.table("databases") | |
| .select("*") | |
| .eq("user_id", user.user.id) | |
| .execute() | |
| ) | |
| return rows.data | |
| async def root(): | |
| return { | |
| "message": "NL2SQL T5-large backend running.", | |
| "endpoints": [ | |
| "POST /upload (subir .sql o .zip con .sql → crea schema en Supabase)", | |
| "GET /connections (listar BDs subidas en esta instancia)", | |
| "GET /schema/{id} (esquema resumido)", | |
| "GET /preview/{id}/{t} (preview de tabla)", | |
| "POST /infer (NL→SQL + ejecución en BD)", | |
| "POST /speech-infer (voz → NL→SQL + ejecución)", | |
| "GET /history (historial de consultas en Supabase)", | |
| "GET /my-databases (BDs del usuario en Supabase)", | |
| "GET /health (estado del backend)", | |
| "GET /docs (OpenAPI UI)", | |
| ], | |
| } |