Spaces:
Runtime error
Runtime error
| from langchain_community.llms import CTransformers | |
| from ctransformers import AutoModelForCausalLM | |
| from langchain.agents import Tool | |
| from langchain.agents import AgentType, initialize_agent | |
| from langchain.chains import RetrievalQA | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
| import tempfile | |
| import os | |
| import streamlit as st | |
| import timeit | |
| from langchain.callbacks.tracers import ConsoleCallbackHandler | |
| # tt | |
| def main(): | |
| FILE_LOADER_MAPPING = { | |
| "pdf": (PyPDFLoader, {}) | |
| # Add more mappings for other file extensions and loaders as needed | |
| } | |
| st.title("Document Comparison with Q&A using Agents") | |
| # Upload files | |
| uploaded_files = st.file_uploader("Upload your documents", type=["pdf"], accept_multiple_files=True) | |
| loaded_documents = [] | |
| if uploaded_files: | |
| # Create a temporary directory | |
| with tempfile.TemporaryDirectory() as td: | |
| # Move the uploaded files to the temporary directory and process them | |
| for uploaded_file in uploaded_files: | |
| st.write(f"Uploaded: {uploaded_file.name}") | |
| ext = os.path.splitext(uploaded_file.name)[-1][1:].lower() | |
| st.write(f"Uploaded: {ext}") | |
| # Check if the extension is in FILE_LOADER_MAPPING | |
| if ext in FILE_LOADER_MAPPING: | |
| loader_class, loader_args = FILE_LOADER_MAPPING[ext] | |
| # st.write(f"loader_class: {loader_class}") | |
| # Save the uploaded file to the temporary directory | |
| file_path = os.path.join(td, uploaded_file.name) | |
| with open(file_path, 'wb') as temp_file: | |
| temp_file.write(uploaded_file.read()) | |
| # Use Langchain loader to process the file | |
| loader = loader_class(file_path, **loader_args) | |
| loaded_documents.extend(loader.load()) | |
| else: | |
| st.warning(f"Unsupported file extension: {ext}, the app currently only supports pdf") | |
| st.write("Ask question to get comparison from the documents:") | |
| query = st.text_input("Ask a question:") | |
| if st.button("Get Answer"): | |
| if query: | |
| # Load model, set prompts, create vector database, and retrieve answer | |
| try: | |
| start = timeit.default_timer() | |
| # config = { | |
| # 'max_new_tokens': 1024, | |
| # 'repetition_penalty': 1.1, | |
| # 'temperature': 0.1, | |
| # 'top_k': 50, | |
| # 'top_p': 0.9, | |
| # 'stream': True, | |
| # 'threads': int(os.cpu_count() / 2) | |
| # } | |
| llm = CTransformers( | |
| # model = "TheBloke/Mistral-7B-Instruct-v0.2-GGUF", | |
| model= "TheBloke/Llama-2-7B-Chat-GGUF", | |
| model_file = "llama-2-7b-chat.Q3_K_S.gguf", | |
| model_type="llama", | |
| max_new_tokens = 300, | |
| temperature = 0.3, | |
| lib="avx2", # for CPU | |
| ) | |
| # llm = AutoModelForCausalLM.from_pretrained("second-state/stablelm-2-zephyr-1.6b-GGUF", model_type="stablelm-2-zephyr-1_6b-Q4_0.gguf") | |
| print("LLM Initialized...") | |
| model_name = "BAAI/bge-large-en" | |
| model_kwargs = {'device': 'cpu'} | |
| encode_kwargs = {'normalize_embeddings': False} | |
| embeddings = HuggingFaceBgeEmbeddings( | |
| model_name=model_name, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs | |
| ) | |
| text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=0) | |
| chunked_documents = text_splitter.split_documents(loaded_documents) | |
| retriever = FAISS.from_documents(chunked_documents, embeddings).as_retriever() | |
| # Wrap retrievers in a Tool | |
| tools = [] | |
| tools.append( | |
| Tool( | |
| name="Comparison tool", | |
| description="useful when you want to answer questions about the uploaded documents", | |
| func=RetrievalQA.from_chain_type(llm=llm, retriever=retriever), | |
| ) | |
| ) | |
| agent = initialize_agent( | |
| tools=tools, | |
| llm=llm, | |
| agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, | |
| verbose=True | |
| ) | |
| # response = agent.run(query) | |
| end = timeit.default_timer() | |
| st.write("Elapsed time:") | |
| st.write(end - start) | |
| st.write("Bot Response:") | |
| # st.write(agent.invoke(query, config={"callbacks":[ConsoleCallbackHandler()]})) | |
| st.write(agent.run({"input": query})) | |
| # st.write(response) | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| else: | |
| st.warning("Please enter a question.") | |
| if __name__ == "__main__": | |
| main() | |