raahinaez commited on
Commit
f736bfc
Β·
verified Β·
1 Parent(s): ced4216

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -47
app.py CHANGED
@@ -1,55 +1,112 @@
1
  import os
 
2
  import gradio as gr
3
- from docling.document_converter import DocumentConverter, PdfFormatOption
4
- from docling.datamodel.base_models import InputFormat
5
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
6
- import tempfile
7
-
8
- # Path to the model
9
- model_name = "ibm-granite/granite-docling-258M"
10
-
11
- # Load the OCR model from Hugging Face (assuming you have access to it)
12
- # In this case, let's load the model and tokenizer if needed
13
- ocr_model = AutoModelForSequenceClassification.from_pretrained(model_name)
14
- ocr_tokenizer = AutoTokenizer.from_pretrained(model_name)
15
-
16
- def pdf_to_markdown(file):
17
- # Save uploaded file temporarily
18
- tmp_path = file.name
19
-
20
- # Convert PDF using Docling/VLM (Granite Docling)
21
- converter = DocumentConverter(
22
- format_options={
23
- InputFormat.PDF: PdfFormatOption()
24
- }
25
- )
26
-
27
- # Perform OCR using granite-docling model if the file contains scanned text
28
- result = converter.convert(tmp_path)
29
- doc = result.document
30
-
31
- # Export to Markdown (or you can export to JSON via doc.model_dump())
32
- md = doc.export_to_markdown()
33
 
34
- return md
 
35
 
36
- # Define the output box size
37
- output_box = gr.Textbox(
38
- label="Markdown Output",
39
- lines=20, # initial visible lines
40
- max_lines=50, # maximum scrollable lines
41
- placeholder="Converted Markdown will appear here..."
42
- )
43
 
44
- # Create the Gradio Interface
45
- interface = gr.Interface(
46
- fn=pdf_to_markdown,
47
- inputs=gr.File(file_types=[".pdf"]),
48
- outputs=output_box,
49
- title="PDF β†’ Markdown/JSON with Granite Docling (OCR)",
50
- description="Upload a PDF (including scanned PDFs) and get parsed Markdown (or JSON) using Granite Docling via Docling, with OCR support."
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # Launch the interface
54
  if __name__ == "__main__":
55
- interface.launch()
 
 
1
  import os
2
+ import time
3
  import gradio as gr
4
+ import torch
5
+ import random
6
+ from PIL import Image, ImageOps
7
+ from transformers import AutoProcessor, Idefics3ForConditionalGeneration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Device setup for computation
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
12
+ # Define model and processor
13
+ model_id = "ibm-granite/granite-docling-258M"
 
 
 
 
 
14
 
15
+ # Load the processor and model from Hugging Face
16
+ processor = AutoProcessor.from_pretrained(model_id, use_auth_token=True)
17
+ model = Idefics3ForConditionalGeneration.from_pretrained(
18
+ model_id, device_map=device, torch_dtype=torch.bfloat16, use_auth_token=True
 
 
 
19
  )
20
+ if not torch.cuda.is_available():
21
+ model = model.to(device)
22
+
23
+ # Function to clean up special tokens in the model's response
24
+ def clean_model_response(text: str) -> str:
25
+ special_tokens = [
26
+ "<|end_of_text|>", "<|end|>", "<|assistant|>", "<|user|>", "<|system|>", "<pad>", "</s>", "<s>",
27
+ ]
28
+ cleaned = text
29
+ for token in special_tokens:
30
+ cleaned = cleaned.replace(token, "")
31
+ cleaned = cleaned.strip()
32
+ return cleaned if cleaned else "No response generated."
33
+
34
+ # Function to add random padding to images
35
+ def add_random_padding(image: Image.Image, min_percent: float = 0.1, max_percent: float = 0.10) -> Image.Image:
36
+ image = image.convert("RGB")
37
+ width, height = image.size
38
+ pad_w_percent = random.uniform(min_percent, max_percent)
39
+ pad_h_percent = random.uniform(min_percent, max_percent)
40
+ pad_w = int(width * pad_w_percent)
41
+ pad_h = int(height * pad_h_percent)
42
+ corner_pixel = image.getpixel((0, 0)) # Top-left corner
43
+ padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel)
44
+ return padded_image
45
+
46
+ # Function to generate model output for image and question
47
+ def generate_with_model(question: str, image_path: str, apply_padding: bool = False) -> str:
48
+ try:
49
+ # Open the image
50
+ image = Image.open(image_path).convert("RGB")
51
+ if apply_padding:
52
+ image = add_random_padding(image)
53
+
54
+ # Prepare the input messages for the model
55
+ messages = [
56
+ {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]}
57
+ ]
58
+ prompt = processor.apply_chat_template(messages, add_generation_prompt=True)
59
+
60
+ # Tokenize inputs
61
+ inputs = processor(text=prompt, images=[image], return_tensors="pt")
62
+ inputs = {k: v.to(device) for k, v in inputs.items()}
63
+
64
+ # Generate output with the model
65
+ with torch.no_grad():
66
+ generated_ids = model.generate(
67
+ **inputs,
68
+ max_new_tokens=4096,
69
+ temperature=0.0,
70
+ pad_token_id=processor.tokenizer.eos_token_id,
71
+ )
72
+
73
+ generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].shape[1]:], skip_special_tokens=False)[0]
74
+ cleaned_response = clean_model_response(generated_texts)
75
+ return cleaned_response
76
+
77
+ except Exception as e:
78
+ return f"Error processing image: {e}"
79
+
80
+ # Gradio UI for uploading the image and asking questions
81
+ def handle_image_upload(uploaded_file: str | None, question: str) -> str:
82
+ if uploaded_file is None:
83
+ return "No image uploaded."
84
+
85
+ # Generate result based on the uploaded image and the user's question
86
+ response = generate_with_model(question.strip(), uploaded_file)
87
+ return response
88
+
89
+ # Gradio interface setup
90
+ def build_interface():
91
+ with gr.Blocks() as demo:
92
+ gr.Markdown("# Granite Docling 258M Demo")
93
+
94
+ # Upload Image
95
+ upload_button = gr.UploadButton("πŸ“ Upload Image", file_types=["image"])
96
+
97
+ # Textbox to submit questions
98
+ question_input = gr.Textbox(submit_btn=True, show_label=False, placeholder="Ask a question...", scale=4)
99
+
100
+ # Button to submit and process the image and question
101
+ result_output = gr.Textbox(label="Model Response", interactive=False, lines=5)
102
+
103
+ # Handle image upload and question submission
104
+ upload_button.upload(handle_image_upload, inputs=[upload_button, question_input], outputs=result_output)
105
+ question_input.submit(handle_image_upload, inputs=[upload_button, question_input], outputs=result_output)
106
+
107
+ return demo
108
 
109
+ # Launch Gradio app
110
  if __name__ == "__main__":
111
+ demo = build_interface()
112
+ demo.launch()