Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import uuid | |
| from pathlib import Path | |
| from hydra import compose, initialize | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from torchvision import transforms | |
| from einops import rearrange | |
| from huggingface_hub import hf_hub_download | |
| import spaces | |
| sys.path.append(str(Path(__file__).resolve().parent.parent)) | |
| # pylint: disable=wrong-import-position | |
| from algorithms.wan.wan_i2v import WanImageToVideo | |
| from utils.video_utils import numpy_to_mp4_bytes | |
| DEVICE = "cuda" | |
| def load_model() -> WanImageToVideo: | |
| print("Downloading model...") | |
| ckpt_path = hf_hub_download( | |
| repo_id="KempnerInstituteAI/LVP", | |
| filename="checkpoints/LVP_14B_inference.ckpt", | |
| cache_dir="./huggingface", | |
| ) | |
| umt5_path = hf_hub_download( | |
| repo_id="Wan-AI/Wan2.1-I2V-14B-480P", | |
| filename="models_t5_umt5-xxl-enc-bf16.pth", | |
| cache_dir="./huggingface", | |
| ) | |
| vae_path = hf_hub_download( | |
| repo_id="Wan-AI/Wan2.1-I2V-14B-480P", | |
| filename="Wan2.1_VAE.pth", | |
| cache_dir="./huggingface", | |
| ) | |
| clip_path = hf_hub_download( | |
| repo_id="Wan-AI/Wan2.1-I2V-14B-480P", | |
| filename="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", | |
| cache_dir="./huggingface", | |
| ) | |
| config_path = hf_hub_download( | |
| repo_id="Wan-AI/Wan2.1-I2V-14B-480P", | |
| filename="config.json", | |
| cache_dir="./huggingface/Wan2.1-I2V-14B-480P", | |
| ) | |
| with initialize(version_base=None, config_path="./configurations"): | |
| cfg = compose( | |
| config_name="config", | |
| overrides=[ | |
| "experiment=exp_video", | |
| "algorithm=wan_i2v", | |
| "dataset=dummy", | |
| "experiment.tasks=[test]", | |
| "algorithm.sample_steps=40", | |
| "algorithm.load_prompt_embed=False", | |
| f"algorithm.model.tuned_ckpt_path={ckpt_path}", | |
| f"algorithm.text_encoder.ckpt_path={umt5_path}", | |
| f"algorithm.vae.ckpt_path={vae_path}", | |
| f"algorithm.clip.ckpt_path={clip_path}", | |
| f"algorithm.model.ckpt_path={Path(config_path).parent}", | |
| ], | |
| ) | |
| OmegaConf.resolve(cfg) | |
| cfg = cfg.algorithm | |
| print("Initializing model...") | |
| _model = WanImageToVideo(cfg) | |
| print("Configuring model...") | |
| _model.configure_model() | |
| _model = _model.eval().to(DEVICE) | |
| _model.vae_scale = [_model.vae_mean, _model.vae_inv_std] | |
| return _model | |
| def load_transform(height: int, width: int): | |
| return transforms.Compose( | |
| [ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | |
| transforms.RandomResizedCrop( | |
| size=(height, width), | |
| scale=(1.0, 1.0), | |
| ratio=(width / height, width / height), | |
| interpolation=transforms.InterpolationMode.BICUBIC, | |
| ), | |
| ] | |
| ) | |
| model = load_model() | |
| print("Model loaded successfully") | |
| transform = load_transform(model.height, model.width) | |
| def get_duration(image: str, prompt: str, sample_steps: int, lang_guidance: float, hist_guidance: float, progress: gr.Progress) -> int: | |
| step_duration = 5 | |
| multiplier = 1 + int(lang_guidance > 0) + int(hist_guidance > 0) - int(lang_guidance == hist_guidance and lang_guidance > 0) | |
| return int(20 + sample_steps * multiplier * step_duration) | |
| def infer_i2v( | |
| image: str, | |
| prompt: str, | |
| sample_steps: int, | |
| lang_guidance: float, | |
| hist_guidance: float, | |
| progress: gr.Progress = gr.Progress(), | |
| ) -> str: | |
| """Run I2V inference, given an image path, prompt, and sampling parameters.""" | |
| image = transform(Image.open(image).convert("RGB")) | |
| videos = torch.randn(1, model.n_frames, 3, model.height, model.width, device=DEVICE) | |
| videos[:, 0] = image[None] | |
| batch = { | |
| "videos": videos, | |
| "prompts": [prompt], | |
| "has_bbox": torch.zeros(1, 2, device=DEVICE).bool(), | |
| "bbox_render": torch.zeros(1, 2, model.height, model.width, device=DEVICE), | |
| } | |
| model.hist_guidance = hist_guidance | |
| model.lang_guidance = lang_guidance | |
| model.sample_steps = sample_steps | |
| pbar = progress.tqdm(range(sample_steps), desc="Sampling") | |
| video = rearrange( | |
| model.sample_seq(batch, pbar=pbar).squeeze(0), "t c h w -> t h w c" | |
| ) | |
| video = video.squeeze(0).float().cpu().numpy() | |
| video = np.clip(video * 0.5 + 0.5, 0, 1) | |
| video = (video * 255).astype(np.uint8) | |
| video_bytes = numpy_to_mp4_bytes(video, fps=model.cfg.logging.fps) | |
| videos_dir = Path("./videos") | |
| videos_dir.mkdir(exist_ok=True) | |
| video_path = videos_dir / f"{uuid.uuid4()}.mp4" | |
| with open(video_path, "wb") as f: | |
| f.write(video_bytes) | |
| return video_path.as_posix() | |
| examples_dir = Path("examples") | |
| examples = [] | |
| if examples_dir.exists(): | |
| for image_path in sorted(examples_dir.iterdir()): | |
| if not image_path.is_file(): | |
| continue | |
| examples.append([image_path.as_posix(), image_path.stem[2:].replace("_", " ")]) | |
| if __name__ == "__main__": | |
| with gr.Blocks() as demo: | |
| gr.HTML( | |
| """ | |
| <style> | |
| .header-button-row { | |
| gap: 4px !important; | |
| } | |
| .header-button-row div { | |
| width: 131.0px !important; | |
| } | |
| .header-button-column { | |
| width: 131.0px !important; | |
| gap: 5px !important; | |
| } | |
| .header-button a { | |
| border: 1px solid #e4e4e7; | |
| } | |
| .header-button .button-icon { | |
| margin-right: 8px; | |
| } | |
| #sample-gallery table { | |
| width: 100% !important; | |
| } | |
| #sample-gallery td:first-child { | |
| width: 25% !important; | |
| } | |
| #sample-gallery .border.table, | |
| #sample-gallery .container.table, | |
| #sample-gallery .container { | |
| max-height: none !important; | |
| height: auto !important; | |
| max-width: none !important; | |
| width: 100% !important; | |
| } | |
| #sample-gallery img { | |
| width: 100% !important; | |
| height: auto !important; | |
| object-fit: contain !important; | |
| } | |
| </style> | |
| """ | |
| ) | |
| with gr.Sidebar(): | |
| gr.Markdown("# Large Video Planner") | |
| gr.Markdown( | |
| "### Official Interactive Demo for [_Large Video Planner Enables Generalizable Robot Control_](todo)" | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("#### Links ↓") | |
| with gr.Row(elem_classes=["header-button-row"]): | |
| with gr.Column(elem_classes=["header-button-column"], min_width=0): | |
| gr.Button( | |
| value="Website", | |
| link="https://www.boyuan.space/large-video-planner/", | |
| icon="https://simpleicons.org/icons/googlechrome.svg", | |
| elem_classes=["header-button"], | |
| size="md", | |
| min_width=0, | |
| ) | |
| gr.Button( | |
| value="Paper", | |
| link="todo", | |
| icon="https://simpleicons.org/icons/arxiv.svg", | |
| elem_classes=["header-button"], | |
| size="md", | |
| min_width=0, | |
| ) | |
| with gr.Column(elem_classes=["header-button-column"], min_width=0): | |
| gr.Button( | |
| value="Code", | |
| link="https://github.com/buoyancy99/large-video-planner", | |
| icon="https://simpleicons.org/icons/github.svg", | |
| elem_classes=["header-button"], | |
| size="md", | |
| min_width=0, | |
| ) | |
| gr.Button( | |
| value="Weights", | |
| link="https://huggingface.co/large-video-planner/LVP", | |
| icon="https://simpleicons.org/icons/huggingface.svg", | |
| elem_classes=["header-button"], | |
| size="md", | |
| min_width=0, | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown("#### Troubleshooting ↓") | |
| with gr.Group(): | |
| with gr.Accordion("Error or Unexpected Results?", open=False): | |
| gr.Markdown("Please try again after refreshing the page and ensure you do not click the same button multiple times.") | |
| with gr.Accordion("Too Slow or No GPU Allocation?", open=False): | |
| gr.Markdown( | |
| "This demo may respond slowly because it runs a large, non-distilled model. Consider running the demo locally (click the dots in the top-right corner). Alternatively, you can subscribe to Hugging Face Pro for an increased GPU quota." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Input Image", type="filepath") | |
| prompt_input = gr.Textbox(label="Prompt", lines=2, max_lines=2) | |
| with gr.Column(): | |
| sample_steps_slider = gr.Slider( | |
| label="Sampling Steps", | |
| minimum=10, | |
| maximum=50, | |
| value=30, | |
| step=1, | |
| ) | |
| lang_guidance_slider = gr.Slider( | |
| label="Language Guidance (recommended 1.5-2.5)", | |
| minimum=0, | |
| maximum=5, | |
| value=2.5, | |
| step=0.1, | |
| ) | |
| hist_guidance_slider = gr.Slider( | |
| label="History Guidance (recommended 1.0-2.0)", | |
| minimum=0, | |
| maximum=5, | |
| value=1.5, | |
| step=0.1, | |
| ) | |
| run_button = gr.Button("Generate Video") | |
| with gr.Column(): | |
| video_output = gr.Video(label="Generated Video") | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[image_input, prompt_input], | |
| outputs=[video_output], | |
| run_on_click=False, | |
| elem_id="sample-gallery", | |
| ) | |
| run_button.click( # pylint: disable=no-member | |
| fn=infer_i2v, | |
| inputs=[ | |
| image_input, | |
| prompt_input, | |
| sample_steps_slider, | |
| lang_guidance_slider, | |
| hist_guidance_slider, | |
| ], | |
| outputs=video_output, | |
| ) | |
| demo.launch(share=True) | |