asdfasdfdsafdsa commited on
Commit
89b7ad2
Β·
verified Β·
1 Parent(s): 39958b2

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +4 -3
  2. app.py +76 -16
  3. requirements.txt +1 -4
README.md CHANGED
@@ -50,9 +50,10 @@ A comprehensive pipeline that combines grammatical error correction with punctua
50
  - ByT5-large model fine-tuned on Czech GEC corpus
51
  - Handles complex grammatical errors
52
 
53
- - **Punctuation**: [1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase](https://huggingface.co/1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase)
54
- - XLM-RoBERTa for punctuation restoration
55
- - Supports capitalization and sentence boundaries
 
56
 
57
  ## πŸ’‘ Use Cases
58
 
 
50
  - ByT5-large model fine-tuned on Czech GEC corpus
51
  - Handles complex grammatical errors
52
 
53
+ - **Punctuation**: [kredor/punctuate-all](https://huggingface.co/kredor/punctuate-all)
54
+ - Token classification model for punctuation restoration
55
+ - Supports Czech and 11 other languages
56
+ - Adds punctuation marks: . , ? - :
57
 
58
  ## πŸ’‘ Use Cases
59
 
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
- from punctuators.models import PunctCapSegModelONNX
5
  from difflib import SequenceMatcher
6
  import re
7
 
@@ -17,9 +16,9 @@ print(f"GEC model loaded on {device}")
17
 
18
  # Load punctuation model
19
  print("Loading punctuation model...")
20
- punct_model = PunctCapSegModelONNX.from_pretrained(
21
- "1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase"
22
- )
23
  print("Punctuation model loaded!")
24
 
25
  def gec_correct(input_text):
@@ -117,23 +116,84 @@ def gec_correct(input_text):
117
  return corrections
118
 
119
  def punct_correct(input_text):
120
- """Generate 3 different punctuation corrections"""
121
  if not input_text.strip():
122
  return ["", "", ""]
123
 
124
  corrections = []
125
 
126
- # Conservative - no sentence boundaries
127
- result = punct_model.infer(texts=[input_text], apply_sbd=False)
128
- corrections.append(result[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
 
130
- # With sentence boundaries
131
- result = punct_model.infer(texts=[input_text], apply_sbd=True)
132
- corrections.append("\n".join(result[0]) if isinstance(result[0], list) else result[0])
 
133
 
134
- # Balanced
135
- result = punct_model.infer(texts=[input_text], apply_sbd=False)
136
- corrections.append(result[0])
 
 
137
 
138
  return corrections
139
 
@@ -382,7 +442,7 @@ with gr.Blocks(title="Czech GEC + Punctuation Pipeline", theme=gr.themes.Soft())
382
  ---
383
  **Models:**
384
  - GEC: [ufal/byt5-large-geccc-mate](https://huggingface.co/ufal/byt5-large-geccc-mate)
385
- - Punctuation: [1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase](https://huggingface.co/1-800-BAD-CODE/xlm-roberta_punctuation_fullstop_truecase)
386
  """)
387
 
388
  # Launch the app
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTokenClassification, pipeline
 
4
  from difflib import SequenceMatcher
5
  import re
6
 
 
16
 
17
  # Load punctuation model
18
  print("Loading punctuation model...")
19
+ punct_tokenizer = AutoTokenizer.from_pretrained("kredor/punctuate-all")
20
+ punct_model = AutoModelForTokenClassification.from_pretrained("kredor/punctuate-all")
21
+ punct_pipeline = pipeline("token-classification", model=punct_model, tokenizer=punct_tokenizer, device=0 if torch.cuda.is_available() else -1)
22
  print("Punctuation model loaded!")
23
 
24
  def gec_correct(input_text):
 
116
  return corrections
117
 
118
  def punct_correct(input_text):
119
+ """Generate 3 different punctuation corrections using kredor/punctuate-all"""
120
  if not input_text.strip():
121
  return ["", "", ""]
122
 
123
  corrections = []
124
 
125
+ # Process with the punctuation pipeline
126
+ # The model expects lowercase input without punctuation
127
+ clean_text = input_text.lower()
128
+ results = punct_pipeline(clean_text)
129
+
130
+ # Build a mapping of token positions to punctuation
131
+ punct_map = {}
132
+ current_word = ""
133
+ current_punct = ""
134
+
135
+ for i, result in enumerate(results):
136
+ word = result['word'].replace('▁', '').strip()
137
+
138
+ # Get punctuation from entity label
139
+ entity = result['entity']
140
+ if entity == 'LABEL_0':
141
+ punct = '' # No punctuation
142
+ elif entity == 'LABEL_1':
143
+ punct = '.'
144
+ elif entity == 'LABEL_2':
145
+ punct = ','
146
+ elif entity == 'LABEL_3':
147
+ punct = '?'
148
+ elif entity == 'LABEL_4':
149
+ punct = '-'
150
+ elif entity == 'LABEL_5':
151
+ punct = ':'
152
+ else:
153
+ punct = ''
154
+
155
+ # Check if this is a continuation of previous word (subword token)
156
+ if not result['word'].startswith('▁') and i > 0:
157
+ current_word += word
158
+ else:
159
+ # Save previous word if exists
160
+ if current_word:
161
+ punct_map[current_word] = current_punct
162
+ current_word = word
163
+ current_punct = punct
164
+
165
+ # Don't forget the last word
166
+ if current_word:
167
+ punct_map[current_word] = current_punct
168
+
169
+ # Reconstruct text with punctuation
170
+ words = clean_text.split()
171
+ punctuated_words = []
172
+
173
+ for word in words:
174
+ # Check if we have punctuation for this word
175
+ if word in punct_map and punct_map[word]:
176
+ punctuated_words.append(word + punct_map[word])
177
+ else:
178
+ punctuated_words.append(word)
179
+
180
+ # Join words
181
+ base_result = ' '.join(punctuated_words)
182
+
183
+ # Three variations
184
+ # 1. Conservative - just punctuation
185
+ corrections.append(base_result)
186
 
187
+ # 2. With first letter and sentence capitalization
188
+ sentences = re.split(r'(?<=[.?!])\s+', base_result)
189
+ capitalized = ' '.join(s[0].upper() + s[1:] if s else s for s in sentences)
190
+ corrections.append(capitalized)
191
 
192
+ # 3. Clean formatting
193
+ clean = capitalized
194
+ for p in [',', '.', '?', ':', '!', ';']:
195
+ clean = clean.replace(f' {p}', p)
196
+ corrections.append(clean)
197
 
198
  return corrections
199
 
 
442
  ---
443
  **Models:**
444
  - GEC: [ufal/byt5-large-geccc-mate](https://huggingface.co/ufal/byt5-large-geccc-mate)
445
+ - Punctuation: [kredor/punctuate-all](https://huggingface.co/kredor/punctuate-all)
446
  """)
447
 
448
  # Launch the app
requirements.txt CHANGED
@@ -1,6 +1,3 @@
1
  gradio>=4.0.0
2
  torch>=2.0.0
3
- transformers>=4.30.0
4
- punctuators==0.0.7
5
- onnx>=1.14.0
6
- onnxruntime>=1.15.0
 
1
  gradio>=4.0.0
2
  torch>=2.0.0
3
+ transformers>=4.30.0