import os import time import gradio as gr import torch import random from PIL import Image, ImageOps from transformers import AutoProcessor, Idefics3ForConditionalGeneration # Device setup for computation device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define model and processor model_id = "ibm-granite/granite-docling-258M" # Load the processor and model from Hugging Face processor = AutoProcessor.from_pretrained(model_id, use_auth_token=True) model = Idefics3ForConditionalGeneration.from_pretrained( model_id, device_map=device, torch_dtype=torch.bfloat16, use_auth_token=True ) if not torch.cuda.is_available(): model = model.to(device) # Function to clean up special tokens in the model's response def clean_model_response(text: str) -> str: special_tokens = [ "<|end_of_text|>", "<|end|>", "<|assistant|>", "<|user|>", "<|system|>", "", "", "", ] cleaned = text for token in special_tokens: cleaned = cleaned.replace(token, "") cleaned = cleaned.strip() return cleaned if cleaned else "No response generated." # Function to add random padding to images def add_random_padding(image: Image.Image, min_percent: float = 0.1, max_percent: float = 0.10) -> Image.Image: image = image.convert("RGB") width, height = image.size pad_w_percent = random.uniform(min_percent, max_percent) pad_h_percent = random.uniform(min_percent, max_percent) pad_w = int(width * pad_w_percent) pad_h = int(height * pad_h_percent) corner_pixel = image.getpixel((0, 0)) # Top-left corner padded_image = ImageOps.expand(image, border=(pad_w, pad_h, pad_w, pad_h), fill=corner_pixel) return padded_image # Function to generate model output for image and question def generate_with_model(question: str, image_path: str, apply_padding: bool = False) -> str: try: # Open the image image = Image.open(image_path).convert("RGB") if apply_padding: image = add_random_padding(image) # Prepare the input messages for the model messages = [ {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": question}]} ] prompt = processor.apply_chat_template(messages, add_generation_prompt=True) # Tokenize inputs inputs = processor(text=prompt, images=[image], return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} # Generate output with the model with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=4096, temperature=0.0, pad_token_id=processor.tokenizer.eos_token_id, ) generated_texts = processor.batch_decode(generated_ids[:, inputs["input_ids"].shape[1]:], skip_special_tokens=False)[0] cleaned_response = clean_model_response(generated_texts) return cleaned_response except Exception as e: return f"Error processing image: {e}" # Gradio UI for uploading the image and asking questions def handle_image_upload(uploaded_file: str | None, question: str) -> str: if uploaded_file is None: return "No image uploaded." # Generate result based on the uploaded image and the user's question response = generate_with_model(question.strip(), uploaded_file) return response # Gradio interface setup def build_interface(): with gr.Blocks() as demo: gr.Markdown("# Granite Docling 258M Demo") # Upload Image upload_button = gr.UploadButton("📁 Upload Image", file_types=["image"]) # Textbox to submit questions question_input = gr.Textbox(submit_btn=True, show_label=False, placeholder="Ask a question...", scale=4) # Button to submit and process the image and question result_output = gr.Textbox(label="Model Response", interactive=False, lines=5) # Handle image upload and question submission upload_button.upload(handle_image_upload, inputs=[upload_button, question_input], outputs=result_output) question_input.submit(handle_image_upload, inputs=[upload_button, question_input], outputs=result_output) return demo # Launch Gradio app if __name__ == "__main__": demo = build_interface() demo.launch()