import gradio as gr
import torch
import requests
from transformers import AutoModel, AutoProcessor, AutoTokenizer, GenerationConfig
import spaces
from typing import Iterable, List, Tuple
import os
from PIL import Image, ImageDraw
import re
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
import json
# --- Environment and Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
print("torch.__version__ =", torch.__version__)
print("torch.version.cuda =", torch.version.cuda)
print("cuda available:", torch.cuda.is_available())
print("cuda device count:", torch.cuda.device_count())
if torch.cuda.is_available():
print("current device:", torch.cuda.current_device())
print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
print(f"✅ Using device: {device}")
# --- Post-processing Functions (Integrated from Nemotron-Parse) ---
# NOTE: The following helper functions are adapted from the NVIDIA-Nemotron-Parse repository
# to make this script self-contained and runnable.
def extract_classes_bboxes(generated_text: str) -> Tuple[List[str], List[List[int]], List[str]]:
"""Extracts classes, bounding boxes, and associated text from the generated output."""
classes, bboxes, texts = [], [], []
# Regex to find ... sections
box_pattern = re.compile(r"(.*?)", re.DOTALL)
# Regex to find ... and [[...]] within each box
class_pattern = re.compile(r"(.*?)")
bbox_pattern = re.compile(r"\[\[(\d+,\d+,\d+,\d+)\]\]")
for match in box_pattern.finditer(generated_text):
box_content = match.group(1)
class_match = class_pattern.search(box_content)
bbox_match = bbox_pattern.search(box_content)
if class_match and bbox_match:
cls = class_match.group(1).strip()
bbox_str = bbox_match.group(1).split(',')
bbox = [int(coord) for coord in bbox_str]
# Extract the text following the bbox
text_content = box_content[bbox_match.end():].strip()
classes.append(cls)
bboxes.append(bbox)
texts.append(text_content)
return classes, bboxes, texts
def transform_bbox_to_original(bbox: List[int], original_width: int, original_height: int,
target_width: int = 1000, target_height: int = 1000) -> List[int]:
"""Transforms a bounding box from the model's coordinate system to the original image dimensions."""
x1, y1, x2, y2 = bbox
original_x1 = int(x1 / target_width * original_width)
original_y1 = int(y1 / target_height * original_height)
original_x2 = int(x2 / target_width * original_width)
original_y2 = int(y2 / target_height * original_height)
return [original_x1, original_y1, original_x2, original_y2]
def postprocess_text(text: str, cls: str = 'text', table_format: str = 'markdown',
text_format: str = 'markdown', blank_text_in_figures: bool = False) -> str:
"""Post-processes the extracted text based on its class and desired format."""
if blank_text_in_figures and cls.lower() == 'picture':
return ""
if cls.lower() == 'table':
if table_format == 'html':
return f"
"
elif table_format == 'latex':
return f"\\begin{{tabular}}\n{text}\n\\end{{tabular}}"
if text_format == 'markdown':
return text.strip()
return text
# --- Theme and CSS (Unchanged) ---
colors.steel_blue = colors.Color(name="steel_blue",c50="#EBF3F8",c100="#D3E5F0",c200="#A8CCE1",c300="#7DB3D2",c400="#529AC3",c500="#4682B4",c600="#3E72A0",c700="#36638C",c800="#2E5378",c900="#264364",c950="#1E3450")
class SteelBlueTheme(Soft):
def __init__(self,*,primary_hue: colors.Color | str = colors.gray,secondary_hue: colors.Color | str = colors.steel_blue,neutral_hue: colors.Color | str = colors.slate,text_size: sizes.Size | str = sizes.text_lg,font: fonts.Font | str | Iterable[fonts.Font | str] = (fonts.GoogleFont("Outfit"), "Arial", "sans-serif",),font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",),):
super().__init__(primary_hue=primary_hue,secondary_hue=secondary_hue,neutral_hue=neutral_hue,text_size=text_size,font=font,font_mono=font_mono,)
super().set(background_fill_primary="*primary_50",background_fill_primary_dark="*primary_900",body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",button_primary_text_color="white",button_primary_text_color_hover="white",button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",slider_color="*secondary_500",slider_color_dark="*secondary_600",block_title_text_weight="600",block_border_width="3px",block_shadow="*shadow_drop_lg",button_primary_shadow="*shadow_drop_lg",button_large_padding="11px",color_accent_soft="*primary_100",block_label_background_fill="*primary_200",)
steel_blue_theme = SteelBlueTheme()
css = """
#main-title h1 { font-size: 2.3em !important; }
#output-title h2 { font-size: 2.1em !important; }
"""
# --- Model Loading ---
print("Loading model, tokenizer, and processor...")
model_path = "nvidia/NVIDIA-Nemotron-Parse-v1.1"
model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_path)
processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
generation_config = GenerationConfig.from_pretrained(model_path, trust_remote_code=True)
print("✅ Model loaded successfully.")
# --- Main Processing Function ---
@spaces.GPU
def process_document(image: Image.Image, task_type: str):
"""
Processes an image with NVIDIA-Nemotron-Parse-v1.1.
"""
if image is None:
return "Please upload an image first.", None
print(f"🏃 Running inference with task: {task_type}")
# Define task prompts based on user selection
task_prompts = {
"Markdown + BBoxes + Classes": "",
"Markdown Only": "",
"Plain Text + BBoxes + Classes": "",
"Plain Text Only": "",
}
task_prompt = task_prompts.get(task_type, "")
# Process image and generate output
inputs = processor(images=[image], text=task_prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, generation_config=generation_config)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
print("✅ Inference complete. Starting post-processing...")
# Post-process the generated text
classes, bboxes, texts = extract_classes_bboxes(generated_text)
# Combine all text for the main output textbox
full_text_output = generated_text.split("")[-1]
# If bounding boxes were predicted, draw them on the image
result_image_pil = None
if bboxes:
print(f"✅ Found {len(bboxes)} bounding box(es). Drawing on the original image.")
bboxes_original = [transform_bbox_to_original(bbox, image.width, image.height) for bbox in bboxes]
image_with_bboxes = image.copy().convert("RGB")
draw = ImageDraw.Draw(image_with_bboxes)
for bbox in bboxes_original:
draw.rectangle(tuple(bbox), outline="red", width=3)
result_image_pil = image_with_bboxes
else:
print("⚠️ No bounding boxes were predicted or found.")
result_image_pil = image.copy() # Return original image if no bboxes
return full_text_output, result_image_pil
# --- Gradio UI ---
with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
gr.Markdown("# **NVIDIA Nemotron Parse v1.1**", elem_id="main-title")
gr.Markdown("### An advanced model for document layout analysis and text extraction.")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Document Image", sources=["upload", "clipboard"])
task_type = gr.Dropdown(
choices=["Markdown + BBoxes + Classes", "Markdown Only", "Plain Text + BBoxes + Classes", "Plain Text Only"],
value="Markdown + BBoxes + Classes",
label="Task Type"
)
submit_btn = gr.Button("Process Document", variant="primary")
gr.Examples(
examples=["examples/1.jpg", "examples/2.jpg", "examples/3.jpg"],
inputs=image_input,
label="Examples"
)
with gr.Column(scale=2):
output_text = gr.Textbox(label="Extracted Text", lines=12, show_copy_button=True)
output_image = gr.Image(label="Layout Detection (If Task Includes BBoxes)", type="pil")
with gr.Accordion("Note", open=False):
gr.Markdown("Inference is performed using Hugging Face transformers on available hardware. This app demonstrates the capabilities of the NVIDIA-Nemotron-Parse-v1.1 model.")
submit_btn.click(
fn=process_document,
inputs=[image_input, task_type],
outputs=[output_text, output_image]
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True, ssr_mode=False)