asdfasdfdsafdsa's picture
Upload 3 files
89b7ad2 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification, pipeline
from difflib import SequenceMatcher
import re
# Load Czech GEC model
print("Loading Czech GEC ByT5 model...")
gec_tokenizer = AutoTokenizer.from_pretrained("ufal/byt5-large-geccc-mate")
gec_model = AutoModelForSeq2SeqLM.from_pretrained("ufal/byt5-large-geccc-mate")
# Check if CUDA is available and move model to GPU if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
gec_model = gec_model.to(device)
print(f"GEC model loaded on {device}")
# Load punctuation model
print("Loading punctuation model...")
punct_tokenizer = AutoTokenizer.from_pretrained("kredor/punctuate-all")
punct_model = AutoModelForTokenClassification.from_pretrained("kredor/punctuate-all")
punct_pipeline = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer, device=0 if torch.cuda.is_available() else -1)
print("Punctuation model loaded!")
def gec_correct(input_text):
"""Generate 3 different GEC corrections"""
if not input_text.strip():
return ["", "", ""]
# Log input length
print(f"[GEC] Input text length: {len(input_text)} characters")
configs = [
{
"name": "Conservative",
"num_beams": 4, # Standard beam search
"do_sample": False,
"repetition_penalty": 1.0, # No penalty - preserve original
"length_penalty": 1.0,
"no_repeat_ngram_size": 0, # Allow natural repetitions
"early_stopping": True,
"max_new_tokens": 1000 # Sufficient for corrections
},
{
"name": "Balanced",
"num_beams": 8, # Moderate beams
"do_sample": False,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"no_repeat_ngram_size": 0,
"early_stopping": True,
"max_new_tokens": 1500
},
{
"name": "Exploratory",
"num_beams": 12, # Higher beam search
"do_sample": False,
"repetition_penalty": 1.0,
"length_penalty": 1.0,
"no_repeat_ngram_size": 0,
"early_stopping": False, # Full exploration
"max_new_tokens": 2000
}
]
corrections = []
# Increase max_length for ByT5 (uses byte-level tokenization)
max_input_length = 1024 # Increased from 512
inputs = gec_tokenizer(input_text, return_tensors="pt", max_length=max_input_length, truncation=True)
inputs = {k: v.to(device) for k, v in inputs.items()}
print(f"[GEC] Tokenized input shape: {inputs['input_ids'].shape}")
for config in configs:
with torch.no_grad():
gen_params = {k: v for k, v in config.items() if k != "name" and k != "num_return_sequences"}
# Handle no_repeat_ngram_size separately since it's specified in config
if "no_repeat_ngram_size" in config:
gen_params["no_repeat_ngram_size"] = config["no_repeat_ngram_size"]
# Use max_length from config instead of max_new_tokens for better control
# ByT5 uses byte-level tokens (1 char ≈ 1 token)
print(f"[GEC] Generating {config['name']}")
# Handle multiple sequences for exploratory
num_return = config.get("num_return_sequences", 1)
# Use max_new_tokens if specified in config
if "max_new_tokens" in config:
gen_params["max_new_tokens"] = config["max_new_tokens"]
outputs = gec_model.generate(
**inputs,
**gen_params
)
# For exploratory with multiple sequences, pick the best one
if num_return > 1:
best_text = ""
best_length_diff = float('inf')
for output in outputs:
decoded = gec_tokenizer.decode(output, skip_special_tokens=True)
# Pick the one closest to expected length
length_diff = abs(len(decoded) - len(input_text))
if length_diff < best_length_diff:
best_length_diff = length_diff
best_text = decoded
corrected_text = best_text
else:
corrected_text = gec_tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"[GEC] {config['name']} output length: {len(corrected_text)} characters")
corrections.append(corrected_text)
return corrections
def punct_correct(input_text):
"""Generate 3 different punctuation corrections using kredor/punctuate-all"""
if not input_text.strip():
return ["", "", ""]
corrections = []
# Process with the punctuation pipeline
# The model expects lowercase input without punctuation
clean_text = input_text.lower()
results = punct_pipeline(clean_text)
# Build a mapping of token positions to punctuation
punct_map = {}
current_word = ""
current_punct = ""
for i, result in enumerate(results):
word = result['word'].replace('▁', '').strip()
# Get punctuation from entity label
entity = result['entity']
if entity == 'LABEL_0':
punct = '' # No punctuation
elif entity == 'LABEL_1':
punct = '.'
elif entity == 'LABEL_2':
punct = ','
elif entity == 'LABEL_3':
punct = '?'
elif entity == 'LABEL_4':
punct = '-'
elif entity == 'LABEL_5':
punct = ':'
else:
punct = ''
# Check if this is a continuation of previous word (subword token)
if not result['word'].startswith('▁') and i > 0:
current_word += word
else:
# Save previous word if exists
if current_word:
punct_map[current_word] = current_punct
current_word = word
current_punct = punct
# Don't forget the last word
if current_word:
punct_map[current_word] = current_punct
# Reconstruct text with punctuation
words = clean_text.split()
punctuated_words = []
for word in words:
# Check if we have punctuation for this word
if word in punct_map and punct_map[word]:
punctuated_words.append(word + punct_map[word])
else:
punctuated_words.append(word)
# Join words
base_result = ' '.join(punctuated_words)
# Three variations
# 1. Conservative - just punctuation
corrections.append(base_result)
# 2. With first letter and sentence capitalization
sentences = re.split(r'(?<=[.?!])\s+', base_result)
capitalized = ' '.join(s[0].upper() + s[1:] if s else s for s in sentences)
corrections.append(capitalized)
# 3. Clean formatting
clean = capitalized
for p in [',', '.', '?', ':', '!', ';']:
clean = clean.replace(f' {p}', p)
corrections.append(clean)
return corrections
def calculate_similarity(text1, text2):
"""Calculate similarity percentage between two texts"""
return round(SequenceMatcher(None, text1, text2).ratio() * 100, 2)
def remove_commas(text):
"""Remove all commas from text"""
return text.replace(",", "").replace(",", "")
def process_pipeline(input_text, progress=gr.Progress()):
"""Process text through both models to get 9 outputs"""
if not input_text.strip():
return [""] * 9 + ["Please enter text to process"]
all_outputs = []
status_text = []
# Step 1: GEC corrections
progress(0.2, desc="Generating GEC corrections...")
gec_outputs = gec_correct(input_text)
# Step 2: Apply punctuation to each GEC output
for i, gec_text in enumerate(gec_outputs):
progress(0.4 + i*0.2, desc=f"Processing punctuation for GEC variant {i+1}...")
punct_outputs = punct_correct(gec_text)
for j, final_text in enumerate(punct_outputs):
all_outputs.append(final_text)
progress(1.0, desc="Processing complete!")
status = f"✅ Generated 9 correction variants successfully!"
return all_outputs + [status]
def benchmark_pipeline(original_text, corrupted_text, progress=gr.Progress()):
"""Benchmark the pipeline against original text"""
if not original_text.strip() or not corrupted_text.strip():
return [""] * 9 + [""] * 9 + ["Please enter both original and corrupted texts"]
# Process the corrupted text through pipeline
progress(0.1, desc="Processing corrupted text...")
all_outputs = []
# GEC corrections
progress(0.2, desc="Generating GEC corrections...")
gec_outputs = gec_correct(corrupted_text)
# Apply punctuation to each
for i, gec_text in enumerate(gec_outputs):
progress(0.4 + i*0.15, desc=f"Processing punctuation for variant {i+1}...")
punct_outputs = punct_correct(gec_text)
all_outputs.extend(punct_outputs)
# Calculate similarities
progress(0.9, desc="Calculating similarities...")
similarities = []
best_score = 0
best_idx = 0
for i, output in enumerate(all_outputs):
score = calculate_similarity(original_text, output)
similarities.append(f"{score}%")
if score > best_score:
best_score = score
best_idx = i
# Generate report
gec_names = ["Conservative GEC", "Balanced GEC", "Exploratory GEC"]
punct_names = ["Conservative Punct", "Sentence Boundaries", "Balanced Punct"]
report = f"📊 **Benchmark Results**\n\n"
report += f"**Best Match:** Variant {best_idx + 1} ({gec_names[best_idx // 3]}{punct_names[best_idx % 3]}) with {best_score}% similarity\n\n"
report += "**All Scores:**\n"
for i in range(9):
gec_idx = i // 3
punct_idx = i % 3
report += f"{i+1}. {gec_names[gec_idx]}{punct_names[punct_idx]}: {similarities[i]}\n"
progress(1.0, desc="Benchmark complete!")
return all_outputs + similarities + [report]
# Create Gradio interface
with gr.Blocks(title="Czech GEC + Punctuation Pipeline", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🇨🇿 Czech Grammar & Punctuation Correction Pipeline
This pipeline combines two powerful models:
1. **ByT5 Czech GEC** - Corrects grammatical errors
2. **XLM-RoBERTa Punctuation** - Adds punctuation and capitalization
Each model generates 3 variants, resulting in **9 total output combinations**.
""")
with gr.Tabs():
with gr.TabItem("🚀 Pipeline Mode"):
gr.Markdown("### Enter Czech text for correction through both models")
with gr.Row():
pipeline_input = gr.Textbox(
label="Input Czech Text",
placeholder="Zadejte český text s gramatickými chybami a bez interpunkce...",
lines=8,
scale=2
)
process_btn = gr.Button("🔄 Process Through Pipeline", variant="primary", size="lg")
gr.Markdown("### 📊 Pipeline Outputs (GEC → Punctuation)")
# Create 3x3 grid for outputs
with gr.Row():
with gr.Column():
gr.Markdown("#### Conservative GEC →")
output_1_1 = gr.Textbox(label="→ Conservative Punct", lines=4, interactive=True)
output_1_2 = gr.Textbox(label="→ Sentence Boundaries", lines=4, interactive=True)
output_1_3 = gr.Textbox(label="→ Balanced Punct", lines=4, interactive=True)
with gr.Column():
gr.Markdown("#### Balanced GEC →")
output_2_1 = gr.Textbox(label="→ Conservative Punct", lines=4, interactive=True)
output_2_2 = gr.Textbox(label="→ Sentence Boundaries", lines=4, interactive=True)
output_2_3 = gr.Textbox(label="→ Balanced Punct", lines=4, interactive=True)
with gr.Column():
gr.Markdown("#### Exploratory GEC →")
output_3_1 = gr.Textbox(label="→ Conservative Punct", lines=4, interactive=True)
output_3_2 = gr.Textbox(label="→ Sentence Boundaries", lines=4, interactive=True)
output_3_3 = gr.Textbox(label="→ Balanced Punct", lines=4, interactive=True)
pipeline_status = gr.Markdown("")
with gr.TabItem("📊 Benchmark Mode"):
gr.Markdown("""
### Benchmark the Pipeline
1. Paste your **original correct text**
2. Create a **corrupted version** (remove commas, introduce errors)
3. Run benchmark to see how well the pipeline recovers the original
""")
with gr.Row():
with gr.Column():
original_text = gr.Textbox(
label="Original Text (Ground Truth)",
placeholder="Paste the correct Czech text here...",
lines=6
)
remove_commas_btn = gr.Button("🗑️ Remove Commas → Copy to Corrupted", size="sm")
with gr.Column():
corrupted_text = gr.Textbox(
label="Corrupted Text (Input for Pipeline)",
placeholder="Paste or edit the corrupted version here...",
lines=6,
interactive=True
)
benchmark_btn = gr.Button("📊 Run Benchmark", variant="primary", size="lg")
gr.Markdown("### 🎯 Benchmark Results")
# Outputs with similarity scores
with gr.Row():
with gr.Column():
gr.Markdown("#### Conservative GEC →")
bench_output_1_1 = gr.Textbox(label="→ Conservative Punct", lines=3)
bench_sim_1_1 = gr.Textbox(label="Similarity", lines=1)
bench_output_1_2 = gr.Textbox(label="→ Sentence Boundaries", lines=3)
bench_sim_1_2 = gr.Textbox(label="Similarity", lines=1)
bench_output_1_3 = gr.Textbox(label="→ Balanced Punct", lines=3)
bench_sim_1_3 = gr.Textbox(label="Similarity", lines=1)
with gr.Column():
gr.Markdown("#### Balanced GEC →")
bench_output_2_1 = gr.Textbox(label="→ Conservative Punct", lines=3)
bench_sim_2_1 = gr.Textbox(label="Similarity", lines=1)
bench_output_2_2 = gr.Textbox(label="→ Sentence Boundaries", lines=3)
bench_sim_2_2 = gr.Textbox(label="Similarity", lines=1)
bench_output_2_3 = gr.Textbox(label="→ Balanced Punct", lines=3)
bench_sim_2_3 = gr.Textbox(label="Similarity", lines=1)
with gr.Column():
gr.Markdown("#### Exploratory GEC →")
bench_output_3_1 = gr.Textbox(label="→ Conservative Punct", lines=3)
bench_sim_3_1 = gr.Textbox(label="Similarity", lines=1)
bench_output_3_2 = gr.Textbox(label="→ Sentence Boundaries", lines=3)
bench_sim_3_2 = gr.Textbox(label="Similarity", lines=1)
bench_output_3_3 = gr.Textbox(label="→ Balanced Punct", lines=3)
bench_sim_3_3 = gr.Textbox(label="Similarity", lines=1)
benchmark_report = gr.Markdown("")
# Examples
gr.Examples(
examples=[
["včera jsem šel do obchodu a koupil jsem si rohlíky máslo a mléko bylo to levné"],
["programování je zajímavé učím se každý den něco nového a baví mě to"],
["česká republika je krásná země ve střední evropě má bohatou historii"]
],
inputs=pipeline_input,
label="Example inputs (click to try)"
)
# Event handlers
pipeline_outputs = [
output_1_1, output_1_2, output_1_3,
output_2_1, output_2_2, output_2_3,
output_3_1, output_3_2, output_3_3,
pipeline_status
]
process_btn.click(
fn=process_pipeline,
inputs=pipeline_input,
outputs=pipeline_outputs
)
benchmark_outputs = [
bench_output_1_1, bench_output_1_2, bench_output_1_3,
bench_output_2_1, bench_output_2_2, bench_output_2_3,
bench_output_3_1, bench_output_3_2, bench_output_3_3,
bench_sim_1_1, bench_sim_1_2, bench_sim_1_3,
bench_sim_2_1, bench_sim_2_2, bench_sim_2_3,
bench_sim_3_1, bench_sim_3_2, bench_sim_3_3,
benchmark_report
]
benchmark_btn.click(
fn=benchmark_pipeline,
inputs=[original_text, corrupted_text],
outputs=benchmark_outputs
)
remove_commas_btn.click(
fn=remove_commas,
inputs=original_text,
outputs=corrupted_text
)
gr.Markdown("""
---
**Models:**
- GEC: [ufal/byt5-large-geccc-mate](https://huggingface.co/ufal/byt5-large-geccc-mate)
- Punctuation: [kredor/punctuate-all](https://huggingface.co/kredor/punctuate-all)
""")
# Launch the app
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)