Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification, pipeline | |
| from difflib import SequenceMatcher | |
| import re | |
| # Load Czech GEC model | |
| print("Loading Czech GEC ByT5 model...") | |
| gec_tokenizer = AutoTokenizer.from_pretrained("ufal/byt5-large-geccc-mate") | |
| gec_model = AutoModelForSeq2SeqLM.from_pretrained("ufal/byt5-large-geccc-mate") | |
| # Check if CUDA is available and move model to GPU if possible | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| gec_model = gec_model.to(device) | |
| print(f"GEC model loaded on {device}") | |
| # Load punctuation model | |
| print("Loading punctuation model...") | |
| punct_tokenizer = AutoTokenizer.from_pretrained("kredor/punctuate-all") | |
| punct_model = AutoModelForTokenClassification.from_pretrained("kredor/punctuate-all") | |
| punct_pipeline = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer, device=0 if torch.cuda.is_available() else -1) | |
| print("Punctuation model loaded!") | |
| def gec_correct(input_text): | |
| """Generate 3 different GEC corrections""" | |
| if not input_text.strip(): | |
| return ["", "", ""] | |
| # Log input length | |
| print(f"[GEC] Input text length: {len(input_text)} characters") | |
| configs = [ | |
| { | |
| "name": "Conservative", | |
| "num_beams": 4, # Standard beam search | |
| "do_sample": False, | |
| "repetition_penalty": 1.0, # No penalty - preserve original | |
| "length_penalty": 1.0, | |
| "no_repeat_ngram_size": 0, # Allow natural repetitions | |
| "early_stopping": True, | |
| "max_new_tokens": 1000 # Sufficient for corrections | |
| }, | |
| { | |
| "name": "Balanced", | |
| "num_beams": 8, # Moderate beams | |
| "do_sample": False, | |
| "repetition_penalty": 1.0, | |
| "length_penalty": 1.0, | |
| "no_repeat_ngram_size": 0, | |
| "early_stopping": True, | |
| "max_new_tokens": 1500 | |
| }, | |
| { | |
| "name": "Exploratory", | |
| "num_beams": 12, # Higher beam search | |
| "do_sample": False, | |
| "repetition_penalty": 1.0, | |
| "length_penalty": 1.0, | |
| "no_repeat_ngram_size": 0, | |
| "early_stopping": False, # Full exploration | |
| "max_new_tokens": 2000 | |
| } | |
| ] | |
| corrections = [] | |
| # Increase max_length for ByT5 (uses byte-level tokenization) | |
| max_input_length = 1024 # Increased from 512 | |
| inputs = gec_tokenizer(input_text, return_tensors="pt", max_length=max_input_length, truncation=True) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| print(f"[GEC] Tokenized input shape: {inputs['input_ids'].shape}") | |
| for config in configs: | |
| with torch.no_grad(): | |
| gen_params = {k: v for k, v in config.items() if k != "name" and k != "num_return_sequences"} | |
| # Handle no_repeat_ngram_size separately since it's specified in config | |
| if "no_repeat_ngram_size" in config: | |
| gen_params["no_repeat_ngram_size"] = config["no_repeat_ngram_size"] | |
| # Use max_length from config instead of max_new_tokens for better control | |
| # ByT5 uses byte-level tokens (1 char ≈ 1 token) | |
| print(f"[GEC] Generating {config['name']}") | |
| # Handle multiple sequences for exploratory | |
| num_return = config.get("num_return_sequences", 1) | |
| # Use max_new_tokens if specified in config | |
| if "max_new_tokens" in config: | |
| gen_params["max_new_tokens"] = config["max_new_tokens"] | |
| outputs = gec_model.generate( | |
| **inputs, | |
| **gen_params | |
| ) | |
| # For exploratory with multiple sequences, pick the best one | |
| if num_return > 1: | |
| best_text = "" | |
| best_length_diff = float('inf') | |
| for output in outputs: | |
| decoded = gec_tokenizer.decode(output, skip_special_tokens=True) | |
| # Pick the one closest to expected length | |
| length_diff = abs(len(decoded) - len(input_text)) | |
| if length_diff < best_length_diff: | |
| best_length_diff = length_diff | |
| best_text = decoded | |
| corrected_text = best_text | |
| else: | |
| corrected_text = gec_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| print(f"[GEC] {config['name']} output length: {len(corrected_text)} characters") | |
| corrections.append(corrected_text) | |
| return corrections | |
| def punct_correct(input_text): | |
| """Generate 3 different punctuation corrections using kredor/punctuate-all""" | |
| if not input_text.strip(): | |
| return ["", "", ""] | |
| corrections = [] | |
| # Process with the punctuation pipeline | |
| # The model expects lowercase input without punctuation | |
| clean_text = input_text.lower() | |
| results = punct_pipeline(clean_text) | |
| # Build a mapping of token positions to punctuation | |
| punct_map = {} | |
| current_word = "" | |
| current_punct = "" | |
| for i, result in enumerate(results): | |
| word = result['word'].replace('▁', '').strip() | |
| # Get punctuation from entity label | |
| entity = result['entity'] | |
| if entity == 'LABEL_0': | |
| punct = '' # No punctuation | |
| elif entity == 'LABEL_1': | |
| punct = '.' | |
| elif entity == 'LABEL_2': | |
| punct = ',' | |
| elif entity == 'LABEL_3': | |
| punct = '?' | |
| elif entity == 'LABEL_4': | |
| punct = '-' | |
| elif entity == 'LABEL_5': | |
| punct = ':' | |
| else: | |
| punct = '' | |
| # Check if this is a continuation of previous word (subword token) | |
| if not result['word'].startswith('▁') and i > 0: | |
| current_word += word | |
| else: | |
| # Save previous word if exists | |
| if current_word: | |
| punct_map[current_word] = current_punct | |
| current_word = word | |
| current_punct = punct | |
| # Don't forget the last word | |
| if current_word: | |
| punct_map[current_word] = current_punct | |
| # Reconstruct text with punctuation | |
| words = clean_text.split() | |
| punctuated_words = [] | |
| for word in words: | |
| # Check if we have punctuation for this word | |
| if word in punct_map and punct_map[word]: | |
| punctuated_words.append(word + punct_map[word]) | |
| else: | |
| punctuated_words.append(word) | |
| # Join words | |
| base_result = ' '.join(punctuated_words) | |
| # Three variations | |
| # 1. Conservative - just punctuation | |
| corrections.append(base_result) | |
| # 2. With first letter and sentence capitalization | |
| sentences = re.split(r'(?<=[.?!])\s+', base_result) | |
| capitalized = ' '.join(s[0].upper() + s[1:] if s else s for s in sentences) | |
| corrections.append(capitalized) | |
| # 3. Clean formatting | |
| clean = capitalized | |
| for p in [',', '.', '?', ':', '!', ';']: | |
| clean = clean.replace(f' {p}', p) | |
| corrections.append(clean) | |
| return corrections | |
| def calculate_similarity(text1, text2): | |
| """Calculate similarity percentage between two texts""" | |
| return round(SequenceMatcher(None, text1, text2).ratio() * 100, 2) | |
| def remove_commas(text): | |
| """Remove all commas from text""" | |
| return text.replace(",", "").replace(",", "") | |
| def process_pipeline(input_text, progress=gr.Progress()): | |
| """Process text through both models to get 9 outputs""" | |
| if not input_text.strip(): | |
| return [""] * 9 + ["Please enter text to process"] | |
| all_outputs = [] | |
| status_text = [] | |
| # Step 1: GEC corrections | |
| progress(0.2, desc="Generating GEC corrections...") | |
| gec_outputs = gec_correct(input_text) | |
| # Step 2: Apply punctuation to each GEC output | |
| for i, gec_text in enumerate(gec_outputs): | |
| progress(0.4 + i*0.2, desc=f"Processing punctuation for GEC variant {i+1}...") | |
| punct_outputs = punct_correct(gec_text) | |
| for j, final_text in enumerate(punct_outputs): | |
| all_outputs.append(final_text) | |
| progress(1.0, desc="Processing complete!") | |
| status = f"✅ Generated 9 correction variants successfully!" | |
| return all_outputs + [status] | |
| def benchmark_pipeline(original_text, corrupted_text, progress=gr.Progress()): | |
| """Benchmark the pipeline against original text""" | |
| if not original_text.strip() or not corrupted_text.strip(): | |
| return [""] * 9 + [""] * 9 + ["Please enter both original and corrupted texts"] | |
| # Process the corrupted text through pipeline | |
| progress(0.1, desc="Processing corrupted text...") | |
| all_outputs = [] | |
| # GEC corrections | |
| progress(0.2, desc="Generating GEC corrections...") | |
| gec_outputs = gec_correct(corrupted_text) | |
| # Apply punctuation to each | |
| for i, gec_text in enumerate(gec_outputs): | |
| progress(0.4 + i*0.15, desc=f"Processing punctuation for variant {i+1}...") | |
| punct_outputs = punct_correct(gec_text) | |
| all_outputs.extend(punct_outputs) | |
| # Calculate similarities | |
| progress(0.9, desc="Calculating similarities...") | |
| similarities = [] | |
| best_score = 0 | |
| best_idx = 0 | |
| for i, output in enumerate(all_outputs): | |
| score = calculate_similarity(original_text, output) | |
| similarities.append(f"{score}%") | |
| if score > best_score: | |
| best_score = score | |
| best_idx = i | |
| # Generate report | |
| gec_names = ["Conservative GEC", "Balanced GEC", "Exploratory GEC"] | |
| punct_names = ["Conservative Punct", "Sentence Boundaries", "Balanced Punct"] | |
| report = f"📊 **Benchmark Results**\n\n" | |
| report += f"**Best Match:** Variant {best_idx + 1} ({gec_names[best_idx // 3]} → {punct_names[best_idx % 3]}) with {best_score}% similarity\n\n" | |
| report += "**All Scores:**\n" | |
| for i in range(9): | |
| gec_idx = i // 3 | |
| punct_idx = i % 3 | |
| report += f"{i+1}. {gec_names[gec_idx]} → {punct_names[punct_idx]}: {similarities[i]}\n" | |
| progress(1.0, desc="Benchmark complete!") | |
| return all_outputs + similarities + [report] | |
| # Create Gradio interface | |
| with gr.Blocks(title="Czech GEC + Punctuation Pipeline", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🇨🇿 Czech Grammar & Punctuation Correction Pipeline | |
| This pipeline combines two powerful models: | |
| 1. **ByT5 Czech GEC** - Corrects grammatical errors | |
| 2. **XLM-RoBERTa Punctuation** - Adds punctuation and capitalization | |
| Each model generates 3 variants, resulting in **9 total output combinations**. | |
| """) | |
| with gr.Tabs(): | |
| with gr.TabItem("🚀 Pipeline Mode"): | |
| gr.Markdown("### Enter Czech text for correction through both models") | |
| with gr.Row(): | |
| pipeline_input = gr.Textbox( | |
| label="Input Czech Text", | |
| placeholder="Zadejte český text s gramatickými chybami a bez interpunkce...", | |
| lines=8, | |
| scale=2 | |
| ) | |
| process_btn = gr.Button("🔄 Process Through Pipeline", variant="primary", size="lg") | |
| gr.Markdown("### 📊 Pipeline Outputs (GEC → Punctuation)") | |
| # Create 3x3 grid for outputs | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Conservative GEC →") | |
| output_1_1 = gr.Textbox(label="→ Conservative Punct", lines=4, interactive=True) | |
| output_1_2 = gr.Textbox(label="→ Sentence Boundaries", lines=4, interactive=True) | |
| output_1_3 = gr.Textbox(label="→ Balanced Punct", lines=4, interactive=True) | |
| with gr.Column(): | |
| gr.Markdown("#### Balanced GEC →") | |
| output_2_1 = gr.Textbox(label="→ Conservative Punct", lines=4, interactive=True) | |
| output_2_2 = gr.Textbox(label="→ Sentence Boundaries", lines=4, interactive=True) | |
| output_2_3 = gr.Textbox(label="→ Balanced Punct", lines=4, interactive=True) | |
| with gr.Column(): | |
| gr.Markdown("#### Exploratory GEC →") | |
| output_3_1 = gr.Textbox(label="→ Conservative Punct", lines=4, interactive=True) | |
| output_3_2 = gr.Textbox(label="→ Sentence Boundaries", lines=4, interactive=True) | |
| output_3_3 = gr.Textbox(label="→ Balanced Punct", lines=4, interactive=True) | |
| pipeline_status = gr.Markdown("") | |
| with gr.TabItem("📊 Benchmark Mode"): | |
| gr.Markdown(""" | |
| ### Benchmark the Pipeline | |
| 1. Paste your **original correct text** | |
| 2. Create a **corrupted version** (remove commas, introduce errors) | |
| 3. Run benchmark to see how well the pipeline recovers the original | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| original_text = gr.Textbox( | |
| label="Original Text (Ground Truth)", | |
| placeholder="Paste the correct Czech text here...", | |
| lines=6 | |
| ) | |
| remove_commas_btn = gr.Button("🗑️ Remove Commas → Copy to Corrupted", size="sm") | |
| with gr.Column(): | |
| corrupted_text = gr.Textbox( | |
| label="Corrupted Text (Input for Pipeline)", | |
| placeholder="Paste or edit the corrupted version here...", | |
| lines=6, | |
| interactive=True | |
| ) | |
| benchmark_btn = gr.Button("📊 Run Benchmark", variant="primary", size="lg") | |
| gr.Markdown("### 🎯 Benchmark Results") | |
| # Outputs with similarity scores | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### Conservative GEC →") | |
| bench_output_1_1 = gr.Textbox(label="→ Conservative Punct", lines=3) | |
| bench_sim_1_1 = gr.Textbox(label="Similarity", lines=1) | |
| bench_output_1_2 = gr.Textbox(label="→ Sentence Boundaries", lines=3) | |
| bench_sim_1_2 = gr.Textbox(label="Similarity", lines=1) | |
| bench_output_1_3 = gr.Textbox(label="→ Balanced Punct", lines=3) | |
| bench_sim_1_3 = gr.Textbox(label="Similarity", lines=1) | |
| with gr.Column(): | |
| gr.Markdown("#### Balanced GEC →") | |
| bench_output_2_1 = gr.Textbox(label="→ Conservative Punct", lines=3) | |
| bench_sim_2_1 = gr.Textbox(label="Similarity", lines=1) | |
| bench_output_2_2 = gr.Textbox(label="→ Sentence Boundaries", lines=3) | |
| bench_sim_2_2 = gr.Textbox(label="Similarity", lines=1) | |
| bench_output_2_3 = gr.Textbox(label="→ Balanced Punct", lines=3) | |
| bench_sim_2_3 = gr.Textbox(label="Similarity", lines=1) | |
| with gr.Column(): | |
| gr.Markdown("#### Exploratory GEC →") | |
| bench_output_3_1 = gr.Textbox(label="→ Conservative Punct", lines=3) | |
| bench_sim_3_1 = gr.Textbox(label="Similarity", lines=1) | |
| bench_output_3_2 = gr.Textbox(label="→ Sentence Boundaries", lines=3) | |
| bench_sim_3_2 = gr.Textbox(label="Similarity", lines=1) | |
| bench_output_3_3 = gr.Textbox(label="→ Balanced Punct", lines=3) | |
| bench_sim_3_3 = gr.Textbox(label="Similarity", lines=1) | |
| benchmark_report = gr.Markdown("") | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| ["včera jsem šel do obchodu a koupil jsem si rohlíky máslo a mléko bylo to levné"], | |
| ["programování je zajímavé učím se každý den něco nového a baví mě to"], | |
| ["česká republika je krásná země ve střední evropě má bohatou historii"] | |
| ], | |
| inputs=pipeline_input, | |
| label="Example inputs (click to try)" | |
| ) | |
| # Event handlers | |
| pipeline_outputs = [ | |
| output_1_1, output_1_2, output_1_3, | |
| output_2_1, output_2_2, output_2_3, | |
| output_3_1, output_3_2, output_3_3, | |
| pipeline_status | |
| ] | |
| process_btn.click( | |
| fn=process_pipeline, | |
| inputs=pipeline_input, | |
| outputs=pipeline_outputs | |
| ) | |
| benchmark_outputs = [ | |
| bench_output_1_1, bench_output_1_2, bench_output_1_3, | |
| bench_output_2_1, bench_output_2_2, bench_output_2_3, | |
| bench_output_3_1, bench_output_3_2, bench_output_3_3, | |
| bench_sim_1_1, bench_sim_1_2, bench_sim_1_3, | |
| bench_sim_2_1, bench_sim_2_2, bench_sim_2_3, | |
| bench_sim_3_1, bench_sim_3_2, bench_sim_3_3, | |
| benchmark_report | |
| ] | |
| benchmark_btn.click( | |
| fn=benchmark_pipeline, | |
| inputs=[original_text, corrupted_text], | |
| outputs=benchmark_outputs | |
| ) | |
| remove_commas_btn.click( | |
| fn=remove_commas, | |
| inputs=original_text, | |
| outputs=corrupted_text | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| **Models:** | |
| - GEC: [ufal/byt5-large-geccc-mate](https://huggingface.co/ufal/byt5-large-geccc-mate) | |
| - Punctuation: [kredor/punctuate-all](https://huggingface.co/kredor/punctuate-all) | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |