| import gradio as gr |
|
|
| from llms.tiny_llama import TinyLlama |
| from knowledgebase import KnowledgeBase |
| from langchain.chains import ConversationalRetrievalChain |
| from langchain.memory import ConversationBufferMemory |
|
|
|
|
| LLM = None |
| KNOWLEDGEBASE = None |
| SYTEM_PROMPT = ( |
| "You are an assistant. Answer to a user's query based on a given context." |
| ) |
|
|
| |
| |
| |
| QA_CHAIN = None |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def init_llm(llm="TinyLlama"): |
| global LLM |
| LLM = TinyLlama() |
|
|
|
|
| def chat(message, history): |
| global LLM, KNOWLEDGEBASE, SYTEM_PROMPT |
| context = KNOWLEDGEBASE.invoke(message)[0].page_content |
| system = {"role": "system", "content": SYTEM_PROMPT} |
| user = { |
| "role": "user", |
| "content": f"prompt: ```{message}``\ncontext:```{context}```", |
| } |
| response = LLM([system, user]).split("<|assistant|>")[-1] |
| return response |
|
|
|
|
| def init_rag(system_prompt, url_input, file_input): |
| global SYTEM_PROMPT |
| if SYTEM_PROMPT != system_prompt: |
| SYTEM_PROMPT = system_prompt |
| gr.Info("Saved new system prompt") |
| if url_input and file_input: |
| gr.Error(message="Provide either an URL or a File") |
| path = url_input if url_input else file_input |
| load_knowledgebase(path) |
|
|
|
|
| def load_knowledgebase(path): |
| global KNOWLEDGEBASE |
| if not KNOWLEDGEBASE: |
| KNOWLEDGEBASE = KnowledgeBase() |
| print("Loading knowledgebase:", path) |
| if not path: |
| return |
| if "https://" in path: |
| KNOWLEDGEBASE.load_url(path) |
| gr.Info(message="Succesfully loaded URL") |
| else: |
| if path.split(".")[-1] == "pdf": |
| KNOWLEDGEBASE.load_pdf(path) |
| gr.Info(message="Succesfully loaded pdf") |
| else: |
| KNOWLEDGEBASE.load_txt(path) |
| gr.Info(message="Succesfully loaded txt") |
|
|
|
|
| def show_file(file): |
| print(file) |
| return file |
|
|
|
|
| with gr.Blocks(title="d-RAG") as iface: |
| gr.Markdown( |
| """# d-RAG [](https://github.com/rumbleFTW/d-RAG/watchers) [](https://github.com/rumbleFTW/d-RAG/stargazers) |
| """ |
| ) |
| with gr.Row(equal_height=True): |
| with gr.Column(): |
| with gr.Row(): |
| model = gr.Dropdown( |
| label="Model", |
| choices=[ |
| "TinyLlama-1.1B-Chat-v1.0", |
| "Mixtral-8x7B-Instruct-v0.1", |
| "Mistral-7B-Instruct-v0.2", |
| ], |
| value="TinyLlama-1.1B-Chat-v1.0", |
| scale=1, |
| interactive=True, |
| ) |
| system_prompt = gr.Textbox( |
| label="System prompt", |
| value="You are an assistant. Answer to a user's query based on a given context.", |
| scale=2, |
| ) |
| with gr.Accordion(label="Knowledge base", open=True): |
| url_input = gr.Textbox(placeholder="URL", value=None) |
| gr.Markdown("OR") |
| file_input = gr.File( |
| file_count="multiple", |
| file_types=[".txt", ".pdf"], |
| show_label=True, |
| visible=True, |
| ) |
| submit = gr.Button("Submit") |
| submit.click( |
| fn=init_rag, |
| inputs=[system_prompt, url_input, file_input], |
| ) |
| with gr.Column(): |
| demo = gr.ChatInterface(fn=chat, examples=["Namaste!", "Hello!", "Hola!"]) |
|
|
| if __name__ == "__main__": |
| init_llm() |
| iface.launch(debug=True) |
|
|