Spaces:
Runtime error
Runtime error
| # modelling util module providing formatting functions for model functionalities | |
| # external imports | |
| import torch | |
| import gradio as gr | |
| from transformers import BitsAndBytesConfig | |
| # function that limits the prompt to contain model runtime | |
| # tries to keep as much as possible, always keeping at least message and system prompt | |
| def prompt_limiter( | |
| tokenizer, message: str, history: list, system_prompt: str, knowledge: str = "" | |
| ): | |
| # initializing the new prompt history empty | |
| prompt_history = [] | |
| # getting the current token count for the message, system prompt, and knowledge | |
| pre_count = ( | |
| token_counter(tokenizer, message) | |
| + token_counter(tokenizer, system_prompt) | |
| + token_counter(tokenizer, knowledge) | |
| ) | |
| # validating the token count against threshold of 1024 | |
| # check if token count already too high without history | |
| if pre_count > 1024: | |
| # check if token count too high even without knowledge and history | |
| if ( | |
| token_counter(tokenizer, message) + token_counter(tokenizer, system_prompt) | |
| > 1024 | |
| ): | |
| # show warning and raise error | |
| gr.Warning("Message and system prompt are too long. Please shorten them.") | |
| raise RuntimeError( | |
| "Message and system prompt are too long. Please shorten them." | |
| ) | |
| # show warning and return with empty history and empty knowledge | |
| gr.Warning(""" | |
| Input too long. | |
| Knowledge and conversation history have been removed to keep model running. | |
| """) | |
| return message, prompt_history, system_prompt, "" | |
| # if token count small enough, adding history bit by bit | |
| if pre_count < 800: | |
| # setting the count to the pre-count | |
| count = pre_count | |
| # reversing the history to prioritize recent conversations | |
| history.reverse() | |
| # iterating through the history | |
| for conversation in history: | |
| # checking the token count i´with the current conversation | |
| count += token_counter(tokenizer, conversation[0]) + token_counter( | |
| tokenizer, conversation[1] | |
| ) | |
| # add conversation or break loop depending on token count | |
| if count < 1024: | |
| prompt_history.append(conversation) | |
| else: | |
| break | |
| # return the message, adapted, system prompt, and knowledge | |
| return message, prompt_history, system_prompt, knowledge | |
| # token counter function using the model tokenizer | |
| def token_counter(tokenizer, text: str): | |
| # tokenize the text | |
| tokens = tokenizer(text, return_tensors="pt").input_ids | |
| # return the token count | |
| return len(tokens[0]) | |
| # function to determine the device to use | |
| def get_device(): | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| return device | |
| # function to set device config | |
| # CREDIT: Copied from captum llama 2 example | |
| # see https://captum.ai/tutorials/Llama2_LLM_Attribution | |
| def gpu_loading_config(max_memory: str = "15000MB"): | |
| n_gpus = torch.cuda.device_count() | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| ) | |
| return n_gpus, max_memory, bnb_config | |
| # formatting mistral attention values | |
| # CREDIT: copied from BERTViz | |
| # see https://github.com/jessevig/bertviz | |
| def format_mistral_attention(attention_values, layers=None, heads=None): | |
| if layers: | |
| attention_values = [attention_values[layer_index] for layer_index in layers] | |
| squeezed = [] | |
| for layer_attention in attention_values: | |
| layer_attention = layer_attention.squeeze(0) | |
| if heads: | |
| layer_attention = layer_attention[heads] | |
| squeezed.append(layer_attention) | |
| return torch.stack(squeezed) | |