Spaces:
Sleeping
Sleeping
| import argparse | |
| import gradio as gr | |
| import os | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from serve.frontend import reload_javascript | |
| from serve.utils import ( | |
| configure_logger, | |
| ) | |
| from serve.gradio_utils import ( | |
| cancel_outputing, | |
| delete_last_conversation, | |
| reset_state, | |
| reset_textbox, | |
| transfer_input, | |
| wrap_gen_fn, | |
| ) | |
| from serve.chat_utils import compress_video_to_base64 | |
| from serve.examples import get_examples | |
| import logging | |
| TITLE = """<h1 align="left" style="min-width:200px; margin-top:0;">Chat with Video-XL-2 </h1>""" | |
| DESCRIPTION_TOP = """<a href="https://unabletousegit.github.io/video-xl2.github.io" target="_blank">Video-XL-2</a>, a better, faster, and high-frame-count model for long video understanding.""" | |
| DESCRIPTION = """""" | |
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| DEPLOY_MODELS = dict() | |
| logger = configure_logger() | |
| DEFAULT_IMAGE_TOKEN = "<image>" | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", type=str, default="Video-XL-2") | |
| parser.add_argument( | |
| "--local-path", | |
| type=str, | |
| default="/share/project/minghao/Share_1/Models/Video-XL-2", | |
| help="huggingface ckpt, optional", | |
| ) | |
| parser.add_argument("--ip", type=str, default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=7860) | |
| return parser.parse_args() | |
| def fetch_model(model_name: str): | |
| global DEPLOY_MODELS | |
| local_model_path = '/share/project/minghao/Share_1/Models/Video-XL-2' | |
| if model_name in DEPLOY_MODELS: | |
| model_info = DEPLOY_MODELS[model_name] | |
| print(f"{model_name} has been loaded.") | |
| else: | |
| print(f"{model_name} is loading...") | |
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' | |
| tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| local_model_path, | |
| trust_remote_code=True, | |
| device_map=device, | |
| quantization_config=None, | |
| attn_implementation="sdpa", | |
| torch_dtype=torch.float16, | |
| low_cpu_mem_usage=True | |
| ) | |
| DEPLOY_MODELS[model_name] = (model, tokenizer) | |
| print(f"Load {model_name} successfully...") | |
| model_info = DEPLOY_MODELS[model_name] | |
| return model_info | |
| def preview_images(files) -> list[str]: | |
| if files is None: | |
| return [] | |
| image_paths = [] | |
| for file in files: | |
| image_paths.append(file.name) | |
| return image_paths | |
| def predict( | |
| text, | |
| images, | |
| chatbot, | |
| history, | |
| top_p, | |
| temperature, | |
| max_generate_length, | |
| max_context_length_tokens, | |
| video_nframes, | |
| chunk_size: int = 512, | |
| ): | |
| """ | |
| Predict the response for the input text and images. | |
| Args: | |
| text (str): The input text. | |
| images (list[PIL.Image.Image]): The input images. | |
| chatbot (list): The chatbot. | |
| history (list): The history. | |
| top_p (float): The top-p value. | |
| temperature (float): The temperature value. | |
| repetition_penalty (float): The repetition penalty value. | |
| max_generate_length (int): The max length tokens. | |
| max_context_length_tokens (int): The max context length tokens. | |
| chunk_size (int): The chunk size. | |
| """ | |
| if images is None: | |
| pil_images = history["video_path"] | |
| else: | |
| pil_images = images[0].name | |
| print("running the prediction function") | |
| try: | |
| logger.info("fetching model") | |
| model, tokenizer = fetch_model(args.model) | |
| logger.info("model fetched") | |
| if text == "": | |
| yield chatbot, history, "Empty context." | |
| return | |
| except KeyError: | |
| logger.info("no model found") | |
| yield [[text, "No Model Found"]], [], "No Model Found" | |
| return | |
| gen_kwargs = { | |
| "do_sample": True if temperature > 1e-2 else False, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "num_beams": 1, | |
| "use_cache": True, | |
| "max_new_tokens": max_generate_length, | |
| } | |
| # Check if this is the very first turn with an image | |
| is_first_image_turn = (len(history) == 0 and pil_images) | |
| if is_first_image_turn: | |
| history["video_path"] = pil_images | |
| history["context"] = None | |
| response, temp_history = model.chat( | |
| history["video_path"] if "video_path" in history else pil_images, | |
| tokenizer, | |
| text, | |
| chat_history=history["context"], | |
| return_history=True, | |
| max_num_frames=video_nframes, | |
| sample_fps=None, | |
| max_sample_fps=None, | |
| generation_config=gen_kwargs | |
| ) | |
| text_for_history = text | |
| if is_first_image_turn: | |
| media_str = "" | |
| b64 = compress_video_to_base64(history["video_path"] if "video_path" in history else pil_images) | |
| media_str += ( | |
| f'<video controls style="max-width:300px;height:auto;" ' | |
| f'src="data:video/mp4;base64,{b64}"></video>' | |
| ) | |
| text_for_history = media_str + text_for_history | |
| chatbot.append([text_for_history, response]) | |
| else: | |
| chatbot.append([text_for_history, response]) | |
| history["context"] = (temp_history) | |
| logger.info("flushed result to gradio") | |
| print( | |
| f"temperature: {temperature}, " | |
| f"top_p: {top_p}, " | |
| f"max_generate_length: {max_generate_length}" | |
| ) | |
| yield chatbot, history, "Generate: Success" | |
| def retry( | |
| text, # This `text` is the current text box content, not the last user input | |
| images, | |
| chatbot, | |
| full_history, # This is the full history | |
| top_p, | |
| temperature, | |
| max_generate_length, | |
| max_context_length_tokens, | |
| video_nframes, | |
| chunk_size: int = 512, | |
| ): | |
| """ | |
| Retry the response for the input text and images. | |
| """ | |
| history = full_history["context"] | |
| if len(history) == 0: | |
| yield (chatbot, history, "Empty context") | |
| return | |
| # Get the last user input before popping | |
| # print("history:", history) | |
| last_user_input = history[-2]["content"] | |
| # Remove the last turn from chatbot and history | |
| chatbot.pop() | |
| history.pop() | |
| full_history["context"] = history | |
| # Now call predict with the last user input and the modified history | |
| yield from predict( | |
| last_user_input, # Pass the last user input as the current text | |
| images, # Images should be the same as the last turn | |
| chatbot, # Updated chatbot | |
| full_history, # Updated history | |
| top_p, | |
| temperature, | |
| max_generate_length, | |
| max_context_length_tokens, | |
| video_nframes, | |
| chunk_size, | |
| ) | |
| def build_demo(args: argparse.Namespace) -> gr.Blocks: | |
| with gr.Blocks(theme=gr.themes.Soft(), delete_cache=(1800, 1800)) as demo: | |
| history = gr.State(dict()) | |
| input_text = gr.State() | |
| input_images = gr.State() | |
| with gr.Row(): | |
| gr.HTML(TITLE) | |
| status_display = gr.Markdown("Success", elem_id="status_display") | |
| gr.Markdown(DESCRIPTION_TOP) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=4): | |
| with gr.Row(): | |
| chatbot = gr.Chatbot( | |
| elem_id="Video-XL-2_Demo-chatbot", | |
| show_share_button=True, | |
| bubble_full_width=False, | |
| height=600, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| text_box = gr.Textbox(show_label=False, placeholder="Enter text", container=False) | |
| with gr.Column(min_width=70): | |
| submit_btn = gr.Button("Send") | |
| with gr.Column(min_width=70): | |
| cancel_btn = gr.Button("Stop") | |
| with gr.Row(): | |
| empty_btn = gr.Button("π§Ή New Conversation") | |
| retry_btn = gr.Button("π Regenerate") | |
| del_last_btn = gr.Button("ποΈ Remove Last Turn") | |
| with gr.Column(): | |
| # add note no more than 2 images once | |
| gr.Markdown("Note: you can upload images or videos!") | |
| upload_images = gr.Files(file_types=["image", "video"], show_label=True) | |
| gallery = gr.Gallery(columns=[3], height="200px", show_label=True) | |
| upload_images.change(preview_images, inputs=upload_images, outputs=gallery) | |
| # Parameter Setting Tab for control the generation parameters | |
| with gr.Tab(label="Parameter Setting"): | |
| top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.001, step=0.05, interactive=True, label="Top-p") | |
| temperature = gr.Slider( | |
| minimum=0, maximum=1.0, value=0.01, step=0.1, interactive=True, label="Temperature" | |
| ) | |
| max_generate_length = gr.Slider( | |
| minimum=512, maximum=8192, value=4096, step=64, interactive=True, label="Max Generate Length" | |
| ) | |
| max_context_length_tokens = gr.Slider( | |
| minimum=512, maximum=65536, value=16384, step=64, interactive=True, label="Max Context Length Tokens" | |
| ) | |
| video_nframes = gr.Slider( | |
| minimum=1, maximum=128, value=128, step=1, interactive=True, label="Video Nframes" | |
| ) | |
| show_images = gr.HTML(visible=False) | |
| gr.Markdown("This demo is based on `moonshotai/Kimi-VL-A3B-Thinking` & `deepseek-ai/deepseek-vl2-small` and extends it by adding support for video input.") | |
| gr.Examples( | |
| examples=get_examples(ROOT_DIR), | |
| inputs=[upload_images, show_images, text_box], | |
| ) | |
| gr.Markdown() | |
| input_widgets = [ | |
| input_text, | |
| input_images, | |
| chatbot, | |
| history, | |
| top_p, | |
| temperature, | |
| max_generate_length, | |
| max_context_length_tokens, | |
| video_nframes | |
| ] | |
| output_widgets = [chatbot, history, status_display] | |
| transfer_input_args = dict( | |
| fn=transfer_input, | |
| inputs=[text_box, upload_images], | |
| outputs=[input_text, input_images, text_box, upload_images, submit_btn], | |
| show_progress=True, | |
| ) | |
| predict_args = dict(fn=predict, inputs=input_widgets, outputs=output_widgets, show_progress=True) | |
| retry_args = dict(fn=retry, inputs=input_widgets, outputs=output_widgets, show_progress=True) | |
| reset_args = dict(fn=reset_textbox, inputs=[], outputs=[text_box, status_display]) | |
| predict_events = [ | |
| text_box.submit(**transfer_input_args).then(**predict_args), | |
| submit_btn.click(**transfer_input_args).then(**predict_args), | |
| ] | |
| empty_btn.click(reset_state, outputs=output_widgets, show_progress=True) | |
| empty_btn.click(**reset_args) | |
| retry_btn.click(**retry_args) | |
| del_last_btn.click(delete_last_conversation, [chatbot, history], output_widgets, show_progress=True) | |
| cancel_btn.click(cancel_outputing, [], [status_display], cancels=predict_events) | |
| demo.title = "Video-XL-2_Demo Chatbot" | |
| return demo | |
| def main(args: argparse.Namespace): | |
| demo = build_demo(args) | |
| reload_javascript() | |
| # concurrency_count=CONCURRENT_COUNT, max_size=MAX_EVENTS | |
| favicon_path = os.path.join("serve/assets/favicon.ico") | |
| demo.queue().launch( | |
| favicon_path=favicon_path if os.path.exists(favicon_path) else None, | |
| server_name=args.ip, | |
| server_port=args.port, | |
| ) | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| main(args) | |