Spaces:
Runtime error
Runtime error
| # interpret module that implements the interpretability method | |
| # external imports | |
| from shap import models, maskers, plots, PartitionExplainer | |
| import torch | |
| # internal imports | |
| from utils import formatting as fmt | |
| from .plotting import plot_seq | |
| from .markup import markup_text | |
| # global variables | |
| TEACHER_FORCING = None | |
| TEXT_MASKER = None | |
| # function to extract summarized sequence wise attribution | |
| def shap_extract_seq_att(shap_values): | |
| # extracting summed up shap values | |
| values = fmt.flatten_attribution(shap_values.values[0], 1) | |
| # returning list of tuples of token and value | |
| return list(zip(shap_values.data[0], values)) | |
| # function used to wrap the model with a shap model | |
| def wrap_shap(model): | |
| # calling global variants | |
| global TEXT_MASKER, TEACHER_FORCING | |
| # set the device to cuda if gpu is available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # updating the model settings | |
| model.set_config({}) | |
| # (re)initialize the shap models and masker | |
| # creating a shap text_generation model | |
| text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER) | |
| # wrapping the text generation model in a teacher forcing model | |
| TEACHER_FORCING = models.TeacherForcing( | |
| text_generation, | |
| model.TOKENIZER, | |
| device=str(device), | |
| similarity_model=model.MODEL, | |
| similarity_tokenizer=model.TOKENIZER, | |
| ) | |
| # setting the text masker as an empty string | |
| TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True) | |
| # graphic plotting function that creates a html graphic (as string) for the explanation | |
| def create_graphic(shap_values): | |
| # create the html graphic using shap text plot function | |
| graphic_html = plots.text(shap_values, display=False) | |
| # return the html graphic as string to display in iFrame | |
| return str(graphic_html) | |
| # main explain function that returns a chat with explanations | |
| def chat_explained(model, prompt): | |
| model.set_config({}) | |
| # create the shap explainer | |
| shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER) | |
| # get the shap values for the prompt | |
| shap_values = shap_explainer([prompt]) | |
| # create the explanation graphic and marked text array | |
| graphic = create_graphic(shap_values) | |
| marked_text = markup_text( | |
| shap_values.data[0], shap_values.values[0], variant="shap" | |
| ) | |
| # create the response text | |
| response_text = fmt.format_output_text(shap_values.output_names) | |
| # creating sequence attribution plot | |
| plot = plot_seq(shap_extract_seq_att(shap_values), "PartitionSHAP") | |
| # return response, graphic and marked_text array | |
| return response_text, graphic, marked_text, plot | |