Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| import requests | |
| from transformers import AutoModel, AutoProcessor, AutoTokenizer, GenerationConfig | |
| import spaces | |
| from typing import Iterable, List, Tuple | |
| import os | |
| from PIL import Image, ImageDraw | |
| import re | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| import json | |
| # --- Environment and Device Setup --- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES")) | |
| print("torch.__version__ =", torch.__version__) | |
| print("torch.version.cuda =", torch.version.cuda) | |
| print("cuda available:", torch.cuda.is_available()) | |
| print("cuda device count:", torch.cuda.device_count()) | |
| if torch.cuda.is_available(): | |
| print("current device:", torch.cuda.current_device()) | |
| print("device name:", torch.cuda.get_device_name(torch.cuda.current_device())) | |
| print(f"✅ Using device: {device}") | |
| # --- Post-processing Functions (Integrated from Nemotron-Parse) --- | |
| # NOTE: The following helper functions are adapted from the NVIDIA-Nemotron-Parse repository | |
| # to make this script self-contained and runnable. | |
| def extract_classes_bboxes(generated_text: str) -> Tuple[List[str], List[List[int]], List[str]]: | |
| """Extracts classes, bounding boxes, and associated text from the generated output.""" | |
| classes, bboxes, texts = [], [], [] | |
| # Regex to find <box>...</box> sections | |
| box_pattern = re.compile(r"<box>(.*?)</box>", re.DOTALL) | |
| # Regex to find <class>...</class> and [[...]] within each box | |
| class_pattern = re.compile(r"<class>(.*?)</class>") | |
| bbox_pattern = re.compile(r"\[\[(\d+,\d+,\d+,\d+)\]\]") | |
| for match in box_pattern.finditer(generated_text): | |
| box_content = match.group(1) | |
| class_match = class_pattern.search(box_content) | |
| bbox_match = bbox_pattern.search(box_content) | |
| if class_match and bbox_match: | |
| cls = class_match.group(1).strip() | |
| bbox_str = bbox_match.group(1).split(',') | |
| bbox = [int(coord) for coord in bbox_str] | |
| # Extract the text following the bbox | |
| text_content = box_content[bbox_match.end():].strip() | |
| classes.append(cls) | |
| bboxes.append(bbox) | |
| texts.append(text_content) | |
| return classes, bboxes, texts | |
| def transform_bbox_to_original(bbox: List[int], original_width: int, original_height: int, | |
| target_width: int = 1000, target_height: int = 1000) -> List[int]: | |
| """Transforms a bounding box from the model's coordinate system to the original image dimensions.""" | |
| x1, y1, x2, y2 = bbox | |
| original_x1 = int(x1 / target_width * original_width) | |
| original_y1 = int(y1 / target_height * original_height) | |
| original_x2 = int(x2 / target_width * original_width) | |
| original_y2 = int(y2 / target_height * original_height) | |
| return [original_x1, original_y1, original_x2, original_y2] | |
| def postprocess_text(text: str, cls: str = 'text', table_format: str = 'markdown', | |
| text_format: str = 'markdown', blank_text_in_figures: bool = False) -> str: | |
| """Post-processes the extracted text based on its class and desired format.""" | |
| if blank_text_in_figures and cls.lower() == 'picture': | |
| return "" | |
| if cls.lower() == 'table': | |
| if table_format == 'html': | |
| return f"<table>\n{text}\n</table>" | |
| elif table_format == 'latex': | |
| return f"\\begin{{tabular}}\n{text}\n\\end{{tabular}}" | |
| if text_format == 'markdown': | |
| return text.strip() | |
| return text | |
| # --- Theme and CSS (Unchanged) --- | |
| colors.steel_blue = colors.Color(name="steel_blue",c50="#EBF3F8",c100="#D3E5F0",c200="#A8CCE1",c300="#7DB3D2",c400="#529AC3",c500="#4682B4",c600="#3E72A0",c700="#36638C",c800="#2E5378",c900="#264364",c950="#1E3450") | |
| class SteelBlueTheme(Soft): | |
| def __init__(self,*,primary_hue: colors.Color | str = colors.gray,secondary_hue: colors.Color | str = colors.steel_blue,neutral_hue: colors.Color | str = colors.slate,text_size: sizes.Size | str = sizes.text_lg,font: fonts.Font | str | Iterable[fonts.Font | str] = (fonts.GoogleFont("Outfit"), "Arial", "sans-serif",),font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",),): | |
| super().__init__(primary_hue=primary_hue,secondary_hue=secondary_hue,neutral_hue=neutral_hue,text_size=text_size,font=font,font_mono=font_mono,) | |
| super().set(background_fill_primary="*primary_50",background_fill_primary_dark="*primary_900",body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",button_primary_text_color="white",button_primary_text_color_hover="white",button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",slider_color="*secondary_500",slider_color_dark="*secondary_600",block_title_text_weight="600",block_border_width="3px",block_shadow="*shadow_drop_lg",button_primary_shadow="*shadow_drop_lg",button_large_padding="11px",color_accent_soft="*primary_100",block_label_background_fill="*primary_200",) | |
| steel_blue_theme = SteelBlueTheme() | |
| css = """ | |
| #main-title h1 { font-size: 2.3em !important; } | |
| #output-title h2 { font-size: 2.1em !important; } | |
| """ | |
| # --- Model Loading --- | |
| print("Loading model, tokenizer, and processor...") | |
| model_path = "nvidia/NVIDIA-Nemotron-Parse-v1.1" | |
| model = AutoModel.from_pretrained( | |
| model_path, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 | |
| ).to(device).eval() | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) | |
| generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True) | |
| print("✅ Model loaded successfully.") | |
| # --- Main Processing Function --- | |
| def process_document(image: Image.Image, task_type: str): | |
| """ | |
| Processes an image with NVIDIA-Nemotron-Parse-v1.1. | |
| """ | |
| if image is None: | |
| return "Please upload an image first.", None | |
| print(f"🏃 Running inference with task: {task_type}") | |
| # Define task prompts based on user selection | |
| task_prompts = { | |
| "Markdown + BBoxes + Classes": "</s><s><predict_bbox><predict_classes><output_markdown>", | |
| "Markdown Only": "</s><s><output_markdown>", | |
| "Plain Text + BBoxes + Classes": "</s><s><predict_bbox><predict_classes><output_text>", | |
| "Plain Text Only": "</s><s><output_text>", | |
| } | |
| task_prompt = task_prompts.get(task_type, "</s><s><output_markdown>") | |
| # Process image and generate output | |
| inputs = processor(images=[image], text=task_prompt, return_tensors="pt").to(device) | |
| outputs = model.generate(**inputs, generation_config=generation_config) | |
| generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
| print("✅ Inference complete. Starting post-processing...") | |
| # Post-process the generated text | |
| classes, bboxes, texts = extract_classes_bboxes(generated_text) | |
| # Combine all text for the main output textbox | |
| full_text_output = generated_text.split("<s>")[-1] | |
| # If bounding boxes were predicted, draw them on the image | |
| result_image_pil = None | |
| if bboxes: | |
| print(f"✅ Found {len(bboxes)} bounding box(es). Drawing on the original image.") | |
| bboxes_original = [transform_bbox_to_original(bbox, image.width, image.height) for bbox in bboxes] | |
| image_with_bboxes = image.copy().convert("RGB") | |
| draw = ImageDraw.Draw(image_with_bboxes) | |
| for bbox in bboxes_original: | |
| draw.rectangle(tuple(bbox), outline="red", width=3) | |
| result_image_pil = image_with_bboxes | |
| else: | |
| print("⚠️ No bounding boxes were predicted or found.") | |
| result_image_pil = image.copy() # Return original image if no bboxes | |
| return full_text_output, result_image_pil | |
| # --- Gradio UI --- | |
| with gr.Blocks(css=css, theme=steel_blue_theme) as demo: | |
| gr.Markdown("# **NVIDIA Nemotron Parse v1.1**", elem_id="main-title") | |
| gr.Markdown("### An advanced model for document layout analysis and text extraction.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Upload Document Image", sources=["upload", "clipboard"]) | |
| task_type = gr.Dropdown( | |
| choices=["Markdown + BBoxes + Classes", "Markdown Only", "Plain Text + BBoxes + Classes", "Plain Text Only"], | |
| value="Markdown + BBoxes + Classes", | |
| label="Task Type" | |
| ) | |
| submit_btn = gr.Button("Process Document", variant="primary") | |
| gr.Examples( | |
| examples=["examples/1.jpg", "examples/2.jpg", "examples/3.jpg"], | |
| inputs=image_input, | |
| label="Examples" | |
| ) | |
| with gr.Column(scale=2): | |
| output_text = gr.Textbox(label="Extracted Text", lines=12, show_copy_button=True) | |
| output_image = gr.Image(label="Layout Detection (If Task Includes BBoxes)", type="pil") | |
| with gr.Accordion("Note", open=False): | |
| gr.Markdown("Inference is performed using Hugging Face transformers on available hardware. This app demonstrates the capabilities of the NVIDIA-Nemotron-Parse-v1.1 model.") | |
| submit_btn.click( | |
| fn=process_document, | |
| inputs=[image_input, task_type], | |
| outputs=[output_text, output_image] | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=20).launch(share=True, ssr_mode=False) |