Spaces:
Paused
Paused
| # import streamlit as st | |
| # from text2sql import ChatBot | |
| # from langdetect import detect | |
| # from utils.translate_utils import translate_zh_to_en | |
| # from utils.db_utils import add_a_record | |
| # from langdetect.lang_detect_exception import LangDetectException | |
| # # Initialize chatbot and other variables | |
| # text2sql_bot = ChatBot() | |
| # baidu_api_token = None | |
| # # Define database schemas for demonstration | |
| # db_schemas = { | |
| # "singer": """ | |
| # CREATE TABLE "singer" ( | |
| # "Singer_ID" int, | |
| # "Name" text, | |
| # "Birth_Year" real, | |
| # "Net_Worth_Millions" real, | |
| # "Citizenship" text, | |
| # PRIMARY KEY ("Singer_ID") | |
| # ); | |
| # CREATE TABLE "song" ( | |
| # "Song_ID" int, | |
| # "Title" text, | |
| # "Singer_ID" int, | |
| # "Sales" real, | |
| # "Highest_Position" real, | |
| # PRIMARY KEY ("Song_ID"), | |
| # FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID") | |
| # ); | |
| # """, | |
| # # Add other schemas as needed | |
| # } | |
| # # Streamlit UI | |
| # st.title("Text-to-SQL Chatbot") | |
| # st.sidebar.header("Select a Database") | |
| # # Sidebar for selecting a database | |
| # selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys())) | |
| # # Display the selected schema | |
| # st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600) | |
| # # User input section | |
| # question = st.text_input("Enter your question:") | |
| # db_id = selected_db # Use selected database for DB ID | |
| # if question: | |
| # add_a_record(question, db_id) | |
| # try: | |
| # if baidu_api_token is not None and detect(question) != "en": | |
| # print("Before translation:", question) | |
| # question = translate_zh_to_en(question, baidu_api_token) | |
| # print("After translation:", question) | |
| # except LangDetectException as e: | |
| # print("Language detection error:", str(e)) | |
| # predicted_sql = text2sql_bot.get_response(question, db_id) | |
| # st.write(f"**Database:** {db_id}") | |
| # st.write(f"**Predicted SQL query:** {predicted_sql}") | |
| import streamlit as st | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from langdetect import detect | |
| from utils.translate_utils import translate_zh_to_en | |
| from utils.db_utils import add_a_record | |
| from langdetect.lang_detect_exception import LangDetectException | |
| class SchemaItemClassifierInference: | |
| def __init__(self, model_name): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True) | |
| self.model = AutoModelForSequenceClassification.from_pretrained(model_name, use_auth_token=True) | |
| def predict(self, text): | |
| inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
| outputs = self.model(**inputs) | |
| return outputs.logits | |
| class ChatBot: | |
| def __init__(self): | |
| model_name = "Roxanne-WANG/LangSQL" | |
| self.sic = SchemaItemClassifierInference(model_name) | |
| def get_response(self, question, db_id): | |
| prediction = self.sic.predict(question) | |
| return prediction | |
| text2sql_bot = ChatBot() | |
| baidu_api_token = None | |
| db_schemas = { | |
| "singer": """ | |
| CREATE TABLE "singer" ( | |
| "Singer_ID" int, | |
| "Name" text, | |
| "Birth_Year" real, | |
| "Net_Worth_Millions" real, | |
| "Citizenship" text, | |
| PRIMARY KEY ("Singer_ID") | |
| ); | |
| CREATE TABLE "song" ( | |
| "Song_ID" int, | |
| "Title" text, | |
| "Singer_ID" int, | |
| "Sales" real, | |
| "Highest_Position" real, | |
| PRIMARY KEY ("Song_ID"), | |
| FOREIGN KEY ("Singer_ID") REFERENCES "singer"("Singer_ID") | |
| ); | |
| """, | |
| } | |
| # Streamlit UI | |
| st.title("Text-to-SQL Chatbot") | |
| st.sidebar.header("Select a Database") | |
| selected_db = st.sidebar.selectbox("Choose a database:", list(db_schemas.keys())) | |
| st.sidebar.text_area("Database Schema", db_schemas[selected_db], height=600) | |
| question = st.text_input("Enter your question:") | |
| db_id = selected_db | |
| if question: | |
| add_a_record(question, db_id) | |
| try: | |
| if baidu_api_token is not None and detect(question) != "en": | |
| print("Before translation:", question) | |
| question = translate_zh_to_en(question, baidu_api_token) | |
| print("After translation:", question) | |
| except LangDetectException as e: | |
| print("Language detection error:", str(e)) | |
| predicted_sql = text2sql_bot.get_response(question, db_id) | |
| st.write(f"**Database:** {db_id}") | |
| st.write(f"**Predicted SQL query:** {predicted_sql}") |