prithivMLmods commited on
Commit
211c5ed
·
verified ·
1 Parent(s): a37899d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -0
app.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from threading import Thread
4
+ from typing import Iterable
5
+
6
+ import gradio as gr
7
+ import spaces
8
+ import torch
9
+ from PIL import Image
10
+
11
+ from transformers import (
12
+ Qwen3VLForConditionalGeneration,
13
+ AutoProcessor,
14
+ TextIteratorStreamer,
15
+ )
16
+
17
+ from gradio.themes import Soft
18
+ from gradio.themes.utils import colors, fonts, sizes
19
+
20
+ # --- Theme Configuration ---
21
+ colors.steel_blue = colors.Color(
22
+ name="steel_blue",
23
+ c50="#EBF3F8",
24
+ c100="#D3E5F0",
25
+ c200="#A8CCE1",
26
+ c300="#7DB3D2",
27
+ c400="#529AC3",
28
+ c500="#4682B4",
29
+ c600="#3E72A0",
30
+ c700="#36638C",
31
+ c800="#2E5378",
32
+ c900="#264364",
33
+ c950="#1E3450",
34
+ )
35
+
36
+ class SteelBlueTheme(Soft):
37
+ def __init__(
38
+ self,
39
+ *,
40
+ primary_hue: colors.Color | str = colors.gray,
41
+ secondary_hue: colors.Color | str = colors.steel_blue,
42
+ neutral_hue: colors.Color | str = colors.slate,
43
+ text_size: sizes.Size | str = sizes.text_lg,
44
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
45
+ fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
46
+ ),
47
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
48
+ fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
49
+ ),
50
+ ):
51
+ super().__init__(
52
+ primary_hue=primary_hue,
53
+ secondary_hue=secondary_hue,
54
+ neutral_hue=neutral_hue,
55
+ text_size=text_size,
56
+ font=font,
57
+ font_mono=font_mono,
58
+ )
59
+ super().set(
60
+ background_fill_primary="*primary_50",
61
+ background_fill_primary_dark="*primary_900",
62
+ body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
63
+ body_background_fill_dark="linear-gradient(135deg, *primary_900, *primary_800)",
64
+ button_primary_text_color="white",
65
+ button_primary_text_color_hover="white",
66
+ button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
67
+ button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
68
+ button_primary_background_fill_dark="linear-gradient(90deg, *secondary_600, *secondary_800)",
69
+ button_primary_background_fill_hover_dark="linear-gradient(90deg, *secondary_500, *secondary_500)",
70
+ button_secondary_text_color="black",
71
+ button_secondary_text_color_hover="white",
72
+ button_secondary_background_fill="linear-gradient(90deg, *primary_300, *primary_300)",
73
+ button_secondary_background_fill_hover="linear-gradient(90deg, *primary_400, *primary_400)",
74
+ button_secondary_background_fill_dark="linear-gradient(90deg, *primary_500, *primary_600)",
75
+ button_secondary_background_fill_hover_dark="linear-gradient(90deg, *primary_500, *primary_500)",
76
+ slider_color="*secondary_500",
77
+ slider_color_dark="*secondary_600",
78
+ block_title_text_weight="600",
79
+ block_border_width="3px",
80
+ block_shadow="*shadow_drop_lg",
81
+ button_primary_shadow="*shadow_drop_lg",
82
+ button_large_padding="11px",
83
+ color_accent_soft="*primary_100",
84
+ block_label_background_fill="*primary_200",
85
+ )
86
+
87
+ steel_blue_theme = SteelBlueTheme()
88
+
89
+ css = """
90
+ #main-title h1 {
91
+ font-size: 2.3em !important;
92
+ }
93
+ #output-title h2 {
94
+ font-size: 2.1em !important;
95
+ }
96
+ """
97
+
98
+ # --- Device & Model Setup ---
99
+ MAX_MAX_NEW_TOKENS = 4096
100
+ DEFAULT_MAX_NEW_TOKENS = 2048
101
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
102
+
103
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
104
+
105
+ print("CUDA_VISIBLE_DEVICES=", os.environ.get("CUDA_VISIBLE_DEVICES"))
106
+ print("torch.__version__ =", torch.__version__)
107
+ print("torch.version.cuda =", torch.version.cuda)
108
+ print("cuda available:", torch.cuda.is_available())
109
+ if torch.cuda.is_available():
110
+ print("current device:", torch.cuda.current_device())
111
+ print("device name:", torch.cuda.get_device_name(torch.cuda.current_device()))
112
+
113
+ print("Using device:", device)
114
+
115
+ MODEL_ID = "Qwen/Qwen3-VL-8B-Instruct"
116
+ print(f"Loading model: {MODEL_ID}...")
117
+
118
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
119
+ model = Qwen3VLForConditionalGeneration.from_pretrained(
120
+ MODEL_ID,
121
+ attn_implementation="flash_attention_2",
122
+ trust_remote_code=True,
123
+ torch_dtype=torch.float16
124
+ ).to(device).eval()
125
+
126
+ print("Model loaded successfully.")
127
+
128
+ # --- Generation Logic ---
129
+ @spaces.GPU
130
+ def generate_image(text: str, image: Image.Image,
131
+ max_new_tokens: int, temperature: float, top_p: float,
132
+ top_k: int, repetition_penalty: float):
133
+ """
134
+ Generates responses using the Chandra-OCR model.
135
+ Yields raw text and Markdown-formatted text.
136
+ """
137
+ if image is None:
138
+ yield "Please upload an image.", "Please upload an image."
139
+ return
140
+
141
+ # Prepare messages
142
+ messages = [{
143
+ "role": "user",
144
+ "content": [
145
+ {"type": "image"},
146
+ {"type": "text", "text": text},
147
+ ]
148
+ }]
149
+
150
+ # Apply template
151
+ prompt_full = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
152
+
153
+ # Process inputs
154
+ inputs = processor(
155
+ text=[prompt_full],
156
+ images=[image],
157
+ return_tensors="pt",
158
+ padding=True
159
+ ).to(device)
160
+
161
+ # Setup streamer
162
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
163
+
164
+ generation_kwargs = {
165
+ **inputs,
166
+ "streamer": streamer,
167
+ "max_new_tokens": max_new_tokens,
168
+ "do_sample": True,
169
+ "temperature": temperature,
170
+ "top_p": top_p,
171
+ "top_k": top_k,
172
+ "repetition_penalty": repetition_penalty,
173
+ }
174
+
175
+ # Start generation thread
176
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
177
+ thread.start()
178
+
179
+ buffer = ""
180
+ for new_text in streamer:
181
+ buffer += new_text
182
+ # Clean specific tokens if necessary
183
+ buffer = buffer.replace("<|im_end|>", "")
184
+ time.sleep(0.01)
185
+ yield buffer, buffer
186
+
187
+ # --- Gradio Interface ---
188
+ image_examples = [
189
+ ["OCR the content perfectly.", "examples/3.jpg"],
190
+ ["Perform OCR on the image.", "examples/1.jpg"],
191
+ ["Extract the contents. [page].", "examples/2.jpg"],
192
+ ]
193
+
194
+ with gr.Blocks(css=css, theme=steel_blue_theme) as demo:
195
+ gr.Markdown("# **vibe-voice**", elem_id="main-title")
196
+
197
+ with gr.Row():
198
+ with gr.Column(scale=2):
199
+ image_query = gr.Textbox(label="Query Input", placeholder="Enter your query here...")
200
+ image_upload = gr.Image(type="pil", label="Upload Image", height=290)
201
+
202
+ image_submit = gr.Button("Submit", variant="primary")
203
+
204
+ # Note: Ensure these example paths exist in your environment
205
+ gr.Examples(
206
+ examples=image_examples,
207
+ inputs=[image_query, image_upload]
208
+ )
209
+
210
+ with gr.Accordion("Advanced options", open=False):
211
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
212
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.7)
213
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
214
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
215
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.1)
216
+
217
+ with gr.Column(scale=3):
218
+ gr.Markdown("## Output", elem_id="output-title")
219
+ output = gr.Textbox(label="Raw Output Stream", interactive=True, lines=11)
220
+ with gr.Accordion("(Result.md)", open=False):
221
+ markdown_output = gr.Markdown(label="(Result.Md)")
222
+
223
+ image_submit.click(
224
+ fn=generate_image,
225
+ inputs=[image_query, image_upload, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
226
+ outputs=[output, markdown_output]
227
+ )
228
+
229
+ if __name__ == "__main__":
230
+ demo.queue(max_size=30).launch(ssr_mode=False, show_error=True)