import gradio as gr import torch import os import sys from PIL import Image, ImageDraw from transformers import AutoModel, AutoProcessor, AutoTokenizer, GenerationConfig from huggingface_hub import snapshot_download import spaces from typing import Optional, Tuple, Dict, Any, Iterable from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes # --- Model & Script Download --- print("Downloading model snapshot to ensure all scripts are present...") # Download the full model repo to ensure postprocessing.py is available locally model_dir = snapshot_download(repo_id="nvidia/NVIDIA-Nemotron-Parse-v1.1") print(f"Model downloaded to: {model_dir}") # Add the model directory to sys.path so we can import postprocessing sys.path.append(model_dir) try: from postprocessing import extract_classes_bboxes, transform_bbox_to_original, postprocess_text print("✅ Successfully imported postprocessing functions.") except ImportError as e: print(f"❌ Error importing postprocessing: {e}") raise e # --- Device Setup --- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # --- Theme Definition --- colors.steel_blue = colors.Color( name="steel_blue", c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2", c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C", c800="#2E5378", c900="#264364", c950="#1E3450", ) class SteelBlueTheme(Soft): def __init__( self, *, primary_hue: colors.Color | str = colors.gray, secondary_hue: colors.Color | str = colors.steel_blue, neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Outfit"), "Arial", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono, ) super().set( background_fill_primary="*primary_50", background_fill_primary_dark="*primary_900", body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)", button_primary_text_color="white", button_primary_text_color_hover="white", button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)", button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)", slider_color="*secondary_500", slider_color_dark="*secondary_600", block_title_text_weight="600", block_border_width="3px", block_shadow="*shadow_drop_lg", button_primary_shadow="*shadow_drop_lg", button_large_padding="11px", color_accent_soft="*primary_100", block_label_background_fill="*primary_200", ) steel_blue_theme = SteelBlueTheme() css = """ #main-title h1 { font-size: 2.3em !important; } #output-title h2 { font-size: 2.1em !important; } """ # --- Model Loading --- print("Loading Model components...") processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True) model = AutoModel.from_pretrained( model_dir, trust_remote_code=True, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32 ).to(device).eval() try: generation_config = GenerationConfig.from_pretrained(model_dir, trust_remote_code=True) except Exception as e: print(f"Warning: Could not load GenerationConfig: {e}. Using default.") generation_config = GenerationConfig(max_new_tokens=4096) print("✅ Model loaded successfully.") @spaces.GPU def process_ocr_task(image): """ Processes an image with NVIDIA-Nemotron-Parse-v1.1. """ if image is None: return "Please upload an image first.", None task_prompt = "" inputs = processor(images=[image], text=task_prompt, return_tensors="pt").to(device) if device.type == 'cuda': inputs = {k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v for k, v in inputs.items()} print("🏃 Running inference...") with torch.no_grad(): outputs = model.generate( **inputs, generation_config=generation_config ) generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0] try: classes, bboxes, texts = extract_classes_bboxes(generated_text) except Exception as e: print(f"Error extracting boxes: {e}") return generated_text, image # Transform boxes to original image size bboxes = [transform_bbox_to_original(bbox, image.width, image.height) for bbox in bboxes] table_format = 'latex' text_format = 'markdown' blank_text_in_figures = False processed_texts = [ postprocess_text( text, cls=cls, table_format=table_format, text_format=text_format, blank_text_in_figures=blank_text_in_figures ) for text, cls in zip(texts, classes) ] result_image = image.copy() draw = ImageDraw.Draw(result_image) color_map = { "Table": "red", "Figure": "blue", "Text": "green", "Title": "purple" } final_output_text = "" for cls, bbox, txt in zip(classes, bboxes, processed_texts): # Normalize coordinates to prevent PIL ValueError (x1 >= x0) x1, y1, x2, y2 = bbox xmin = min(x1, x2) ymin = min(y1, y2) xmax = max(x1, x2) ymax = max(y1, y2) color = color_map.get(cls, "red") draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=3) if cls == "Table": final_output_text += f"\n\n--- [Table] ---\n{txt}\n-----------------\n" elif cls == "Figure": final_output_text += f"\n\n--- [Figure] ---\n(Figure Detected)\n-----------------\n" else: final_output_text += f"{txt}\n" if not final_output_text.strip() and generated_text: final_output_text = generated_text return final_output_text, result_image # --- Gradio Interface --- with gr.Blocks(css=css, theme=steel_blue_theme) as demo: gr.Markdown("# **NVIDIA Nemotron Parse v1.1 [OCR/Parsing]**", elem_id="main-title") gr.Markdown("Upload a document image to extract text, tables, and layout structures using NVIDIA's state-of-the-art Parse model.") with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"]) submit_btn = gr.Button("Process Document", variant="primary") examples = gr.Examples( examples=["examples/1.jpg"], inputs=image_input, label="Examples" ) with gr.Column(scale=2): output_text = gr.Textbox(label="Parsed Content (Markdown/LaTeX)", lines=20, show_copy_button=True) output_image = gr.Image(label="Detected Layout & Bounding Boxes", type="pil") with gr.Accordion("Technical Details", open=False): gr.Markdown(""" **Model:** [nvidia/NVIDIA-Nemotron-Parse-v1.1](https://huggingface.co/nvidia/NVIDIA-Nemotron-Parse-v1.1) **Architecture:** Llama-3-Vila based. **Capabilities:** High-accuracy OCR, Table extraction (to LaTeX/HTML), Figure detection. """) submit_btn.click( fn=process_ocr_task, inputs=[image_input], outputs=[output_text, output_image] ) if __name__ == "__main__": demo.queue(max_size=20).launch(share=True, mcp_server=True, ssr_mode=False)