Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from streamlit_cropper import st_cropper | |
| from PIL import Image | |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel, DonutProcessor, NougatProcessor | |
| import torch | |
| import re | |
| import pytesseract | |
| from io import BytesIO | |
| import openai | |
| import requests | |
| from nougat.dataset.rasterize import rasterize_paper | |
| import uuid | |
| import os | |
| def get_pdf(pdf_link): | |
| unique_filename = f"{os.getcwd()}/downloaded_paper_{uuid.uuid4().hex}.pdf" | |
| response = requests.get(pdf_link) | |
| if response.status_code == 200: | |
| with open(unique_filename, 'wb') as pdf_file: | |
| pdf_file.write(response.content) | |
| print("PDF downloaded successfully.") | |
| else: | |
| print("Failed to download the PDF.") | |
| return unique_filename | |
| def predict_arabic(img, model_name="UBC-NLP/Qalam"): | |
| # if img is None: | |
| # _,generated_text=main(image) | |
| # return generated_text | |
| # else: | |
| # model_name = "UBC-NLP/Qalam" | |
| processor = TrOCRProcessor.from_pretrained(model_name) | |
| model = VisionEncoderDecoderModel.from_pretrained(model_name) | |
| images = img.convert("RGB") | |
| pixel_values = processor(images, return_tensors="pt").pixel_values | |
| generated_ids = model.generate(pixel_values, max_length=256) | |
| generated_text = processor.batch_decode( | |
| generated_ids, skip_special_tokens=True)[0] | |
| return generated_text | |
| def predict_english(img, model_name="naver-clova-ix/donut-base-finetuned-cord-v2"): | |
| processor = DonutProcessor.from_pretrained(model_name) | |
| model = VisionEncoderDecoderModel.from_pretrained(model_name) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model.to(device) | |
| task_prompt = "<s_cord-v2>" | |
| decoder_input_ids = processor.tokenizer( | |
| task_prompt, add_special_tokens=False, return_tensors="pt").input_ids | |
| image = img.convert("RGB") | |
| pixel_values = processor(image, return_tensors="pt").pixel_values | |
| outputs = model.generate( | |
| pixel_values.to(device), | |
| decoder_input_ids=decoder_input_ids.to(device), | |
| max_length=model.decoder.config.max_position_embeddings, | |
| early_stopping=True, | |
| pad_token_id=processor.tokenizer.pad_token_id, | |
| eos_token_id=processor.tokenizer.eos_token_id, | |
| use_cache=True, | |
| num_beams=1, | |
| bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
| return_dict_in_generate=True, | |
| ) | |
| sequence = processor.batch_decode(outputs.sequences)[0] | |
| sequence = sequence.replace(processor.tokenizer.eos_token, "").replace( | |
| processor.tokenizer.pad_token, "") | |
| sequence = re.sub(r"<.*?>", "", sequence).strip() | |
| return sequence | |
| def predict_nougat(img, model_name="facebook/nougat-small"): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| processor = NougatProcessor.from_pretrained(model_name) | |
| model = VisionEncoderDecoderModel.from_pretrained(model_name) | |
| image = img.convert("RGB") | |
| pixel_values = processor(image, return_tensors="pt", | |
| data_format="channels_first").pixel_values | |
| # generate transcription (here we only generate 30 tokens) | |
| outputs = model.generate( | |
| pixel_values.to(device), | |
| min_length=1, | |
| max_new_tokens=1500, | |
| bad_words_ids=[[processor.tokenizer.unk_token_id]], | |
| ) | |
| page_sequence = processor.batch_decode( | |
| outputs, skip_special_tokens=True)[0] | |
| # page_sequence = processor.post_process_generation(page_sequence, fix_markdown=False) | |
| return page_sequence | |
| def inference_nougat(pdf_file, pdf_link): | |
| if pdf_file is None: | |
| if pdf_link == '': | |
| print("No file is uploaded and No link is provided") | |
| return "No data provided. Upload a pdf file or provide a pdf link and try again!" | |
| else: | |
| file_name = get_pdf(pdf_link) | |
| else: | |
| file_name = pdf_file.name | |
| pdf_name = pdf_file.name.split('/')[-1].split('.')[0] | |
| images = rasterize_paper(file_name, return_pil=True) | |
| sequence = "" | |
| # infer for every page and concat | |
| for image in images: | |
| sequence += predict_nougat(image) | |
| content = sequence.replace(r'\(', '$').replace( | |
| r'\)', '$').replace(r'\[', '$$').replace(r'\]', '$$') | |
| return content | |
| def predict_tesseract(img): | |
| text = pytesseract.image_to_string(Image.open(img)) | |
| return text | |
| st.set_option('deprecation.showfileUploaderEncoding', False) | |
| st.set_page_config( | |
| page_title="Ex-stream-ly Cool App", | |
| page_icon="🖊️", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| menu_items={ | |
| 'Get Help': 'https://www.extremelycoolapp.com/help', | |
| 'Report a bug': "https://www.extremelycoolapp.com/bug", | |
| 'About': "# This is a header. This is an *extremely* cool app!" | |
| } | |
| ) | |
| # Upload an image and set some options for demo purposes | |
| st.header("Qalam: A Multilingual OCR System") | |
| st.sidebar.header("Configuration and Image Upload") | |
| st.sidebar.subheader("Adjust Image Enhancement Options") | |
| img_file = st.sidebar.file_uploader( | |
| label='Upload a file', type=['png', 'jpg', "pdf"]) | |
| # input_file = st.sidebar.text_input("Enter the file URL") | |
| realtime_update = st.sidebar.checkbox(label="Update in Real Time", value=True) | |
| # box_color = st.sidebar.color_picker(label="Box Color", value='#0000FF') | |
| aspect_choice = st.sidebar.radio(label="Aspect Ratio", options=[ | |
| "Free"]) | |
| aspect_dict = { | |
| "Free": None | |
| } | |
| aspect_ratio = aspect_dict[aspect_choice] | |
| st.sidebar.subheader("Select OCR Language and Model") | |
| Lng = st.sidebar.selectbox(label="Language", options=[ | |
| "Arabic", "English", "French", "Korean", "Chinese"]) | |
| Models = { | |
| "Arabic": "Qalam", | |
| "English": "Nougat", | |
| "French": "Tesseract", | |
| "Korean": "Donut", | |
| "Chinese": "Donut" | |
| } | |
| st.sidebar.markdown(f"### Selected Model: {Models[Lng]}") | |
| if img_file: | |
| if not img_file.type == "application/pdf": | |
| img = Image.open(img_file) | |
| if not realtime_update: | |
| st.write("Double click to save crop") | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.subheader("Input: Upload and Crop Your Image") | |
| # Get a cropped image from the frontend | |
| cropped_img = st_cropper( | |
| img, | |
| realtime_update=realtime_update, | |
| box_color="#FF0000", | |
| aspect_ratio=aspect_ratio, | |
| should_resize_image=True, | |
| ) | |
| with col2: | |
| # Manipulate cropped image at will | |
| st.subheader("Output: Preview and Analyze") | |
| # _ = cropped_img.thumbnail((150, 150)) | |
| st.image(cropped_img) | |
| button = st.button("Run OCR") | |
| if button: | |
| with st.spinner('Running OCR...'): | |
| if Lng == "Arabic": | |
| ocr_text = predict_arabic(cropped_img) | |
| elif Lng == "English": | |
| ocr_text = predict_nougat(cropped_img) | |
| elif Lng == "French": | |
| ocr_text = predict_tesseract(cropped_img) | |
| elif Lng == "Korean": | |
| ocr_text = predict_english(cropped_img) | |
| elif Lng == "Chinese": | |
| ocr_text = predict_english(cropped_img) | |
| st.subheader(f"OCR Results for {Lng}") | |
| st.write(ocr_text) | |
| text_file = BytesIO(ocr_text.encode()) | |
| st.download_button('Download Text', text_file, | |
| file_name='ocr_text.txt') | |
| elif img_file.type == "application/pdf": | |
| button = st.sidebar.button("Run OCR") | |
| if button: | |
| with st.spinner('Running OCR...'): | |
| ocr_text = inference_nougat(img_file, "") | |
| st.subheader(f"OCR Results for the PDF file") | |
| st.write(ocr_text) | |
| text_file = BytesIO(ocr_text.encode()) | |
| st.download_button('Download Text', text_file, | |
| file_name='ocr_text.txt') | |
| # openai.api_key = "" | |
| # if "openai_model" not in st.session_state: | |
| # st.session_state["openai_model"] = "gpt-3.5-turbo" | |
| # if "messages" not in st.session_state: | |
| # st.session_state.messages = [] | |
| # for message in st.session_state.messages: | |
| # with st.chat_message(message["role"]): | |
| # st.markdown(message["content"]) | |
| # if prompt := st.chat_input("How can I help?"): | |
| # st.session_state.messages.append({"role": "user", "content": ocr_text + prompt}) | |
| # with st.chat_message("user"): | |
| # st.markdown(prompt) | |
| # with st.chat_message("assistant"): | |
| # message_placeholder = st.empty() | |
| # full_response = "" | |
| # for response in openai.ChatCompletion.create( | |
| # model=st.session_state["openai_model"], | |
| # messages=[ | |
| # {"role": m["role"], "content": m["content"]} | |
| # for m in st.session_state.messages | |
| # ], | |
| # stream=True, | |
| # ): | |
| # full_response += response.choices[0].delta.get("content", "") | |
| # message_placeholder.markdown(full_response + "▌") | |
| # message_placeholder.markdown(full_response) | |
| # st.session_state.messages.append({"role": "assistant", "content": full_response}) | |