prithivMLmods's picture
Update app.py
874a882 verified
raw
history blame
8.49 kB
import gradio as gr
import torch
import os
import sys
from PIL import Image, ImageDraw
from transformers import AutoModel, AutoProcessor, AutoTokenizer, GenerationConfig
from huggingface_hub import snapshot_download
import spaces
from typing import Optional, Tuple, Dict, Any, Iterable
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
# --- Model & Script Download ---
print("Downloading model snapshot to ensure all scripts are present...")
# Download the full model repo to ensure postprocessing.py is available locally
model_dir = snapshot_download(repo_id="nvidia/NVIDIA-Nemotron-Parse-v1.1")
print(f"Model downloaded to: {model_dir}")
# Add the model directory to sys.path so we can import postprocessing
sys.path.append(model_dir)
try:
from postprocessing import extract_classes_bboxes, transform_bbox_to_original, postprocess_text
print("✅ Successfully imported postprocessing functions.")
except ImportError as e:
print(f"❌ Error importing postprocessing: {e}")
raise e
# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# --- Theme Definition ---
colors.steel_blue = colors.Color(
name="steel_blue",
c50="#EBF3F8",
c100="#D3E5F0",
c200="#A8CCE1",
c300="#7DB3D2",
c400="#529AC3",
c500="#4682B4",
c600="#3E72A0",
c700="#36638C",
c800="#2E5378",
c900="#264364",
c950="#1E3450",
)
class SteelBlueTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.steel_blue,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
background_fill_primary_dark="*primary_900",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
button_primary_text_color="white",
button_primary_text_color_hover="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_700)",
button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_600)",
slider_color="*secondary_500",
slider_color_dark="*secondary_600",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
button_large_padding="11px",
color_accent_soft="*primary_100",
block_label_background_fill="*primary_200",
)
steel_blue_theme = SteelBlueTheme()
css = """
#main-title h1 { font-size: 2.3em !important; }
#output-title h2 { font-size: 2.1em !important; }
"""
# --- Model Loading ---
print("Loading Model components...")
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
model = AutoModel.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
).to(device).eval()
try:
generation_config = GenerationConfig.from_pretrained(model_dir, trust_remote_code=True)
except Exception as e:
print(f"Warning: Could not load GenerationConfig: {e}. Using default.")
generation_config = GenerationConfig(max_new_tokens=4096)
print("✅ Model loaded successfully.")
@spaces.GPU
def process_ocr_task(image):
"""
Processes an image with NVIDIA-Nemotron-Parse-v1.1.
"""
if image is None:
return "Please upload an image first.", None
task_prompt = "</s><s><predict_bbox><predict_classes><output_markdown>"
inputs = processor(images=[image], text=task_prompt, return_tensors="pt").to(device)
if device.type == 'cuda':
inputs = {k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v for k, v in inputs.items()}
print("🏃 Running inference...")
with torch.no_grad():
outputs = model.generate(
**inputs,
generation_config=generation_config
)
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0]
try:
classes, bboxes, texts = extract_classes_bboxes(generated_text)
except Exception as e:
print(f"Error extracting boxes: {e}")
return generated_text, image
# Transform boxes to original image size
bboxes = [transform_bbox_to_original(bbox, image.width, image.height) for bbox in bboxes]
table_format = 'latex'
text_format = 'markdown'
blank_text_in_figures = False
processed_texts = [
postprocess_text(
text,
cls=cls,
table_format=table_format,
text_format=text_format,
blank_text_in_figures=blank_text_in_figures
)
for text, cls in zip(texts, classes)
]
result_image = image.copy()
draw = ImageDraw.Draw(result_image)
color_map = {
"Table": "red",
"Figure": "blue",
"Text": "green",
"Title": "purple"
}
final_output_text = ""
for cls, bbox, txt in zip(classes, bboxes, processed_texts):
# Normalize coordinates to prevent PIL ValueError (x1 >= x0)
x1, y1, x2, y2 = bbox
xmin = min(x1, x2)
ymin = min(y1, y2)
xmax = max(x1, x2)
ymax = max(y1, y2)
color = color_map.get(cls, "red")
draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=3)
if cls == "Table":
final_output_text += f"\n\n--- [Table] ---\n{txt}\n-----------------\n"
elif cls == "Figure":
final_output_text += f"\n\n--- [Figure] ---\n(Figure Detected)\n-----------------\n"
else:
final_output_text += f"{txt}\n"
if not final_output_text.strip() and generated_text:
final_output_text = generated_text
return final_output_text, result_image
# --- Gradio Interface ---
with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
gr.Markdown("# **NVIDIA Nemotron Parse v1.1 [OCR/Parsing]**", elem_id="main-title")
gr.Markdown("Upload a document image to extract text, tables, and layout structures using NVIDIA's state-of-the-art Parse model.")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Image", sources=["upload", "clipboard"])
submit_btn = gr.Button("Process Document", variant="primary")
examples = gr.Examples(
examples=["examples/1.jpg"],
inputs=image_input,
label="Examples"
)
with gr.Column(scale=2):
output_text = gr.Textbox(label="Parsed Content (Markdown/LaTeX)", lines=20, show_copy_button=True)
output_image = gr.Image(label="Detected Layout & Bounding Boxes", type="pil")
with gr.Accordion("Technical Details", open=False):
gr.Markdown("""
**Model:** [nvidia/NVIDIA-Nemotron-Parse-v1.1](https://huggingface.co/nvidia/NVIDIA-Nemotron-Parse-v1.1)
**Architecture:** Llama-3-Vila based.
**Capabilities:** High-accuracy OCR, Table extraction (to LaTeX/HTML), Figure detection.
""")
submit_btn.click(
fn=process_ocr_task,
inputs=[image_input],
outputs=[output_text, output_image]
)
if __name__ == "__main__":
demo.queue(max_size=20).launch(share=True, mcp_server=True, ssr_mode=False)