import re import json import time import math from decimal import Decimal, InvalidOperation from collections import Counter import numpy as np import torch from PIL import Image from datasets import load_dataset from transformers import DonutProcessor, VisionEncoderDecoderModel from sklearn.linear_model import LogisticRegression from tqdm.auto import tqdm import gradio as gr import spaces import easyocr # ============================================================ # Global config # ============================================================ DONUT_MODEL_ID = "naver-clova-ix/donut-base-finetuned-cord-v2" DONUT_TASK_PROMPT = "" TOTAL_KEYWORDS = [ "total", "amount due", "grand total", "total due", "balance due", "amount", "total price", "totalprice", ] DEFAULT_MAX_TRAIN_SAMPLES = 120 DEFAULT_MAX_TEST_SAMPLES = 50 # 用來過濾不合理的巨大金額(位數太多視為無效) MAX_DIGITS_FOR_AMOUNT = 9 # 例如 999,999,999 MAX_ABS_AMOUNT_FOR_METRIC = 1e9 # 評估時的安全上限 # ============================================================ # Amount normalization and numeric conversion # ============================================================ def normalize_amount_str(s: str): if s is None: return None s = str(s).strip() if not s: return None # 去掉空白(含非斷行空白) s = s.replace(" ", "").replace("\u00a0", "") # 僅保留數字、小數點、逗號、負號 s = re.sub(r"[^\d,.\-]", "", s) # 若沒有任何數字,視為無效 digits = re.sub(r"\D", "", s) if len(digits) == 0: return None # 若數字位數太長(極端離譜),視為無效值,避免爆炸 if len(digits) > MAX_DIGITS_FOR_AMOUNT: return None # 處理逗號 / 小數點 if "," in s and "." in s: # 同時存在時,把逗號當千分位 s = s.replace(",", "") elif "," in s and "." not in s: # 只有逗號時,判斷是千分位還是小數點 last = s.split(",")[-1] if len(last) == 3: s = s.replace(",", "") else: s = s.replace(",", ".") # 多個小數點時,只保留第一個 if s.count(".") > 1: first = s.find(".") s = s[: first + 1] + s[first + 1 :].replace(".", "") return s or None def amount_to_float(s: str): s_norm = normalize_amount_str(s) if not s_norm: return None try: val = float(Decimal(s_norm)) except InvalidOperation: return None # 再做一次安全檢查:過大的數值直接視為無效 if not math.isfinite(val): return None if abs(val) > MAX_ABS_AMOUNT_FOR_METRIC: return None return val # ============================================================ # Character-level F1 for numeric strings # ============================================================ def char_f1_for_pair(gt: str, pred: str) -> float: gt_norm = normalize_amount_str(gt) pred_norm = normalize_amount_str(pred) if not gt_norm or not pred_norm: return 0.0 c_gt = Counter(gt_norm) c_pred = Counter(pred_norm) overlap = 0 for ch in set(c_gt.keys()) | set(c_pred.keys()): overlap += min(c_gt.get(ch, 0), c_pred.get(ch, 0)) if len(pred_norm) == 0 or len(gt_norm) == 0: return 0.0 precision = overlap / len(pred_norm) recall = overlap / len(gt_norm) if precision + recall == 0: return 0.0 return 2 * precision * recall / (precision + recall) # ============================================================ # Evaluation metrics: EM, error stats, F1 # ============================================================ def evaluate_amount_predictions(y_true, y_pred, relaxed_tol=0.01): """ y_true, y_pred: list of strings (amounts) Returns: { "n_samples": ..., "strict_em": ..., "relaxed_em": ..., "mae": ..., "rmse": ..., "mape": ..., "char_f1": ... } """ strict_correct = 0 relaxed_correct = 0 total = 0 numeric_gt = [] numeric_pred = [] char_f1_list = [] for gt, pred in zip(y_true, y_pred): if gt is None: continue total += 1 # 字元層級 F1 if pred is None: char_f1_list.append(0.0) else: char_f1_list.append(char_f1_for_pair(gt, pred)) # strict EM if pred is not None and normalize_amount_str(gt) == normalize_amount_str(pred): strict_correct += 1 relaxed_correct += 1 gt_val = amount_to_float(gt) pred_val = amount_to_float(pred) if gt_val is not None and pred_val is not None: numeric_gt.append(gt_val) numeric_pred.append(pred_val) continue # relaxed EM based on numeric closeness gt_val = amount_to_float(gt) pred_val = amount_to_float(pred) if pred is not None else None # 僅在數值合理時才納入 numeric 誤差 if ( gt_val is not None and pred_val is not None and math.isfinite(pred_val) and abs(pred_val) <= MAX_ABS_AMOUNT_FOR_METRIC ): numeric_gt.append(gt_val) numeric_pred.append(pred_val) base = max(1.0, abs(gt_val)) if abs(gt_val - pred_val) <= relaxed_tol * base: relaxed_correct += 1 strict_em = strict_correct / total if total > 0 else 0.0 relaxed_em = relaxed_correct / total if total > 0 else 0.0 char_f1 = sum(char_f1_list) / total if total > 0 else 0.0 if numeric_gt: diffs = [abs(a - b) for a, b in zip(numeric_gt, numeric_pred)] mae = float(sum(diffs) / len(diffs)) rmse = float( math.sqrt( sum((a - b) ** 2 for a, b in zip(numeric_gt, numeric_pred)) / len(numeric_gt) ) ) mape = float( sum( abs(a - b) / max(1.0, abs(a)) for a, b in zip(numeric_gt, numeric_pred) ) / len(numeric_gt) ) else: mae = 0.0 rmse = 0.0 mape = 0.0 return { "n_samples": total, "strict_em": strict_em, "relaxed_em": relaxed_em, "mae": mae, "rmse": rmse, "mape": mape, "char_f1": char_f1, } # ============================================================ # Ground truth extraction from CORD v2 # ============================================================ def get_gt_total_from_cord_item(item): """ Prefer gt["gt_parse"]["total"]["total_price"], fallback to search. """ gt = item["ground_truth"] if not isinstance(gt, dict): try: gt = json.loads(gt) except Exception: return None # preferred path try: val = gt.get("gt_parse", {}).get("total", {}).get("total_price", None) if val is not None: return val except Exception: pass # fallback: recursive search for key "total_price" def _find_total_price(node): if isinstance(node, dict): for k, v in node.items(): if k == "total_price" and isinstance(v, str): return v found = _find_total_price(v) if found is not None: return found elif isinstance(node, list): for item_ in node: found = _find_total_price(item_) if found is not None: return found return None return _find_total_price(gt) # ============================================================ # Load dataset and Donut model # ============================================================ print("Loading CORD v2 dataset ...") cord = load_dataset("naver-clova-ix/cord-v2") print(cord) print("Loading Donut model and processor (CPU init) ...") processor = DonutProcessor.from_pretrained(DONUT_MODEL_ID) model = VisionEncoderDecoderModel.from_pretrained(DONUT_MODEL_ID) # ============================================================ # Donut prediction # ============================================================ def donut_predict_total(image: Image.Image, device: str = "cpu"): decoder_input_ids = processor.tokenizer( DONUT_TASK_PROMPT, add_special_tokens=False, return_tensors="pt", ).input_ids pixel_values = processor(image, return_tensors="pt").pixel_values with torch.no_grad(): outputs = model.generate( pixel_values.to(device), decoder_input_ids=decoder_input_ids.to(device), max_length=model.decoder.config.max_position_embeddings, pad_token_id=processor.tokenizer.pad_token_id, eos_token_id=processor.tokenizer.eos_token_id, use_cache=True, bad_words_ids=[[processor.tokenizer.unk_token_id]], return_dict_in_generate=True, ) seq = processor.batch_decode(outputs.sequences)[0] seq = seq.replace(processor.tokenizer.eos_token, "").replace( processor.tokenizer.pad_token, "" ) seq = re.sub(r"<.*?>", "", seq, count=1).strip() try: pred_json = processor.token2json(seq) except Exception: return None try: return pred_json.get("total", {}).get("total_price") except Exception: return None # ============================================================ # OCR candidates via EasyOCR # ============================================================ def extract_ocr_candidates(image: Image.Image, reader): """ Use EasyOCR to get (bbox, text, conf), each detection as a line, then extract numeric candidates from the text. """ img_rgb = np.array(image.convert("RGB")) results = reader.readtext(img_rgb) if not results: return [] line_objs = [] for bbox, text, conf in results: text = str(text).strip() if not text: continue ys = [p[1] for p in bbox] top = min(ys) height = max(ys) - min(ys) if height <= 0: height = 1 line_objs.append( { "text": text, "top": top, "height": height, } ) h = img_rgb.shape[0] num_pattern = re.compile(r"[-+]?\d[\d,.\s]*\d") candidates = [] for line in line_objs: text = line["text"] lower = text.lower() contains_kw = any(kw in lower for kw in TOTAL_KEYWORDS) matches = list(num_pattern.finditer(text)) if not matches: continue for mi, m in enumerate(matches): amount_str = m.group(0) is_last = 1 if mi == len(matches) - 1 else 0 line_center_y = line["top"] + line["height"] / 2.0 line_y_norm = line_center_y / float(h) candidates.append( { "amount_str": amount_str, "value": amount_to_float(amount_str), "line_text": text, "line_y_norm": line_y_norm, "contains_total_kw": int(contains_kw), "is_last_num_in_line": int(is_last), "is_largest_in_page": 0, } ) valid_values = [c["value"] for c in candidates if c["value"] is not None] if valid_values: max_val = max(valid_values) for c in candidates: if c["value"] is not None and c["value"] == max_val: c["is_largest_in_page"] = 1 return candidates def candidate_to_feature_vec(cand): v = cand.get("value", None) if v is None or not isinstance(v, (int, float)) or not math.isfinite(v): v = 0.0 # clip amount for feature stability v_clipped = max(-1e7, min(1e7, v)) v_abs = abs(v) v_log = math.log1p(v_abs) if v < 0: v_log = -v_log line_y_norm = float(cand["line_y_norm"]) length_digits = float(len(re.sub(r"\D", "", str(cand["amount_str"])))) return [ float(v_clipped), float(v_log), line_y_norm, float(1.0 - line_y_norm), float(cand["contains_total_kw"]), float(cand["is_last_num_in_line"]), float(cand["is_largest_in_page"]), length_digits, ] # ============================================================ # Train logistic ranker with time budget # ============================================================ def train_logistic_ranker_on_cord( train_dataset, reader, max_samples=200, time_budget_sec=None, ): """ Train Logistic Regression ranker. Returns (clf, n_used_docs). """ X = [] y = [] n_used_docs = 0 start_time = time.time() for item in tqdm(train_dataset, desc="Building ranker training set"): if time_budget_sec is not None and (time.time() - start_time) > time_budget_sec: print("Time budget exceeded in ranker training, stop early.") break if n_used_docs >= max_samples: break gt_total = get_gt_total_from_cord_item(item) if gt_total is None: continue gt_val = amount_to_float(gt_total) if gt_val is None: continue image = item["image"].convert("RGB") cands = extract_ocr_candidates(image, reader) if not cands: continue has_positive = False for cand in cands: cand_val = cand["value"] if cand_val is None: label = 0 else: base = max(1.0, abs(gt_val)) label = int(abs(cand_val - gt_val) <= 0.01 * base) if label == 1: has_positive = True X.append(candidate_to_feature_vec(cand)) y.append(label) if has_positive: n_used_docs += 1 if len(X) == 0: print("No OCR candidates collected, skip LogisticRegression training (use fallback).") return None, n_used_docs X = np.array(X, dtype=np.float32) y = np.array(y, dtype=int) print("Training samples (candidates):", X.shape, "positives:", y.sum()) print("Training docs actually used:", n_used_docs) if len(np.unique(y)) < 2: print("Warning: not enough positive/negative samples to train LogisticRegression (use fallback).") return None, n_used_docs clf = LogisticRegression( max_iter=1000, class_weight="balanced", n_jobs=-1, solver="lbfgs", ) clf.fit(X, y) return clf, n_used_docs # ============================================================ # OCR ranker prediction # ============================================================ def ocr_ranker_predict_total(image: Image.Image, reader, clf: LogisticRegression): cands = extract_ocr_candidates(image, reader) if not cands: return None if clf is None: kw_cands = [c for c in cands if c["contains_total_kw"] == 1] use_cands = kw_cands if kw_cands else cands use_cands = sorted( use_cands, key=lambda x: ( x["contains_total_kw"], x["is_last_num_in_line"], x["line_y_norm"], ), reverse=True, ) return use_cands[0]["amount_str"] X = np.array([candidate_to_feature_vec(c) for c in cands], dtype=np.float32) probs = clf.predict_proba(X)[:, 1] best_idx = int(np.argmax(probs)) return cands[best_idx]["amount_str"] # ============================================================ # Evaluate both models with time budget and latency stats # ============================================================ def compare_models_on_cord_test( test_dataset, reader, ocr_clf, max_samples=100, donut_device: str = "cuda", time_budget_sec=None, ): donut_preds = [] ocr_preds = [] gts = [] used = 0 start_time = time.time() donut_time_total = 0.0 ocr_time_total = 0.0 for item in tqdm(test_dataset, desc="Evaluating models on test set"): if time_budget_sec is not None and (time.time() - start_time) > time_budget_sec: print("Time budget exceeded in evaluation, stop early.") break if used >= max_samples: break gt_total = get_gt_total_from_cord_item(item) if gt_total is None: continue image = item["image"].convert("RGB") # Donut latency try: t0 = time.time() pred_donut = donut_predict_total(image, device=donut_device) donut_time_total += time.time() - t0 except Exception: pred_donut = None # OCR+Ranker latency try: t1 = time.time() pred_ocr = ocr_ranker_predict_total(image, reader, ocr_clf) ocr_time_total += time.time() - t1 except Exception: pred_ocr = None gts.append(gt_total) donut_preds.append(pred_donut) ocr_preds.append(pred_ocr) used += 1 donut_metrics = evaluate_amount_predictions(gts, donut_preds) ocr_metrics = evaluate_amount_predictions(gts, ocr_preds) timing = { "n_test_used": used, "eval_wall_time": time.time() - start_time, "donut_time_total": donut_time_total, "ocr_time_total": ocr_time_total, } return { "gts": gts, "donut_preds": donut_preds, "ocr_preds": ocr_preds, "donut_metrics": donut_metrics, "ocr_metrics": ocr_metrics, "timing": timing, } # ============================================================ # ZeroGPU experiment function with extended time budgets # ============================================================ @spaces.GPU(duration=200) def run_experiment(train_samples, test_samples): train_samples = int(train_samples) test_samples = int(test_samples) t_global_start = time.time() # Donut + EasyOCR on GPU if available if torch.cuda.is_available(): device = "cuda" model.to(device) reader = easyocr.Reader(["en"], gpu=True) else: device = "cpu" model.to(device) reader = easyocr.Reader(["en"], gpu=False) # time budget inside this GPU call TRAIN_TIME_BUDGET = 200.0 # 原本 40 的 5 倍 EVAL_TIME_BUDGET = 60.0 # 評估仍使用 60 秒,可自行調整 # 1) train ranker t_train_start = time.time() clf, n_train_used = train_logistic_ranker_on_cord( cord["train"], reader=reader, max_samples=train_samples, time_budget_sec=TRAIN_TIME_BUDGET, ) t_train_end = time.time() train_time = t_train_end - t_train_start # 2) evaluate both models t_eval_start = time.time() results = compare_models_on_cord_test( cord["test"], reader=reader, ocr_clf=clf, max_samples=test_samples, donut_device=device, time_budget_sec=EVAL_TIME_BUDGET, ) t_eval_end = time.time() eval_time = t_eval_end - t_eval_start t_global_end = time.time() total_time = t_global_end - t_global_start dm = results["donut_metrics"] om = results["ocr_metrics"] timing = results["timing"] n_test_used = timing["n_test_used"] donut_time_total = timing["donut_time_total"] ocr_time_total = timing["ocr_time_total"] avg_donut_latency = donut_time_total / n_test_used if n_test_used > 0 else 0.0 avg_ocr_latency = ocr_time_total / n_test_used if n_test_used > 0 else 0.0 lines = [] lines.append("Donut vs OCR+Regex+Ranker on CORD v2") lines.append("-----------------------------------") lines.append(f"Train samples requested : {train_samples}") lines.append(f"Train docs actually used: {n_train_used}") lines.append(f"Test samples requested : {test_samples}") lines.append(f"Test docs actually used : {n_test_used}") lines.append("") lines.append(f"Training time (s): {train_time:.2f}") lines.append(f"Evaluation time (s): {eval_time:.2f}") lines.append(f"Total run time (s): {total_time:.2f}") lines.append("") lines.append(f"Average Donut latency per test doc (s): {avg_donut_latency:.4f}") lines.append(f"Average OCR+Ranker latency per test doc (s): {avg_ocr_latency:.4f}") lines.append("") lines.append("Donut (naver-clova-ix/donut-base-finetuned-cord-v2)") lines.append(f" Strict EM : {dm['strict_em']:.4f}") lines.append(f" Relaxed EM : {dm['relaxed_em']:.4f}") lines.append(f" MAE : {dm['mae']:.2f}") lines.append(f" RMSE : {dm['rmse']:.2f}") lines.append(f" MAPE : {dm['mape']:.4f}") lines.append(f" Char-level F1 : {dm['char_f1']:.4f}") lines.append("") lines.append("OCR + EasyOCR + Regex + Logistic Ranker") lines.append(f" Strict EM : {om['strict_em']:.4f}") lines.append(f" Relaxed EM : {om['relaxed_em']:.4f}") lines.append(f" MAE : {om['mae']:.2f}") lines.append(f" RMSE : {om['rmse']:.2f}") lines.append(f" MAPE : {om['mape']:.4f}") lines.append(f" Char-level F1 : {om['char_f1']:.4f}") lines.append("") lines.append("Sample predictions (first 5)") lines.append("-----------------------------------") n_show = min(5, len(results["gts"])) for i in range(n_show): lines.append( f"[{i}] GT: {results['gts'][i]} | " f"Donut: {results['donut_preds'][i]} | " f"OCR+Ranker: {results['ocr_preds'][i]}" ) return "\n".join(lines) # ============================================================ # Minimal Gradio UI # ============================================================ with gr.Blocks() as demo: gr.Markdown( """ ### Donut vs OCR+Regex+Ranker · CORD v2 Compare a pretrained Donut model with an EasyOCR + Regex + Logistic Regression pipeline on receipt total extraction. """.strip() ) with gr.Row(): train_samples_input = gr.Slider( minimum=40, maximum=300, value=DEFAULT_MAX_TRAIN_SAMPLES, step=10, label="Train samples for OCR Ranker", ) test_samples_input = gr.Slider( minimum=30, maximum=200, value=DEFAULT_MAX_TEST_SAMPLES, step=10, label="Test samples", ) run_button = gr.Button("Run experiment") output_box = gr.Textbox( lines=24, label="Metrics & sample predictions", ) run_button.click( fn=run_experiment, inputs=[train_samples_input, test_samples_input], outputs=[output_box], ) if __name__ == "__main__": demo.launch()