Spaces:
Runtime error
Runtime error
| # controller for the application that calls the model and explanation functions | |
| # returns the updated conversation history and extra elements | |
| # external imports | |
| import gradio as gr | |
| # internal imports | |
| from model import godel | |
| from model import mistral | |
| from explanation import ( | |
| attention as attention_viz, | |
| interpret_shap as shap_int, | |
| interpret_captum as cpt_int, | |
| ) | |
| # simple chat function that calls the model | |
| # formats prompts, calls for an answer and returns updated conversation history | |
| def vanilla_chat( | |
| model, message: str, history: list, system_prompt: str, knowledge: str = "" | |
| ): | |
| print(f"Running normal chat with {model}.") | |
| # formatting the prompt using the model's format_prompt function | |
| prompt = model.format_prompt(message, history, system_prompt, knowledge) | |
| # generating an answer using the model's respond function | |
| answer = model.respond(prompt) | |
| # updating the chat history with the new answer | |
| history.append((message, answer)) | |
| # returning the updated history | |
| return "", history | |
| def explained_chat( | |
| model, xai, message: str, history: list, system_prompt: str, knowledge: str = "" | |
| ): | |
| print(f"Running explained chat with {xai} with {model}.") | |
| # formatting the prompt using the model's format_prompt function | |
| # message, history, system_prompt, knowledge = mdl.prompt_limiter( | |
| # message, history, system_prompt, knowledge | |
| # ) | |
| prompt = model.format_prompt(message, history, system_prompt, knowledge) | |
| # generating an answer using the methods chat function | |
| answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt) | |
| # updating the chat history with the new answer | |
| history.append((message, answer)) | |
| # returning the updated history, xai graphic and xai plot elements | |
| return "", history, xai_graphic, xai_markup, xai_plot | |
| # main interference function that calls chat functions depending on selections | |
| def interference( | |
| prompt: str, | |
| history: list, | |
| knowledge: str, | |
| system_prompt: str, | |
| xai_selection: str, | |
| model_selection: str, | |
| ): | |
| # if no proper system prompt is given, use a default one | |
| if system_prompt in ("", " "): | |
| system_prompt = ( | |
| "You are a helpful, respectful and honest assistant." | |
| "Always answer as helpfully as possible, while being safe." | |
| ) | |
| # if a model is selected, grab the model instance | |
| if model_selection.lower() == "mistral": | |
| model = mistral | |
| print("Identified model as Mistral") | |
| else: | |
| model = godel | |
| print("Identified model as GODEL") | |
| # if a XAI approach is selected, grab the XAI module instance | |
| # and call the explained chat function | |
| if xai_selection in ("SHAP", "Attention"): | |
| # matching selection | |
| match xai_selection.lower(): | |
| case "shap": | |
| if model_selection.lower() == "mistral": | |
| xai = cpt_int | |
| else: | |
| xai = shap_int | |
| case "attention": | |
| xai = attention_viz | |
| case _: | |
| # use Gradio warning to display error message | |
| gr.Warning(f""" | |
| There was an error in the selected XAI Approach. | |
| It is "{xai_selection}" | |
| """) | |
| # raise runtime exception | |
| raise RuntimeError("There was an error in the selected XAI approach.") | |
| # call the explained chat function with the model instance | |
| prompt_output, history_output, xai_interactive, xai_markup, xai_plot = ( | |
| explained_chat( | |
| model=model, | |
| xai=xai, | |
| message=prompt, | |
| history=history, | |
| system_prompt=system_prompt, | |
| knowledge=knowledge, | |
| ) | |
| ) | |
| # if no XAI approach is selected call the vanilla chat function | |
| else: | |
| # calling the vanilla chat function | |
| prompt_output, history_output = vanilla_chat( | |
| model=model, | |
| message=prompt, | |
| history=history, | |
| system_prompt=system_prompt, | |
| knowledge=knowledge, | |
| ) | |
| # set XAI outputs to disclaimer html/none | |
| xai_interactive, xai_markup, xai_plot = ( | |
| """ | |
| <div style="text-align: center"><h4>Without Selected XAI Approach, | |
| no graphic will be displayed</h4></div> | |
| """, | |
| [("", "")], | |
| None, | |
| ) | |
| # return the outputs | |
| return prompt_output, history_output, xai_interactive, xai_markup, xai_plot | |