render4 / src /streamlit_app.py
AIEcosystem's picture
Update src/streamlit_app.py
dcb0b21 verified
import os
os.environ['HF_HOME'] = '/tmp'
import time
import streamlit as st
import streamlit.components.v1 as components
import pandas as pd
import io
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import re
import string
import json
from itertools import cycle
from io import BytesIO
import plotly.io as pio
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import LatentDirichletAllocation
from gliner import GLiNER
from streamlit_extras.stylable_container import stylable_container
import time # Optional: for simulating database processing
st.set_page_config(
page_title="Premium Dashboard",
layout="centered",
initial_sidebar_state="collapsed",
)
# --- Email Retrieval Logic (CRITICAL) ---
# This is where the app reads the 'user_email' parameter from the URL
query_params = st.query_params
user_email = query_params.get("user_email")
# ----------------------------------------
# --- Main Application Logic ---
# ----------------------------------------
st.title("Premium Subscriber Dashboard")
st.markdown("---")
if user_email:
# 🌟 STEP 1: Confirmation and Display
st.balloons()
st.success(f"Payment Confirmed! Welcome to Premium, **{user_email}**! You now have full access. ")
st.header("Granting Access...")
# 🌟 STEP 2: CRITICAL BACKEND PROCESSING
# This is where your code would connect to your database (like Firestore)
# and update the user's status to 'Premium'.
with st.spinner(f"Processing subscription for {user_email}..."):
# --- SIMULATED DATABASE LOGIC START ---
time.sleep(2) # Simulate network delay/database write
# In a real app, you would:
# 1. Connect to Firestore.
# 2. Query your users collection to find the user with this email.
# 3. Update their document: { subscription_status: "active", start_date: current_date }
# --- SIMULATED DATABASE LOGIC END ---
st.info(f"βœ… Your premium access is now permanently linked to **{user_email}**.")
# 🌟 STEP 3: Display Premium Features
st.markdown("""
## πŸ”‘ Your Exclusive Premium Features
""")
col1, col2 = st.columns(2)
with col1:
st.metric(label="Subscription Status", value="Active (Annual)")
with col2:
st.metric(label="Access Tier", value="Unlimited")
st.button("Access Advanced Reports & Tools", type="primary")
st.markdown("---")
st.write("Enjoy your enhanced experience!")
else:
# ⚠️ Case where the user arrives without the 'user_email' parameter
st.error("Access Denied or Subscription Details Missing.")
st.markdown("""
It looks like you arrived here without a confirmation link. If you have already paid:
1. Please check the email address you used for payment.
2. Contact support with your PayPal transaction ID for manual activation.
If you have not paid, please return to the free app to upgrade.
""")
# --- Comet ML Imports (Optional/Placeholder) ---
try:
from comet_ml import Experiment
except ImportError:
class Experiment:
def __init__(self, **kwargs): pass
def log_parameter(self, *args): pass
def log_table(self, *args): pass
def end(self): pass
# --- Fixed Label Definitions and Mappings ---
FIXED_LABELS = ["person", "country", "city", "organization", "date", "time", "cardinal", "money", "position"]
DEFAULT_CUSTOM_LABELS = "person, location, organization, product, date, time, event" # <-- REINSTATED
FIXED_ENTITY_COLOR_MAP = {
"person": "#10b981", # Green
"country": "#3b82f6", # Blue
"city": "#4ade80", # Light Green
"organization": "#f59e0b", # Orange
"date": "#8b5cf6", # Purple
"time": "#ec4899", # Pink
"cardinal": "#06b6d4", # Cyan
"money": "#f43f5e", # Red
"position": "#a855f7", # Violet
}
# --- Fixed Category Mapping ---
FIXED_CATEGORY_MAPPING = {
"People & Roles": ["person", "organization", "position"],
"Locations": ["country", "city"],
"Time & Dates": ["date", "time"],
"Numbers & Finance": ["money", "cardinal"]}
REVERSE_FIXED_CATEGORY_MAPPING = {label: category for category, label_list in FIXED_CATEGORY_MAPPING.items() for label in label_list}
# --- Dynamic Color Generator for Custom Labels ---
COLOR_PALETTE = cycle(px.colors.qualitative.Alphabet + px.colors.qualitative.Bold) # Use a larger palette
def extract_label(node_name):
"""Extracts the label from a node string like 'Text (Label)'."""
match = re.search(r'\(([^)]+)\)$', node_name)
return match.group(1) if match else "Unknown"
def remove_trailing_punctuation(text_string):
"""Removes trailing punctuation from a string."""
return text_string.rstrip(string.punctuation)
def get_dynamic_color_map(active_labels, fixed_map):
"""Generates a color map, using fixed colors if available, otherwise dynamic colors."""
color_map = {}
# If the active labels exactly match the fixed set, use the fixed map
if set(active_labels) == set(fixed_map.keys()):
return fixed_map
# Otherwise, generate a dynamic map, prioritizing fixed colors
# Ensure the color palette resets for consistency across sessions
global COLOR_PALETTE
COLOR_PALETTE = cycle(px.colors.qualitative.Alphabet + px.colors.qualitative.Bold)
for label in active_labels:
if label in fixed_map:
color_map[label] = fixed_map[label]
else:
color_map[label] = next(COLOR_PALETTE)
return color_map
def highlight_entities(text, df_entities, entity_color_map):
"""Generates HTML to display text with entities highlighted and colored."""
if df_entities.empty:
return text
# Ensure the DataFrame has a unique index before sorting/converting
df_entities = df_entities.copy().reset_index(drop=True)
entities = df_entities.sort_values(by='start', ascending=False).to_dict('records')
highlighted_text = text
for entity in entities:
start = max(0, entity['start'])
end = min(len(text), entity['end'])
entity_text_from_full_doc = text[start:end]
label = entity['label']
color = entity_color_map.get(label, '#000000')
highlight_html = f'<span style="background-color: {color}; color: white; padding: 2px 4px; border-radius: 3px; cursor: help;" title="{label}">{entity_text_from_full_doc}</span>'
highlighted_text = highlighted_text[:start] + highlight_html + highlighted_text[end:]
return f'<div style="border: 1px solid #888888; padding: 15px; border-radius: 5px; background-color: #ffffff; font-family: monospace; white-space: pre-wrap; margin-bottom: 20px;">{highlighted_text}</div>'
def perform_topic_modeling(df_entities, num_topics=2, num_top_words=10):
"""Performs basic Topic Modeling using LDA."""
documents = df_entities['text'].unique().tolist()
if len(documents) < 2:
return None
N = min(num_top_words, len(documents))
try:
# Step 1: Try aggressive filtering
tfidf_vectorizer = TfidfVectorizer(max_df=0.95, min_df=2, stop_words='english', ngram_range=(1, 3))
tfidf = tfidf_vectorizer.fit_transform(documents)
tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()
# Step 2: Fallback if not enough features
if len(tfidf_feature_names) < num_topics:
tfidf_vectorizer = TfidfVectorizer(max_df=1.0, min_df=1, stop_words='english', ngram_range=(1, 3))
tfidf = tfidf_vectorizer.fit_transform(documents)
tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()
if len(tfidf_feature_names) < num_topics:
return None
lda = LatentDirichletAllocation(n_components=num_topics, max_iter=5, learning_method='online', random_state=42, n_jobs=-1)
lda.fit(tfidf)
topic_data_list = []
for topic_idx, topic in enumerate(lda.components_):
top_words_indices = topic.argsort()[:-N - 1:-1]
top_words = [tfidf_feature_names[i] for i in top_words_indices]
word_weights = [topic[i] for i in top_words_indices]
for word, weight in zip(top_words, word_weights):
topic_data_list.append({
'Topic_ID': f'Topic #{topic_idx + 1}',
'Word': word,
'Weight': weight,
})
return pd.DataFrame(topic_data_list)
except Exception as e:
# print(f"Topic Modeling Error: {e}")
return None
def create_topic_word_bubbles(df_topic_data):
"""Generates a Plotly Bubble Chart for top words across all topics."""
df_topic_data = df_topic_data.rename(columns={'Topic_ID': 'topic','Word': 'word', 'Weight': 'weight'})
df_topic_data['x_pos'] = df_topic_data.index
if df_topic_data.empty:
return None
fig = px.scatter(
df_topic_data,
x='x_pos', y='weight', size='weight', color='topic', text='word', hover_name='word', size_max=40,
title='Topic Word Weights (Bubble Chart)',
color_discrete_sequence=px.colors.qualitative.Bold,
labels={'x_pos': 'Entity/Word Index', 'weight': 'Word Weight', 'topic': 'Topic ID'},
custom_data=['word', 'weight', 'topic']
)
fig.update_layout(
xaxis_title="Entity/Word", yaxis_title="Word Weight",
xaxis={'showgrid': False, 'showticklabels': False, 'zeroline': False, 'showline': False},
yaxis={'showgrid': True},
showlegend=True, height=600,
margin=dict(t=50, b=100, l=50, r=10),
plot_bgcolor='#f9f9f9', paper_bgcolor='#f9f9f9'
)
fig.update_traces(
textposition='middle center',
textfont=dict(color='white', size=10),
hovertemplate="<b>%{customdata[0]}</b><br>Weight: %{customdata[1]:.3f}<br>Topic: %{customdata[2]}<extra></extra>",
marker=dict(line=dict(width=1, color='DarkSlateGrey'))
)
return fig
def generate_network_graph(df, raw_text, entity_color_map):
"""
Generates a network graph visualization (Node Plot) with edges based on
entity co-occurrence in sentences.
FIXED: The logic for creating 'unique_entities' is revised to guarantee
that the 'text' column is unique, resolving the ValueError.
"""
# 1. Prepare Data for Nodes
# Calculate frequency (count)
entity_counts = df['text'].value_counts().reset_index()
entity_counts.columns = ['text', 'frequency']
# Sort the dataframe by score descending *before* dropping duplicates to ensure the best score/label is kept
df_sorted = df.sort_values('score', ascending=False).reset_index(drop=True)
# Drop duplicates based on 'text' to guarantee unique entity names for the index
unique_entities_data = df_sorted.drop_duplicates(subset=['text'])[['text', 'label', 'score']]
# Merge the unique data with the frequency counts
unique_entities = unique_entities_data.merge(entity_counts, on='text', how='left')
if unique_entities.shape[0] < 2:
return go.Figure().update_layout(title="Not enough unique entities for a meaningful graph.")
# 2. Node Positioning
num_nodes = len(unique_entities)
thetas = np.linspace(0, 2 * np.pi, num_nodes, endpoint=False)
radius = 10
unique_entities['x'] = radius * np.cos(thetas) + np.random.normal(0, 0.5, num_nodes)
unique_entities['y'] = radius * np.sin(thetas) + np.random.normal(0, 0.5, num_nodes)
# This line now succeeds because 'text' is guaranteed to be unique
pos_map = unique_entities.set_index('text')[['x', 'y']].to_dict('index')
# 3. Edge Calculation (Co-occurrence)
edges = set()
sentences = re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', raw_text)
unique_entity_texts = unique_entities['text'].unique().tolist()
for sentence in sentences:
entities_in_sentence = []
for entity_text in unique_entity_texts:
if entity_text.lower() in sentence.lower():
entities_in_sentence.append(entity_text)
unique_entities_in_sentence = list(set(entities_in_sentence))
for i in range(len(unique_entities_in_sentence)):
for j in range(i + 1, len(unique_entities_in_sentence)):
node1 = unique_entities_in_sentence[i]
node2 = unique_entities_in_sentence[j]
edge_tuple = tuple(sorted((node1, node2)))
edges.add(edge_tuple)
# 4. Plotly Figure Generation
edge_x = []
edge_y = []
for edge in edges:
n1, n2 = edge
if n1 in pos_map and n2 in pos_map:
edge_x.extend([pos_map[n1]['x'], pos_map[n2]['x'], None])
edge_y.extend([pos_map[n1]['y'], pos_map[n2]['y'], None])
fig = go.Figure()
edge_trace = go.Scatter(x=edge_x, y=edge_y, line=dict(width=0.5, color='#888'), hoverinfo='none', mode='lines', name='Co-occurrence Edges', showlegend=False)
fig.add_trace(edge_trace)
fig.add_trace(go.Scatter(
x=unique_entities['x'], y=unique_entities['y'], mode='markers+text', name='Entities', text=unique_entities['text'], textposition="top center", showlegend=False,
marker=dict(
size=unique_entities['frequency'] * 5 + 10,
color=[entity_color_map.get(label, '#cccccc') for label in unique_entities['label']],
line_width=1, line_color='black', opacity=0.9
),
textfont=dict(size=10),
customdata=unique_entities[['label', 'score', 'frequency']],
hovertemplate=("<b>%{text}</b><br>Label: %{customdata[0]}<br>Score: %{customdata[1]:.2f}<br>Frequency: %{customdata[2]}<extra></extra>")
))
# 5. Legend and Layout
legend_traces = []
seen_labels = set()
for index, row in unique_entities.iterrows():
label = row['label']
if label not in seen_labels:
seen_labels.add(label)
color = entity_color_map.get(label, '#cccccc')
legend_traces.append(go.Scatter(x=[None], y=[None], mode='markers', marker=dict(size=10, color=color), name=f"{label.capitalize()}", showlegend=True))
for trace in legend_traces:
fig.add_trace(trace)
fig.update_layout(
title='Entity Co-occurrence Network (Edges = Same Sentence)',
showlegend=True, hovermode='closest',
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-15, 15]),
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False, range=[-15, 15]),
plot_bgcolor='#f9f9f9', paper_bgcolor='#f9f9f9',
margin=dict(t=50, b=10, l=10, r=10), height=600,
annotations=[
dict(
text="When a line is drawn between two nodes (entities), it means those two entities co-occurred in the same sentence at least once.",
xref="paper", yref="paper",
x=0.5, y=0.95, # Position below the title
showarrow=False,
font=dict(size=12, color="gray")
)
]
)
return fig
def generate_entity_csv(df):
"""Generates a CSV file of the extracted entities in an in-memory buffer."""
csv_buffer = BytesIO()
df_export = df[['text', 'label', 'category', 'score', 'start', 'end']]
csv_buffer.write(df_export.to_csv(index=False).encode('utf-8'))
csv_buffer.seek(0)
return csv_buffer
# --- HTML REPORT GENERATION FUNCTION ---
def generate_html_report(df, text_input, elapsed_time, df_topic_data, entity_color_map, report_title="Entity and Topic Analysis Report", branding_html=""):
"""
Generates a full HTML report containing all analysis results and visualizations,
including color gradient styling for the score column in the main table.
"""
# 1. Generate Visualizations (Plotly HTML)
# 1a. Treemap
fig_treemap = px.treemap(
df,
path=[px.Constant("All Entities"), 'category', 'label', 'text'],
values='score',
color='label',
title="Entity Distribution by Category and Label",
color_discrete_sequence=px.colors.qualitative.Bold
)
fig_treemap.update_layout(margin=dict(t=50, l=25, r=25, b=25))
treemap_html = fig_treemap.to_html(full_html=False, include_plotlyjs='cdn')
# 1b. Pie Chart
grouped_counts = df['category'].value_counts().reset_index()
grouped_counts.columns = ['Category', 'Count']
color_seq = px.colors.qualitative.Pastel if len(grouped_counts) > 1 else px.colors.sequential.Cividis
fig_pie = px.pie(grouped_counts, values='Count', names='Category',title='Distribution of Entities by Category',color_discrete_sequence=color_seq)
fig_pie.update_layout(margin=dict(t=50, b=10))
pie_html = fig_pie.to_html(full_html=False, include_plotlyjs='cdn')
# 1c. Bar Chart (Category Count)
fig_bar_category = px.bar(grouped_counts, x='Category', y='Count',color='Category', title='Total Entities per Category',color_discrete_sequence=color_seq)
fig_bar_category.update_layout(xaxis={'categoryorder': 'total descending'},margin=dict(t=50, b=100))
bar_category_html = fig_bar_category.to_html(full_html=False,include_plotlyjs='cdn')
# 1d. Bar Chart (Most Frequent Entities)
word_counts = df['text'].value_counts().reset_index()
word_counts.columns = ['Entity', 'Count']
repeating_entities = word_counts[word_counts['Count'] > 1].head(10)
bar_freq_html = '<p>No entities appear more than once in the text for visualization.</p>'
if not repeating_entities.empty:
fig_bar_freq = px.bar(repeating_entities, x='Entity', y='Count',color='Entity', title='Top 10 Most Frequent Entities',color_discrete_sequence=px.colors.sequential.Viridis)
fig_bar_freq.update_layout(xaxis={'categoryorder': 'total descending'},margin=dict(t=50, b=100))
bar_freq_html = fig_bar_freq.to_html(full_html=False, include_plotlyjs='cdn')
# 1e. Network Graph HTML
network_fig = generate_network_graph(df, text_input, entity_color_map)
network_html = network_fig.to_html(full_html=False, include_plotlyjs='cdn')
# 1f. Topic Modeling Bubble Chart
topic_charts_html = '<h3>Topic Word Weights (Bubble Chart)</h3>'
if df_topic_data is not None and not df_topic_data.empty:
bubble_figure = create_topic_word_bubbles(df_topic_data)
if bubble_figure:
topic_charts_html += f'<div class="chart-box">{bubble_figure.to_html(full_html=False, include_plotlyjs="cdn", config={"responsive": True})}</div>'
else:
topic_charts_html += '<p style="color: red;">Error: Topic modeling data was available but visualization failed.</p>'
else:
topic_charts_html += '<div class="chart-box" style="text-align: center; padding: 50px; background-color: #fff; border: 1px dashed #888888;">'
topic_charts_html += '<p><strong>Topic Modeling requires more unique input.</strong></p>'
topic_charts_html += '<p>Please enter text containing at least two unique entities to generate the Topic Bubble Chart.</p>'
topic_charts_html += '</div>'
# 2. Get Highlighted Text
highlighted_text_html = highlight_entities(text_input, df, entity_color_map).replace("div style", "div class='highlighted-text' style")
# 3. Entity Tables (Pandas to HTML)
# Apply color gradient styling to the DataFrame BEFORE converting to HTML
styled_df = df[['text', 'label', 'score', 'start', 'end', 'category']].style.background_gradient(
cmap='YlGnBu',
subset=['score']
).format({'score': '{:.4f}'})
entity_table_html = styled_df.to_html(
classes='table table-striped',
index=False,
)
# 4. Construct the Final HTML
html_content = f"""<!DOCTYPE html><html lang="en"><head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{report_title}</title>
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<style>
body {{ font-family: 'Inter', sans-serif; margin: 0; padding: 20px; background-color: #f4f4f9; color: #333; }}
.container {{ max-width: 1200px; margin: 0 auto; background-color: #ffffff; padding: 30px; border-radius: 12px; box-shadow: 0 4px 12px rgba(0,0,0,0.1); }}
h1 {{ color: #007bff; border-bottom: 3px solid #007bff; padding-bottom: 10px; margin-top: 0; }}
h2 {{ color: #007bff; margin-top: 30px; border-bottom: 1px solid #ddd; padding-bottom: 5px; }}
h3 {{ color: #555; margin-top: 20px; }}
.metadata {{ background-color: #e6f0ff; padding: 15px; border-radius: 8px; margin-bottom: 20px; font-size: 0.9em; }}
.chart-box {{ background-color: #f9f9f9; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05); min-width: 0; margin-bottom: 20px; }}
table {{ width: 100%; border-collapse: collapse; margin-top: 15px; }}
/* Target the cells generated by pandas styling */
table td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
table th {{ border: 1px solid #ddd; padding: 8px; text-align: left; background-color: #f0f0f0; }}
.highlighted-text {{ border: 1px solid #888888; padding: 15px; border-radius: 5px; background-color: #ffffff; font-family: monospace; white-space: pre-wrap; margin-bottom: 20px; }}
</style>
</head>
<body>
<div class="container">
<h1>{report_title}</h1>
<div class="metadata">
{branding_html}
<p><strong>Generated on:</strong> {time.strftime('%Y-%m-%d')}</p>
<p><strong>Processing Time:</strong> {elapsed_time:.2f} seconds</p>
</div>
<h2>Analyzed Text & Extracted Entities</h2>
<h3>Original Text with Highlighted Entities</h3>
<div class="highlighted-text-container">
{highlighted_text_html}
</div>
<h2>2. Full Extracted Entities Table </h2>
{entity_table_html}
<h2>3. Data Visualizations</h2>
<h3>3.1 Entity Distribution Treemap</h3>
<div class="chart-box">{treemap_html}</div>
<h3>3.2 Comparative Charts (Pie, Category Count, Frequency) - *Stacked Vertically*</h3>
<div class="chart-box">{pie_html}</div>
<div class="chart-box">{bar_category_html}</div>
<h3>3.3 Most Frequent Entities</h3>
<div class="chart-box">{bar_freq_html}</div>
<h3>3.4 Entity Relationship Map (Edges = Same Sentence)</h3>
<div class="chart-box">{network_html}</div>
<h2>4. Topic Modelling</h2>
{topic_charts_html}
</div>
</body>
</html>
"""
return html_content
def chunk_text(text, max_chunk_size=1500):
"""Splits text into chunks by sentence/paragraph, respecting a max size (by character count)."""
segments = re.split(r'(\n\n|(?<=[.!?])\s+)', text)
chunks = []
current_chunk = ""
current_offset = 0
for segment in segments:
if not segment: continue
if len(current_chunk) + len(segment) > max_chunk_size and current_chunk:
chunks.append((current_chunk, current_offset))
current_offset += len(current_chunk)
current_chunk = segment
else:
current_chunk += segment
if current_chunk:
chunks.append((current_chunk, current_offset))
return chunks
def process_chunked_text(text, labels, model):
"""Processes large text in chunks and aggregates/offsets the entities."""
MAX_CHUNK_CHARS = 3500
chunks = chunk_text(text, max_chunk_size=MAX_CHUNK_CHARS)
all_entities = []
for chunk_data, chunk_offset in chunks:
chunk_entities = model.predict_entities(chunk_data, labels)
for entity in chunk_entities:
entity['start'] += chunk_offset
entity['end'] += chunk_offset
all_entities.append(entity)
return all_entities
st.set_page_config(layout="wide", page_title="NER & Topic Report App")
# --- Conditional Mobile Warning CSS ---
st.markdown(
"""
<style>
/* FIX: Aggressive theme override to ensure visibility */
body {
background-color: #f0f2f6 !important;
color: #333333 !important;
}
[data-testid="stAppViewBlock"] {
background-color: #ffffff !important;
}
@media (max-width: 600px) {
#mobile-warning-container {
display: block;
background-color: #ffcccc;
color: #cc0000;
padding: 10px;
border-radius: 5px;
text-align: center;
margin-bottom: 20px;
font-weight: bold;
border: 1px solid #cc0000;
}
}
@media (min-width: 601px) {
#mobile-warning-container {
display: none;
}
}
[data-testid="stConfigurableTabs"] button {
color: #333333 !important;
background-color: #f0f0f0;
border: 1px solid #cccccc;
}
[data-testid="stConfigurableTabs"] button[aria-selected="true"] {
color: #FFFFFF !important;
background-color: #007bff;
border-bottom: 2px solid #007bff;
}
.streamlit-expanderHeader {
color: #007bff;
}
</style>
<div id="mobile-warning-container">
⚠️ **Tip for Mobile Users:** For the best viewing experience of the charts and tables, please switch your browser to **"Desktop Site"** view.
</div>
""",
unsafe_allow_html=True)
st.subheader("Entity and Topic Analysis Report Generator", divider="blue")
tab1, tab2 = st.tabs(["Embed", "Important Notes"])
with tab1:
with st.expander("Embed"):
st.write("Use the following code to embed the DataHarvest web app on your website. Feel free to adjust the width and height values to fit your page.")
code = '''
<iframe
src="https://aiecosystem-dataharvest.hf.space"
frameborder="0"
width="850"
height="450"
></iframe>
'''
st.code(code, language="html")
with tab2:
expander = st.expander("**Important Notes**")
expander.markdown("""
**Named Entities (Fixed Mode):** This DataHarvest web app predicts nine (9) fixed labels: "person", "country", "city", "organization", "date", "time", "cardinal", "money", "position".
**Results:** Results are compiled into a single, comprehensive **HTML report** and a **CSV file** for easy download and sharing.
**How to Use:** Type or paste your text into the text area below, then click the 'Analyze Text' button.
""")
st.markdown("For any errors or inquiries, please contact us at [[email protected]](mailto:[email protected])")
# --- Model Loading ---
@st.cache_resource
def load_ner_model(labels):
"""Loads the GLiNER model and caches it."""
try:
# GLiNER model is loaded with constraints based on the active labels list
return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5", nested_ner=True, num_gen_sequences=2, gen_constraints=labels)
except Exception as e:
# print(f"FATAL ERROR: Failed to load NER model: {e}")
st.error(f"Failed to load NER model. This may be due to a dependency issue or resource limits: {e}")
st.stop()
# --- LONG DEFAULT TEXT ---
DEFAULT_TEXT = (
"In June 2024, the founder, Dr. Emily Carter, officially announced a new, expansive partnership between "
"TechSolutions Inc. and the European Space Agency (ESA). This strategic alliance represents a significant "
"leap forward for commercial space technology across the entire **European Union**. The agreement, finalized "
"on Monday in Paris, France, focuses specifically on jointly developing the next generation of the 'Astra' "
"software platform. This version of the **Astra** platform is critical for processing and managing the vast amounts of data being sent "
"back from the recent Mars rover mission. This project underscores the ESA's commitment to advancing "
"space capabilities within the **European Union**. The core team, including lead engineer Marcus Davies, will hold "
"their first collaborative workshop in Berlin, Germany, on August 15th. The community response on social "
"media platform X (under the username @TechCEO) was overwhelmingly positive, with many major tech "
"publications, including Wired Magazine, predicting a major impact on the space technology industry by the "
"end of the year, further strengthening the technological standing of the **European Union**. The platform is designed to be compatible with both Windows and Linux operating systems. "
"The initial funding, secured via a Series B round, totaled $50 million. Financial analysts from Morgan Stanley "
"are closely monitoring the impact on TechSolutions Inc.'s Q3 financial reports, expected to be released to the "
"general public by October 1st. The goal is to deploy the **Astra** v2 platform before the next solar eclipse event in 2026.")
# -----------------------------------
# --- Session State Initialization (Custom Label Reinstatement) ---
if 'show_results' not in st.session_state: st.session_state.show_results = False
if 'my_text_area' not in st.session_state: st.session_state.my_text_area = DEFAULT_TEXT
if 'last_text' not in st.session_state: st.session_state.last_text = ""
if 'results_df' not in st.session_state: st.session_state.results_df = pd.DataFrame()
if 'elapsed_time' not in st.session_state: st.session_state.elapsed_time = 0.0
if 'topic_results' not in st.session_state: st.session_state.topic_results = None
if 'active_labels_list' not in st.session_state: st.session_state.active_labels_list = FIXED_LABELS
if 'is_custom_mode' not in st.session_state: st.session_state.is_custom_mode = "Fixed Labels" # Re-use for radio
if 'custom_labels_input' not in st.session_state: st.session_state.custom_labels_input = DEFAULT_CUSTOM_LABELS
if 'num_topics_slider' not in st.session_state: st.session_state.num_topics_slider = 5
if 'num_top_words_slider' not in st.session_state: st.session_state.num_top_words_slider = 10
if 'last_num_topics' not in st.session_state: st.session_state.last_num_topics = None
if 'last_num_top_words' not in st.session_state: st.session_state.last_num_top_words = None
if 'last_active_labels' not in st.session_state: st.session_state.last_active_labels = None
def clear_text():
"""Clears the text area (sets it to an empty string) and hides results."""
st.session_state['my_text_area'] = ""
st.session_state.show_results = False
st.session_state.last_text = ""
st.session_state.results_df = pd.DataFrame()
st.session_state.elapsed_time = 0.0
st.session_state.topic_results = None
# --- Revised Text Area Input ---
st.markdown("## ✍️ Text Input for Analysis")
word_limit = 2000
text = st.text_area(
f"Type or paste your text below (max {word_limit} words), and then press Ctrl + Enter",
height=250,
key='my_text_area',
)
word_count = len(text.split())
st.markdown(f"**Word count:** {word_count}/{word_limit}")
# --- Custom/Fixed Label Selector ---
st.markdown("---")
st.markdown("### 🏷️ Entity Label Mode Selection")
mode = st.radio(
"Select Entity Recognition Mode:",
["Fixed Labels", "Custom Labels"],
key='is_custom_mode',
horizontal=True,
help="Fixed Labels use a predefined set. Custom Labels let you define your own."
)
active_labels = []
if mode == "Fixed Labels":
active_labels = FIXED_LABELS
st.info(f"Fixed Labels active: **{', '.join(active_labels)}**")
else:
custom_labels_input = st.text_input(
"Enter your custom labels, separated by commas (e.g., product, feature, ticket_id):",
value=st.session_state.custom_labels_input,
key='custom_labels_input',
help="The labels must be non-empty and comma-separated."
)
# Clean and set active labels from user input
active_labels = [label.strip().lower() for label in custom_labels_input.split(',') if label.strip()]
if not active_labels:
st.error("Please enter at least one custom label.")
active_labels = [] # Prevents model run if empty
else:
st.info(f"Custom Labels active: **{', '.join(active_labels)}**")
st.session_state.active_labels_list = active_labels
current_num_topics = st.session_state.num_topics_slider
current_num_top_words = st.session_state.num_top_words_slider
# --- Buttons ---
col_results, col_clear = st.columns([1, 1])
with col_results:
run_button = st.button("Analyze Text", key='run_results', use_container_width=True, type="primary", disabled=not active_labels)
with col_clear:
st.button("Clear text", on_click=clear_text, use_container_width=True)
# --- Results Trigger and Processing (Fixed for index error) ---
if run_button:
if text.strip() and word_count <= word_limit:
# 1. Determine Active Labels (Already done above, just referencing)
active_labels = st.session_state.active_labels_list
# Caching Logic: Check if we need to re-run the full process
should_rerun_full_analysis = (
text.strip() != st.session_state.last_text.strip() or
set(active_labels) != set(st.session_state.last_active_labels if st.session_state.last_active_labels else [])
)
if should_rerun_full_analysis:
# 2. Rerunning Full Analysis
CHUNKING_THRESHOLD = 500
should_chunk = word_count > CHUNKING_THRESHOLD
mode_msg = "custom labels" if mode == "Custom Labels" else "fixed labels"
if should_chunk:
mode_msg += " with **chunking** for large text"
with st.spinner(f"Analyzing text with {mode_msg}..."):
start_time = time.time()
# 2a. Load Model
model = load_ner_model(active_labels)
# 2b. Extract Entities
if should_chunk:
all_entities = process_chunked_text(text, active_labels, model)
else:
all_entities = model.predict_entities(text, active_labels)
end_time = time.time()
elapsed_time = end_time - start_time
# 2c. Prepare DataFrame
df = pd.DataFrame(all_entities)
if not df.empty:
df = df.reset_index(drop=True)
# --- CATEGORY MAPPING ADJUSTMENT ---
# Assign fixed labels to their categories, and custom labels to 'User Defined'
def map_category(label):
if label in REVERSE_FIXED_CATEGORY_MAPPING:
return REVERSE_FIXED_CATEGORY_MAPPING[label]
elif label in active_labels and label not in FIXED_LABELS:
# This handles any truly custom labels entered by the user
return 'User Defined Entities'
else:
return 'Other'
df['category'] = df['label'].apply(map_category)
df['text'] = df['text'].apply(remove_trailing_punctuation)
# 2d. Perform Topic Modeling on extracted entities
df_topic_data = perform_topic_modeling(df, num_topics=current_num_topics, num_top_words=current_num_top_words)
else:
df_topic_data = None
# 3. Save Results to Session State
st.session_state.results_df = df
st.session_state.topic_results = df_topic_data
st.session_state.elapsed_time = elapsed_time
st.session_state.last_text = text
st.session_state.show_results = True
st.session_state.last_active_labels = active_labels
st.session_state.last_num_topics = current_num_topics
st.session_state.last_num_top_words = current_num_top_words
else:
st.info("Results already calculated for the current text and settings.")
st.session_state.show_results = True
elif word_count > word_limit:
st.error(f"Text too long! Please limit your input to {word_limit} words.")
st.session_state.show_results = False
elif not active_labels:
st.error("Please ensure your custom label input is not empty.")
st.session_state.show_results = False
else:
st.warning("Please enter some text to analyze.")
st.session_state.show_results = False
# --- Display Download Link and Results ---
if st.session_state.show_results:
df = st.session_state.results_df
df_topic_data = st.session_state.topic_results
current_labels_in_df = df['label'].unique().tolist()
entity_color_map = get_dynamic_color_map(current_labels_in_df, FIXED_ENTITY_COLOR_MAP)
if df.empty:
st.warning("No entities were found in the provided text with the current label set.")
else:
st.subheader("1. Analysis Results", divider="blue")
# --- Function to Apply Conditional Coloring to Scores (For Streamlit UI only) ---
def color_score_gradient(df_input):
"""Applies a color gradient to the 'score' column using Pandas Styler."""
return df_input.style.background_gradient(
cmap='YlGnBu',
subset=['score']
).format(
{'score': '{:.4f}'}
)
# 1. Highlighted Text placed inside an Expander
with st.expander(f"### 1. Analyzed Text with Highlighted Entities ({mode} Mode)", expanded=False):
st.markdown(
highlight_entities(st.session_state.last_text, df, entity_color_map),
unsafe_allow_html=True
)
st.markdown(f"**Total Entities Found:** {len(df)}")
# 2. Detailed Entity Analysis Tabs
st.markdown("### 2. Detailed Entity Analysis")
tab_category_details, tab_treemap_viz = st.tabs(["πŸ“‘ Entities Grouped by Category", "πŸ—ΊοΈ Treemap Distribution"])
# --- Section 2a: Detailed Tables by Category/Label ---
with tab_category_details:
st.markdown("#### Detailed Entities Table (Grouped by Category)")
# Get all unique categories present in the data (Fixed + User Defined)
unique_categories = list(df['category'].unique())
# Ensure fixed categories appear first if present, followed by custom/other
ordered_categories = []
# Add fixed categories in defined order
for fixed_cat in FIXED_CATEGORY_MAPPING.keys():
if fixed_cat in unique_categories:
ordered_categories.append(fixed_cat)
unique_categories.remove(fixed_cat)
# Add User Defined and Other at the end
if 'User Defined Entities' in unique_categories:
ordered_categories.append('User Defined Entities')
unique_categories.remove('User Defined Entities')
if 'Other' in unique_categories:
ordered_categories.append('Other')
unique_categories.remove('Other')
# Add any remaining categories (shouldn't happen with map_category, but for safety)
ordered_categories.extend(unique_categories)
tabs_category = st.tabs(ordered_categories)
for category, tab in zip(ordered_categories, tabs_category):
df_category = df[df['category'] == category][['text', 'label', 'score', 'start', 'end']].sort_values(by='score', ascending=False)
styled_df_category = color_score_gradient(df_category)
with tab:
st.markdown(f"##### {category} Entities ({len(df_category)} total)")
if not df_category.empty:
st.dataframe(styled_df_category, use_container_width=True)
else:
st.info(f"No entities of category **{category}** were found in the text.")
with st.expander("See Glossary of tags"):
st.write('''- **text**: ['entity extracted from your text data']- **label**: ['label (tag) assigned to a given extracted entity (custom or fixed)']- **category**: ['the grouping category (e.g., "Locations" or "User Defined Entities")']- **score**: ['accuracy score; how accurately a tag has been assigned to a given entity']- **start**: ['index of the start of the corresponding entity']- **end**: ['index of the end of the corresponding entity']''')
# --- Section 2b: Treemap Visualization ---
with tab_treemap_viz:
st.markdown("#### Treemap: Entity Distribution")
fig_treemap = px.treemap(
df,
path=[px.Constant("All Entities"), 'category', 'label', 'text'],
values='score',
color='label',
color_discrete_sequence=px.colors.qualitative.Bold
)
fig_treemap.update_layout(margin=dict(t=10, l=10, r=10, b=10))
st.plotly_chart(fig_treemap, use_container_width=True)
# 3. Comparative Charts
st.markdown("---")
st.markdown("### 3. Comparative Charts")
col1, col2, col3 = st.columns(3)
grouped_counts = df['category'].value_counts().reset_index()
grouped_counts.columns = ['Category', 'Count']
chart_color_seq = px.colors.qualitative.Pastel if len(grouped_counts) > 1 else px.colors.sequential.Cividis
with col1: # Pie Chart
fig_pie = px.pie(grouped_counts, values='Count', names='Category',title='Distribution of Entities by Category',color_discrete_sequence=chart_color_seq)
fig_pie.update_layout(margin=dict(t=30, b=10, l=10, r=10), height=350)
st.plotly_chart(fig_pie, use_container_width=True)
with col2: # Bar Chart by Category
st.markdown("#### Entity Count by Category")
fig_bar_category = px.bar(grouped_counts, x='Category', y='Count', color='Category', title='Total Entities per Category', color_discrete_sequence=chart_color_seq)
fig_bar_category.update_layout(margin=dict(t=30, b=10, l=10, r=10), height=350, showlegend=False)
st.plotly_chart(fig_bar_category, use_container_width=True)
with col3: # Bar Chart for Most Frequent Entities
st.markdown("#### Top 10 Most Frequent Entities")
word_counts = df['text'].value_counts().reset_index()
word_counts.columns = ['Entity', 'Count']
repeating_entities = word_counts[word_counts['Count'] > 1].head(10)
if not repeating_entities.empty:
fig_bar_freq = px.bar(repeating_entities, x='Entity', y='Count', title='Top 10 Most Frequent Entities', color='Entity', color_discrete_sequence=px.colors.sequential.Viridis)
fig_bar_freq.update_layout(margin=dict(t=30, b=10, l=10, r=10), height=350, showlegend=False)
st.plotly_chart(fig_bar_freq, use_container_width=True)
else:
st.info("No entities were repeated enough for a Top 10 frequency chart.")
# 4. Advanced Analysis
st.markdown("---")
st.markdown("### 4. Advanced Analysis")
# --- A. Network Graph Section ---
with st.expander("πŸ”— Entity Co-occurrence Network Graph", expanded=True):
st.plotly_chart(generate_network_graph(df, st.session_state.last_text, entity_color_map), use_container_width=True)
# --- B. Topic Modeling Section ---
st.markdown("---")
with st.container(border=True):
st.markdown("#### πŸ’‘ Topic Modeling (LDA) Configuration and Results")
st.markdown("Adjust the settings below and click **'Re-Run Topic Model'** to instantly update the visualization based on the extracted entities.")
col_slider_topic, col_slider_words, col_rerun_btn = st.columns([1, 1, 0.5])
with col_slider_topic:
new_num_topics = st.slider(
"Number of Topics",
min_value=2,
max_value=10,
value=st.session_state.num_topics_slider,
step=1,
key='num_topics_slider_new',
help="The number of topics to discover (2 to 10)."
)
with col_slider_words:
new_num_top_words = st.slider(
"Number of Top Words",
min_value=5,
max_value=20,
value=st.session_state.num_top_words_slider,
step=1,
key='num_top_words_slider_new',
help="The number of top words to display per topic (5 to 20)."
)
def rerun_topic_model():
# Update session state with the new slider values
st.session_state.num_topics_slider = st.session_state.num_topics_slider_new
st.session_state.num_top_words_slider = st.session_state.num_top_words_slider_new
if not st.session_state.results_df.empty:
# Recalculate topic modeling results
df_topic_data_new = perform_topic_modeling(
df_entities=st.session_state.results_df,
num_topics=st.session_state.num_topics_slider,
num_top_words=st.session_state.num_top_words_slider
)
st.session_state.topic_results = df_topic_data_new
st.session_state.last_num_topics = st.session_state.num_topics_slider
st.session_state.last_num_top_words = st.session_state.num_top_words_slider
with col_rerun_btn:
st.markdown("<div style='height: 38px;'></div>", unsafe_allow_html=True)
st.button("Re-Run Topic Model", on_click=rerun_topic_model, use_container_width=True, type="primary")
st.markdown("---")
st.markdown(f"""
**Current LDA Parameters:**
* Topics: **{st.session_state.num_topics_slider}**
* Top Words: **{st.session_state.num_top_words_slider}**
""")
df_topic_data = st.session_state.topic_results
if df_topic_data is not None and not df_topic_data.empty:
st.plotly_chart(create_topic_word_bubbles(df_topic_data), use_container_width=True)
st.markdown("This chart visualizes the key words driving the identified topics, based on extracted entities.")
else:
st.info("Topic Modeling requires at least two unique entities with a minimum frequency to perform statistical analysis.")
# 5. White-Label Configuration
st.markdown("---")
st.markdown("### 5. White-Label Report Configuration 🎨")
default_report_title = "Fixed Entity Analysis Report" if mode == "Fixed Labels" else "Custom Entity Analysis Report"
custom_report_title = st.text_input(
"Type Your Report Title (for HTML Report), and then press Enter.",
value=default_report_title
)
custom_branding_text_input = st.text_area(
"Type Your Brand Name or Tagline (Appears below the title in the report), and then press Enter.",
value="Analysis powered by My Own Brand",
key='custom_branding_input',
help="Enter your brand name or a short tagline. This text will be automatically styled and included below the main title."
)
# 6. Downloads
st.markdown("---")
st.markdown("### 6. Downloads")
col_csv, col_html = st.columns(2)
# CSV Download
csv_buffer = generate_entity_csv(df)
with col_csv:
st.download_button(
label="⬇️ Download Entities as CSV",
data=csv_buffer,
file_name="ner_entities_report.csv",
mime="text/csv",
use_container_width=True
)
# HTML Download (Passing custom white-label parameters)
branding_to_pass = f'<p style="font-size: 1.1em; font-weight: 500;">{custom_branding_text_input}</p>'
html_content = generate_html_report(
df,
st.session_state.last_text,
st.session_state.elapsed_time,
df_topic_data,
entity_color_map,
report_title=custom_report_title,
branding_html=branding_to_pass
)
html_bytes = html_content.encode('utf-8')
with col_html:
st.download_button(
label="⬇️ Download Full HTML Report",
data=html_bytes,
file_name="ner_topic_full_report.html",
mime="text/html",
use_container_width=True
)