Commit
Β·
317f434
1
Parent(s):
8515a17
Update app.py
Browse files
app.py
CHANGED
|
@@ -12,11 +12,7 @@ from streamlit_chat import message
|
|
| 12 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
| 13 |
import torch
|
| 14 |
|
| 15 |
-
|
| 16 |
-
CHECKPOINT = "MBZUAI/LaMini-T5-738M"
|
| 17 |
-
TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
|
| 18 |
-
BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32)
|
| 19 |
-
|
| 20 |
|
| 21 |
def process_answer(instruction, qa_chain):
|
| 22 |
response = ''
|
|
@@ -50,7 +46,11 @@ def data_ingestion():
|
|
| 50 |
|
| 51 |
|
| 52 |
@st.cache_resource
|
| 53 |
-
def initialize_qa_chain():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
pipe = pipeline(
|
| 55 |
'text2text-generation',
|
| 56 |
model=BASE_MODEL,
|
|
@@ -101,7 +101,10 @@ def display_conversation(history):
|
|
| 101 |
|
| 102 |
|
| 103 |
def main():
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
| 105 |
st.markdown("<h1 style='text-align: center; color: blue;'>Custom PDF Chatbot π¦π </h1>", unsafe_allow_html=True)
|
| 106 |
st.markdown("<h2 style='text-align: center; color:red;'>Upload your PDF, and Ask Questions π</h2>", unsafe_allow_html=True)
|
| 107 |
|
|
@@ -125,6 +128,7 @@ def main():
|
|
| 125 |
pdf_view = display_pdf(filepath)
|
| 126 |
|
| 127 |
with col2:
|
|
|
|
| 128 |
with st.spinner('Embeddings are in process...'):
|
| 129 |
ingested_data = data_ingestion()
|
| 130 |
st.success('Embeddings are created successfully!')
|
|
@@ -140,7 +144,7 @@ def main():
|
|
| 140 |
|
| 141 |
# Search the database for a response based on user input and update session state
|
| 142 |
if user_input:
|
| 143 |
-
answer = process_answer({'query': user_input}, initialize_qa_chain())
|
| 144 |
st.session_state["past"].append(user_input)
|
| 145 |
response = answer
|
| 146 |
st.session_state["generated"].append(response)
|
|
|
|
| 12 |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
| 13 |
import torch
|
| 14 |
|
| 15 |
+
st.set_page_config(layout="wide")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
def process_answer(instruction, qa_chain):
|
| 18 |
response = ''
|
|
|
|
| 46 |
|
| 47 |
|
| 48 |
@st.cache_resource
|
| 49 |
+
def initialize_qa_chain(selected_model):
|
| 50 |
+
# Constants
|
| 51 |
+
CHECKPOINT = selected_model
|
| 52 |
+
TOKENIZER = AutoTokenizer.from_pretrained(CHECKPOINT)
|
| 53 |
+
BASE_MODEL = AutoModelForSeq2SeqLM.from_pretrained(CHECKPOINT, device_map=torch.device('cpu'), torch_dtype=torch.float32)
|
| 54 |
pipe = pipeline(
|
| 55 |
'text2text-generation',
|
| 56 |
model=BASE_MODEL,
|
|
|
|
| 101 |
|
| 102 |
|
| 103 |
def main():
|
| 104 |
+
# Add a sidebar for model selection
|
| 105 |
+
model_options = ["MBZUAI/LaMini-T5-738M", "google/flan-t5-base", "google/flan-t5-small"]
|
| 106 |
+
selected_model = st.sidebar.selectbox("Select Model", model_options)
|
| 107 |
+
|
| 108 |
st.markdown("<h1 style='text-align: center; color: blue;'>Custom PDF Chatbot π¦π </h1>", unsafe_allow_html=True)
|
| 109 |
st.markdown("<h2 style='text-align: center; color:red;'>Upload your PDF, and Ask Questions π</h2>", unsafe_allow_html=True)
|
| 110 |
|
|
|
|
| 128 |
pdf_view = display_pdf(filepath)
|
| 129 |
|
| 130 |
with col2:
|
| 131 |
+
st.success(f'model selected successfully: {selected_model}')
|
| 132 |
with st.spinner('Embeddings are in process...'):
|
| 133 |
ingested_data = data_ingestion()
|
| 134 |
st.success('Embeddings are created successfully!')
|
|
|
|
| 144 |
|
| 145 |
# Search the database for a response based on user input and update session state
|
| 146 |
if user_input:
|
| 147 |
+
answer = process_answer({'query': user_input}, initialize_qa_chain(selected_model))
|
| 148 |
st.session_state["past"].append(user_input)
|
| 149 |
response = answer
|
| 150 |
st.session_state["generated"].append(response)
|