sergey21000 commited on
Commit
4ea0220
·
verified ·
1 Parent(s): a2a8a9c

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +169 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,171 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+
2
+
 
3
  import streamlit as st
4
+ import pandas as pd
5
+ import numpy as np
6
+ from datasets import load_dataset, concatenate_datasets
7
+ from sklearn.feature_extraction.text import TfidfVectorizer
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ import torch
10
+ from transformers import AutoTokenizer, AutoModelForQuestionAnswering
11
+ import spacy
12
+ import nltk
13
+ from nltk.corpus import stopwords
14
+ from nltk.tokenize import word_tokenize
15
+ import re
16
+ from bs4 import BeautifulSoup
17
+
18
+ # === Загрузка и подготовка данных ===
19
+
20
+
21
+ spacy.cli.download('ru_core_news_lg')
22
+ nltk.download('punkt_tab', download_dir='/usr/local/share/nltk_data')
23
+ nltk.download('stopwords')
24
+
25
+
26
+ @st.cache_resource
27
+ def load_data():
28
+ # Загрузка датасета
29
+ data = load_dataset('Romyx/ru_QA_school_history', split='train')
30
+ df = pd.DataFrame(data)
31
+ df['Pt_question'] = df['question'].apply(preprocess_text)
32
+ df['Pt_answer'] = df['answer'].apply(preprocess_text)
33
+ return df
34
+
35
+ @st.cache_resource
36
+ def load_model_and_tokenizer():
37
+ # Загрузка предобученной модели вопрос-ответа (например, SberQuad)
38
+ model_name = "AlexKay/xlm-roberta-large-qa-multilingual-finedtuned-ru" # замените на нужную модель, например, "bert-base-uncased"
39
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
40
+ model = AutoModelForQuestionAnswering.from_pretrained(model_name)
41
+ return tokenizer, model
42
+
43
+ @st.cache_resource
44
+ def build_vectorizer(_df):
45
+ combined_texts = _df['Pt_question'].tolist() + _df['Pt_answer'].tolist()
46
+ vectorizer = TfidfVectorizer()
47
+ tfidf_matrix = vectorizer.fit_transform(combined_texts)
48
+ return vectorizer, tfidf_matrix
49
+
50
+ # === Предобработка текста ===
51
+
52
+ # Загрузка Spacy модели
53
+ nlp = spacy.load('ru_core_news_lg')
54
+ stop_words = set(stopwords.words('russian'))
55
+
56
+ cache_dict = {}
57
+
58
+ def get_norm_form(word):
59
+ if word in cache_dict:
60
+ return cache_dict[word]
61
+ norm_form = nlp(word)[0].lemma_
62
+ cache_dict[word] = norm_form
63
+ return norm_form
64
+
65
+ def remove_html_tags(text):
66
+ soup = BeautifulSoup(text, 'html.parser')
67
+ return soup.text
68
+
69
+ def preprocess_text(text):
70
+ if pd.isna(text) or text is None:
71
+ return ""
72
+ text = remove_html_tags(text)
73
+ text = text.lower()
74
+
75
+ # Обработка знаков препинания
76
+ text = re.sub(r'([^\w\s-]|_)', r' \1 ', text)
77
+ text = re.sub(r'\s+', ' ', text)
78
+ text = re.sub(r'(\w+)-(\w+)', r'\1 \2', text)
79
+ text = re.sub(r'(\d+)(г|кг|см|м|мм|л|мл)', r'\1 \2', text)
80
+
81
+ # Удаление всего, кроме букв, цифр и пробелов
82
+ text = re.sub(r'[^\w\s]', '', text)
83
+
84
+ tokens = word_tokenize(text)
85
+ tokens = [token for token in tokens if token not in stop_words]
86
+ tokens = [get_norm_form(token) for token in tokens]
87
+
88
+ words_to_remove = {"ответ", "new"}
89
+ tokens = [token for token in tokens if token not in words_to_remove]
90
+
91
+ return ' '.join(tokens)
92
+
93
+ # === Основная функция получения ответа ===
94
+ def get_answer_from_qa_model(user_question, df, vectorizer, tfidf_matrix, model, tokenizer):
95
+ processed = preprocess_text(user_question)
96
+ user_vec = vectorizer.transform([processed])
97
+
98
+ similarities = cosine_similarity(user_vec, tfidf_matrix).flatten()
99
+
100
+ # Проверка, что similarities не пустой
101
+ if len(similarities) == 0:
102
+ return "Тема не входит в программу этих классов."
103
+
104
+ best_match_idx = similarities.argmax()
105
+ best_score = similarities[best_match_idx]
106
+
107
+ if best_score > 0.1:
108
+ # Проверка, что индекс не выходит за границы
109
+ if best_match_idx >= len(df):
110
+ return "Тема не входит в программу этих классов."
111
+
112
+ context = df.iloc[best_match_idx]['answer']
113
+ question = user_question
114
+
115
+ inputs = tokenizer(question, context, return_tensors="pt", truncation=True, padding=True)
116
+
117
+ with torch.no_grad():
118
+ outputs = model(**inputs)
119
+
120
+ start_scores = outputs.start_logits
121
+ end_scores = outputs.end_logits
122
+
123
+ # Проверка на корректность размера логитов
124
+ if len(start_scores.shape) == 2:
125
+ start_idx = torch.argmax(start_scores, dim=1)[0].item()
126
+ end_idx = torch.argmax(end_scores, dim=1)[0].item()
127
+ else:
128
+ start_idx = torch.argmax(start_scores).item()
129
+ end_idx = torch.argmax(end_scores).item()
130
+
131
+ # Проверка, что индексы не выходят за пределы
132
+ seq_len = inputs['input_ids'].shape[1]
133
+ if start_idx >= seq_len or end_idx >= seq_len or start_idx > end_idx:
134
+ return "Ответ не найден."
135
+
136
+ answer = tokenizer.decode(inputs['input_ids'][0][start_idx:end_idx+1], skip_special_tokens=True)
137
+ else:
138
+ answer = "Извините, я не понимаю вопрос."
139
+
140
+ return answer
141
+
142
+ # === Интерфейс Streamlit ===
143
+
144
+ def main():
145
+ st.title("🤖 ИИ-ассистент по истории (на основе вопрос-ответа)")
146
+
147
+ st.write("Задайте вопрос, и я постараюсь найти на него ответ из базы.")
148
+
149
+ # Загрузка данных и модели
150
+ df = load_data()
151
+ tokenizer, model = load_model_and_tokenizer()
152
+ vectorizer, tfidf_matrix = build_vectorizer(df)
153
+
154
+ # Поле ввода вопроса
155
+ user_input = st.text_input("Введите ваш вопрос:")
156
+
157
+ if st.button("Получить ответ"):
158
+ if user_input.strip():
159
+ with st.spinner("Ищем ответ..."):
160
+ response = get_answer_from_qa_model(
161
+ user_input, df, vectorizer, tfidf_matrix, model, tokenizer
162
+ )
163
+ st.success("Ответ:")
164
+ st.write(response)
165
+ else:
166
+ st.warning("Пожалуйста, введите вопрос.")
167
+
168
+ if __name__ == "__main__":
169
+ main()
170
+
171