Spaces:
Running
on
Zero
Running
on
Zero
| REVISION = "bce9358ca7928fc17c0c82d5fa2253aa681a4624" | |
| try: | |
| import spaces | |
| IN_SPACES = True | |
| except ImportError: | |
| from functools import wraps | |
| import inspect | |
| class spaces: | |
| def GPU(duration): | |
| def decorator(func): | |
| # Preserves the original function's metadata | |
| def wrapper(*args, **kwargs): | |
| if inspect.isgeneratorfunction(func): | |
| # If the decorated function is a generator, yield from it | |
| yield from func(*args, **kwargs) | |
| else: | |
| # For regular functions, just return the result | |
| return func(*args, **kwargs) | |
| return wrapper | |
| return decorator | |
| IN_SPACES = False | |
| import torch | |
| import os | |
| import gradio as gr | |
| import json | |
| from queue import Queue | |
| from threading import Thread | |
| from transformers import AutoModelForCausalLM | |
| from PIL import ImageDraw | |
| from torchvision.transforms.v2 import Resize | |
| os.environ["HF_TOKEN"] = os.environ.get("TOKEN_FROM_SECRET") or True | |
| moondream = AutoModelForCausalLM.from_pretrained( | |
| "vikhyatk/moondream-next", | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16, | |
| device_map={"": "cuda"}, | |
| revision=REVISION | |
| ) | |
| moondream.eval() | |
| def convert_to_entities(text, coords): | |
| """ | |
| Converts a string with special markers into an entity representation. | |
| Markers: | |
| - <|coord|> pairs indicate coordinate markers | |
| - <|start_ground_points|> indicates the start of grounding | |
| - <|start_ground_text|> indicates the start of a ground term | |
| - <|end_ground|> indicates the end of a ground term | |
| Returns: | |
| - Dictionary with cleaned text and entities with their character positions | |
| """ | |
| # Initialize variables | |
| cleaned_text = "" | |
| entities = [] | |
| entity = [] | |
| # Track current position in cleaned text | |
| current_pos = 0 | |
| # Track if we're currently processing an entity | |
| in_entity = False | |
| entity_start = 0 | |
| i = 0 | |
| while i < len(text): | |
| # Check for markers | |
| if text[i : i + 9] == "<|coord|>": | |
| i += 9 | |
| entity.append(coords.pop(0)) | |
| continue | |
| elif text[i : i + 23] == "<|start_ground_points|>": | |
| in_entity = True | |
| entity_start = current_pos | |
| i += 23 | |
| continue | |
| elif text[i : i + 21] == "<|start_ground_text|>": | |
| entity_start = current_pos | |
| i += 21 | |
| continue | |
| elif text[i : i + 14] == "<|end_ground|>": | |
| # Store entity position | |
| entities.append( | |
| { | |
| "entity": json.dumps(entity), | |
| "start": entity_start, | |
| "end": current_pos, | |
| } | |
| ) | |
| entity = [] | |
| in_entity = False | |
| i += 14 | |
| continue | |
| # Add character to cleaned text | |
| cleaned_text += text[i] | |
| current_pos += 1 | |
| i += 1 | |
| return {"text": cleaned_text, "entities": entities} | |
| def answer_question(img, prompt, reasoning): | |
| buffer = "" | |
| resp = moondream.query(img, prompt, stream=True, reasoning=reasoning) | |
| reasoning_text = resp["reasoning"]["text"] if reasoning else "[reasoning disabled]" | |
| entities = [ | |
| {"start": g["start_idx"], "end": g["end_idx"], "entity": json.dumps(g["points"])} | |
| for g in resp["reasoning"]["grounding"] | |
| ] if reasoning else [] | |
| for new_text in resp["answer"]: | |
| buffer += new_text | |
| yield buffer.strip(), {"text": reasoning_text, "entities": entities} | |
| def caption(img, mode): | |
| if img is None: | |
| yield "" | |
| return | |
| buffer = "" | |
| if mode == "Short": | |
| l = "short" | |
| elif mode == "Long": | |
| l = "long" | |
| else: | |
| l = "normal" | |
| for t in moondream.caption(img, length=l, stream=True)["caption"]: | |
| buffer += t | |
| yield buffer.strip() | |
| def detect(img, object, eos_bias): | |
| if img is None: | |
| yield "", gr.update(visible=False, value=None) | |
| return | |
| eos_bias = float(eos_bias) | |
| objs = moondream.detect(img, object, settings={"eos_bias": eos_bias})["objects"] | |
| w, h = img.size | |
| if w > 768 or h > 768: | |
| img = Resize(768)(img) | |
| w, h = img.size | |
| draw_image = ImageDraw.Draw(img) | |
| for o in objs: | |
| draw_image.rectangle( | |
| (o["x_min"] * w, o["y_min"] * h, o["x_max"] * w, o["y_max"] * h), | |
| outline="red", | |
| width=3, | |
| ) | |
| yield {"text": f"{len(objs)} detected", "entities": []}, gr.update( | |
| visible=True, value=img | |
| ) | |
| def point(img, object): | |
| if img is None: | |
| yield "", gr.update(visible=False, value=None) | |
| return | |
| w, h = img.size | |
| if w > 768 or h > 768: | |
| img = Resize(768)(img) | |
| w, h = img.size | |
| objs = moondream.point(img, object, settings={"max_objects": 200})["points"] | |
| draw_image = ImageDraw.Draw(img) | |
| for o in objs: | |
| draw_image.ellipse( | |
| (o["x"] * w - 5, o["y"] * h - 5, o["x"] * w + 5, o["y"] * h + 5), | |
| fill="red", | |
| outline="blue", | |
| width=2, | |
| ) | |
| yield {"text": f"{len(objs)} detected", "entities": []}, gr.update( | |
| visible=True, value=img | |
| ) | |
| def localized_query(img, x, y, question): | |
| if img is None: | |
| yield "", {"text": "", "entities": []}, gr.update(visible=False, value=None) | |
| return | |
| answer = moondream.query(img, question, spatial_refs=[(x, y)])["answer"] | |
| w, h = img.size | |
| x, y = x * w, y * h | |
| img_clone = img.copy() | |
| draw = ImageDraw.Draw(img_clone) | |
| draw.ellipse( | |
| (x - 5, y - 5, x + 5, y + 5), | |
| fill="red", | |
| outline="blue", | |
| ) | |
| yield answer, {"text": "", "entities": []}, gr.update(visible=True, value=img_clone) | |
| js = "" | |
| css = """ | |
| .output-text span p { | |
| font-size: 1.4rem !important; | |
| } | |
| .chain-of-thought { | |
| opacity: 0.7 !important; | |
| } | |
| .chain-of-thought span.label { | |
| display: none; | |
| } | |
| .chain-of-thought span.textspan { | |
| padding-right: 0; | |
| } | |
| """ | |
| with gr.Blocks(title="moondream vl (new)", css=css, js=js) as demo: | |
| if IN_SPACES: | |
| # gr.HTML("<style>body, body gradio-app { background: none !important; }</style>") | |
| pass | |
| gr.Markdown( | |
| """ | |
| # 🌔 test space, pls ignore | |
| """ | |
| ) | |
| mode_radio = gr.Radio( | |
| ["Caption", "Query", "Detect", "Point", "Localized"], | |
| show_label=False, | |
| value=lambda: "Caption", | |
| ) | |
| input_image = gr.State(None) | |
| with gr.Row(): | |
| with gr.Column(): | |
| def show_inputs(mode): | |
| if mode == "Query": | |
| with gr.Group(): | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| label="Input", | |
| value="How many people are in this image?", | |
| scale=4, | |
| ) | |
| submit = gr.Button("Submit") | |
| reasoning = gr.Checkbox(label="Enable reasoning") | |
| img = gr.Image(type="pil", label="Upload an Image") | |
| submit.click(answer_question, [img, prompt, reasoning], [output, thought]) | |
| prompt.submit(answer_question, [img, prompt, reasoning], [output, thought]) | |
| reasoning.change(answer_question, [img, prompt, reasoning], [output, thought]) | |
| img.change(answer_question, [img, prompt, reasoning], [output, thought]) | |
| img.change(lambda img: img, [img], [input_image]) | |
| elif mode == "Caption": | |
| with gr.Group(): | |
| with gr.Row(): | |
| caption_mode = gr.Radio( | |
| ["Short", "Normal", "Long"], | |
| label="Caption Length", | |
| value=lambda: "Normal", | |
| scale=4, | |
| ) | |
| submit = gr.Button("Submit") | |
| img = gr.Image(type="pil", label="Upload an Image") | |
| submit.click(caption, [img, caption_mode], output) | |
| img.change(caption, [img, caption_mode], output) | |
| elif mode == "Detect": | |
| with gr.Group(): | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| label="Object", | |
| value="Cat", | |
| scale=4, | |
| ) | |
| submit = gr.Button("Submit") | |
| img = gr.Image(type="pil", label="Upload an Image") | |
| eos_bias = gr.Textbox(label="EOS Bias", value="0") | |
| submit.click(detect, [img, prompt, eos_bias], [thought, ann]) | |
| prompt.submit(detect, [img, prompt, eos_bias], [thought, ann]) | |
| img.change(detect, [img, prompt, eos_bias], [thought, ann]) | |
| elif mode == "Point": | |
| with gr.Group(): | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| label="Object", | |
| value="Cat", | |
| scale=4, | |
| ) | |
| submit = gr.Button("Submit") | |
| img = gr.Image(type="pil", label="Upload an Image") | |
| submit.click(point, [img, prompt], [thought, ann]) | |
| prompt.submit(point, [img, prompt], [thought, ann]) | |
| img.change(point, [img, prompt], [thought, ann]) | |
| elif mode == "Localized": | |
| with gr.Group(): | |
| with gr.Row(): | |
| prompt = gr.Textbox( | |
| label="Input", | |
| value="What is this?", | |
| scale=4, | |
| ) | |
| submit = gr.Button("Submit") | |
| img = gr.Image(type="pil", label="Upload an Image") | |
| x_slider = gr.Slider(label="x", minimum=0, maximum=1) | |
| y_slider = gr.Slider(label="y", minimum=0, maximum=1) | |
| submit.click(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) | |
| prompt.submit(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) | |
| x_slider.change(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) | |
| y_slider.change(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) | |
| img.change(localized_query, [img, x_slider, y_slider, prompt], [output, thought, ann]) | |
| def select_handler(image, evt: gr.SelectData): | |
| w, h = img.size | |
| return [evt.index[0] / w, evt.index[1] / h] | |
| img.select(select_handler, img, [x_slider, y_slider]) | |
| else: | |
| gr.Markdown("Coming soon!") | |
| with gr.Column(): | |
| thought = gr.HighlightedText( | |
| elem_classes=["chain-of-thought"], | |
| label="Thinking tokens", | |
| interactive=False, | |
| ) | |
| output = gr.Markdown(label="Response", elem_classes=["output-text"], line_breaks=True) | |
| ann = gr.Image(visible=False) | |
| def on_select(img, evt: gr.SelectData): | |
| if img is None or evt.value[1] is None: | |
| return gr.update(visible=False, value=None) | |
| w, h = img.size | |
| if w > 768 or h > 768: | |
| img = Resize(768)(img) | |
| w, h = img.size | |
| points = json.loads(evt.value[1]) | |
| img_clone = img.copy() | |
| draw = ImageDraw.Draw(img_clone) | |
| for point in points: | |
| x = int(point[0] * w) | |
| y = int(point[1] * h) | |
| draw.ellipse( | |
| (x - 3, y - 3, x + 3, y + 3), | |
| fill="red", | |
| outline="red", | |
| ) | |
| return gr.update(visible=True, value=img_clone) | |
| thought.select(on_select, [input_image], [ann]) | |
| input_image.change(lambda: gr.update(visible=False), [], [ann]) | |
| mode_radio.change( | |
| lambda: ("", "", gr.update(visible=False, value=None)), | |
| [], | |
| [output, thought, ann], | |
| ) | |
| demo.queue().launch() | |