IdlecloudX commited on
Commit
a73378e
·
verified ·
1 Parent(s): a4820f6

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +301 -0
  2. optimization.py +132 -0
  3. optimization_utils.py +107 -0
  4. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch 2.8 (temporary hack)
2
+ import os
3
+ os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
4
+
5
+ # Actual demo code
6
+ import spaces
7
+ import torch
8
+ from diffusers import FlowMatchEulerDiscreteScheduler
9
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
10
+ from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
11
+ from diffusers.utils.export_utils import export_to_video
12
+ import gradio as gr
13
+ import tempfile
14
+ import numpy as np
15
+ from PIL import Image
16
+ import random
17
+ import gc
18
+ from optimization import optimize_pipeline_
19
+
20
+ SECRET_KEY = os.environ.get("SECRET_KEY")
21
+
22
+ # 如果在 Space 中没有设置密钥
23
+ if not SECRET_KEY:
24
+ raise ValueError("请设置 SECRET_KEY 环境变量。")
25
+
26
+ MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
27
+
28
+ # 在这里配置所有的 LoRA。
29
+ LORA_REPO_ID = "IdlecloudX/Flux_and_Wan_Lora"
30
+ LORA_SETS = {
31
+ "NF": {
32
+ "high_noise": {"file": "NSFW-22-H-e8.safetensors", "adapter_name": "nf_high"},
33
+ "low_noise": {"file": "NSFW-22-L-e8.safetensors", "adapter_name": "nf_low"}
34
+ },
35
+ "BP": {
36
+ "high_noise": {"file": "Wan2.2_BP-v1-HighNoise-I2V_T2V.safetensors", "adapter_name": "bp_high"},
37
+ "low_noise": {"file": "Wan2.2_BP-v1-LowNoise-I2V_T2V.safetensors", "adapter_name": "bp_low"}
38
+ },
39
+ "Py-v1": {
40
+ "high_noise": {"file": "WAN2.2-HighNoise_Pyv1-I2V_T2V.safetensors", "adapter_name": "py_high"},
41
+ "low_noise": {"file": "WAN2.2-LowNoise_Pyv1-I2V_T2V.safetensors", "adapter_name": "py_low"}
42
+ }
43
+ }
44
+
45
+ MAX_DIMENSION = 832
46
+ MIN_DIMENSION = 576
47
+
48
+ DIMENSION_MULTIPLE = 16
49
+ SQUARE_SIZE = 640
50
+ MAX_SEED = np.iinfo(np.int32).max
51
+ FIXED_FPS = 16
52
+ MIN_FRAMES_MODEL = 8
53
+ MAX_FRAMES_MODEL = 81
54
+ MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1)
55
+ MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1)
56
+
57
+
58
+ print("正在加载模型...")
59
+ pipe = WanImageToVideoPipeline.from_pretrained(
60
+ MODEL_ID,
61
+ transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
62
+ subfolder='transformer',
63
+ torch_dtype=torch.bfloat16,
64
+ device_map='cuda',
65
+ ),
66
+ transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
67
+ subfolder='transformer_2',
68
+ torch_dtype=torch.bfloat16,
69
+ device_map='cuda',
70
+ ),
71
+ torch_dtype=torch.bfloat16,
72
+ )
73
+ # 使用新的调度器
74
+ pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, shift=8.0)
75
+ pipe.to('cuda')
76
+ print("模型加载完成。")
77
+
78
+
79
+ print("开始优化 Pipeline...")
80
+ optimize_pipeline_(pipe,
81
+ image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)),
82
+ last_image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)), # 为首尾帧功能添加 last_image
83
+ prompt='prompt',
84
+ height=MIN_DIMENSION,
85
+ width=MAX_DIMENSION,
86
+ num_frames=MAX_FRAMES_MODEL,
87
+ )
88
+ print("优化完成。")
89
+
90
+ for name, lora_set in LORA_SETS.items():
91
+ print(f"--- 开始加载 LoRA 集合: {name} ---")
92
+
93
+ # 加载 High Noise
94
+ high_noise_config = lora_set["high_noise"]
95
+ print(f"正在加载 High Noise: {high_noise_config['file']}...")
96
+ pipe.load_lora_weights(LORA_REPO_ID, weight_name=high_noise_config['file'], adapter_name=high_noise_config['adapter_name'])
97
+ print("High Noise LoRA 加载完成。")
98
+
99
+ # 加载 Low Noise
100
+ low_noise_config = lora_set["low_noise"]
101
+ print(f"正在加载 Low Noise: {low_noise_config['file']}...")
102
+ pipe.load_lora_weights(LORA_REPO_ID, weight_name=low_noise_config['file'], adapter_name=low_noise_config['adapter_name'])
103
+ print("Low Noise LoRA 加载完成。")
104
+ print("所有自定义 LoRA 加载完毕。")
105
+
106
+ for i in range(3):
107
+ gc.collect()
108
+ torch.cuda.synchronize()
109
+ torch.cuda.empty_cache()
110
+
111
+ default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
112
+ default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
113
+
114
+
115
+ def process_image_for_video(image: Image.Image) -> Image.Image:
116
+ width, height = image.size
117
+ if width == height:
118
+ return image.resize((SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS)
119
+
120
+ aspect_ratio = width / height
121
+ new_width, new_height = width, height
122
+
123
+ if new_width > MAX_DIMENSION or new_height > MAX_DIMENSION:
124
+ scale = MAX_DIMENSION / (new_width if aspect_ratio > 1 else new_height)
125
+ new_width, new_height = new_width * scale, new_height * scale
126
+
127
+ if new_width < MIN_DIMENSION or new_height < MIN_DIMENSION:
128
+ scale = MIN_DIMENSION / (new_height if aspect_ratio > 1 else new_width)
129
+ new_width, new_height = new_width * scale, new_height * scale
130
+
131
+ final_width = int(round(new_width / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
132
+ final_height = int(round(new_height / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
133
+
134
+ final_width = max(final_width, MIN_DIMENSION if aspect_ratio < 1 else SQUARE_SIZE)
135
+ final_height = max(final_height, MIN_DIMENSION if aspect_ratio > 1 else SQUARE_SIZE)
136
+
137
+ return image.resize((final_width, final_height), Image.Resampling.LANCZOS)
138
+
139
+ def resize_and_crop_to_match(target_image, reference_image):
140
+ ref_width, ref_height = reference_image.size
141
+ target_width, target_height = target_image.size
142
+ scale = max(ref_width / target_width, ref_height / target_height)
143
+ new_width, new_height = int(target_width * scale), int(target_height * scale)
144
+ resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
145
+ left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
146
+ return resized.crop((left, top, left + ref_width, top + ref_height))
147
+
148
+ def get_duration(
149
+ secret_key,
150
+ input_image,
151
+ prompt,
152
+ steps,
153
+ negative_prompt,
154
+ duration_seconds,
155
+ guidance_scale,
156
+ guidance_scale_2,
157
+ seed,
158
+ randomize_seed,
159
+ selected_loras,
160
+ progress,
161
+ ):
162
+ return int(steps) * 15
163
+
164
+ @spaces.GPU(duration=get_duration)
165
+ def generate_video(
166
+ secret_key,
167
+ start_image_pil,
168
+ end_image_pil,
169
+ prompt,
170
+ steps = 8,
171
+ negative_prompt=default_negative_prompt,
172
+ duration_seconds=3.5,
173
+ guidance_scale=1,
174
+ guidance_scale_2=1,
175
+ seed=42,
176
+ randomize_seed=False,
177
+ selected_loras=[],
178
+ progress=gr.Progress(track_tqdm=True),
179
+ ):
180
+ if secret_key != SECRET_KEY:
181
+ raise gr.Error("无效的密钥!请输入正确的密钥。")
182
+
183
+ if start_image_pil is None or end_image_pil is None:
184
+ raise gr.Error("请上传开始帧和结束帧。")
185
+
186
+ progress(0.1, desc="正在预处理图像...")
187
+ processed_start_image = process_image_for_video(start_image_pil)
188
+ processed_end_image = resize_and_crop_to_match(end_image_pil, processed_start_image)
189
+ target_height, target_width = processed_start_image.height, processed_start_image.width
190
+
191
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
192
+ num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
193
+ num_inference_steps = int(steps)
194
+ switch_step = num_inference_steps // 2
195
+
196
+ progress(0.2, desc=f"正在生成 {num_frames} 帧,尺寸 {target_width}x{target_height} (seed: {current_seed})...")
197
+
198
+ class LoraSwitcher:
199
+ def __init__(self, selected_lora_names):
200
+ self.switched = False
201
+ self.high_noise_adapters = []
202
+ self.low_noise_adapters = []
203
+ if selected_lora_names:
204
+ for name in selected_lora_names:
205
+ if name in LORA_SETS:
206
+ self.high_noise_adapters.append(LORA_SETS[name]["high_noise"]["adapter_name"])
207
+ self.low_noise_adapters.append(LORA_SETS[name]["low_noise"]["adapter_name"])
208
+
209
+ def __call__(self, pipe, step_index, timestep, callback_kwargs):
210
+ if step_index == 0:
211
+ self.switched = False
212
+ if self.high_noise_adapters:
213
+ print(f"激活 High Noise LoRA: {self.high_noise_adapters}")
214
+ pipe.set_adapters(self.high_noise_adapters, adapter_weights=[1.0] * len(self.high_noise_adapters))
215
+ elif pipe.get_active_adapters():
216
+ active_adapters = pipe.get_active_adapters()
217
+ print(f"未选择 LoRA,通过设置权重为0来禁用残留的 LoRA: {active_adapters}")
218
+ pipe.set_adapters(active_adapters, adapter_weights=[0.0] * len(active_adapters))
219
+
220
+ if self.low_noise_adapters and step_index >= switch_step and not self.switched:
221
+ print(f"在第 {step_index} 步切换到 Low Noise LoRA: {self.low_noise_adapters}")
222
+ pipe.set_adapters(self.low_noise_adapters, adapter_weights=[1.0] * len(self.low_noise_adapters))
223
+ self.switched = True
224
+ return callback_kwargs
225
+
226
+ lora_switcher_callback = LoraSwitcher(selected_loras)
227
+
228
+ output_frames_list = pipe(
229
+ image=processed_start_image,
230
+ last_image=processed_end_image,
231
+ prompt=prompt,
232
+ negative_prompt=negative_prompt,
233
+ height=target_height,
234
+ width=target_width,
235
+ num_frames=num_frames,
236
+ guidance_scale=float(guidance_scale),
237
+ guidance_scale_2=float(guidance_scale_2),
238
+ num_inference_steps=num_inference_steps,
239
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
240
+ callback_on_step_end=lora_switcher_callback,
241
+ ).frames[0]
242
+
243
+ progress(0.9, desc="正在编码和保存视频...")
244
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
245
+ video_path = tmpfile.name
246
+ export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
247
+
248
+ progress(1.0, desc="完成!")
249
+ return video_path, current_seed
250
+
251
+
252
+ with gr.Blocks() as demo:
253
+ gr.Markdown("# Wan 2.2 First/Last Frame with Custom LoRA")
254
+ with gr.Row():
255
+ with gr.Column():
256
+ secret_key_input = gr.Textbox(label="密钥 (Secret Key)", placeholder="Enter your key here...", type="password")
257
+
258
+ with gr.Row():
259
+ start_image_component = gr.Image(type="pil", label="开始帧 (Start Frame)", sources=["upload", "clipboard"])
260
+ end_image_component = gr.Image(type="pil", label="结束帧 (End Frame)", sources=["upload", "clipboard"])
261
+
262
+ prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
263
+ duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="视频时长 (秒)", info=f"将在 {FIXED_FPS}fps 下被限制在模型的 {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} 帧范围内。")
264
+
265
+ # 保留您的 LoRA 选择器
266
+ lora_selection_checkbox = gr.CheckboxGroup(
267
+ choices=list(LORA_SETS.keys()),
268
+ label="选择要应用的 LoRA (可多选)",
269
+ info="选择一个或多个 LoRA 风格进行组合。"
270
+ )
271
+
272
+ with gr.Accordion("高级设置", open=False):
273
+ negative_prompt_input = gr.Textbox(label="负面提示词", value=default_negative_prompt, lines=3)
274
+ seed_input = gr.Slider(label="种子", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
275
+ randomize_seed_checkbox = gr.Checkbox(label="随机种子", value=True, interactive=True)
276
+ steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=8, label="推理步数")
277
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="引导系数 - 高噪声阶段")
278
+ guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="引导系数 2 - 低噪声阶段")
279
+
280
+ generate_button = gr.Button("生成视频", variant="primary")
281
+ with gr.Column():
282
+ video_output = gr.Video(label="生成的视频", autoplay=True, interactive=False)
283
+
284
+ ui_inputs = [
285
+ secret_key_input,
286
+ start_image_component,
287
+ end_image_component,
288
+ prompt_input,
289
+ steps_slider,
290
+ negative_prompt_input,
291
+ duration_seconds_input,
292
+ guidance_scale_input,
293
+ guidance_scale_2_input,
294
+ seed_input,
295
+ randomize_seed_checkbox,
296
+ lora_selection_checkbox
297
+ ]
298
+ generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
299
+
300
+ if __name__ == "__main__":
301
+ demo.queue().launch(mcp_server=True)
optimization.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+
8
+ import spaces
9
+ import torch
10
+ from torch.utils._pytree import tree_map_only
11
+ from torchao.quantization import quantize_
12
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
+ from torchao.quantization import Int8WeightOnlyConfig
14
+
15
+ from optimization_utils import capture_component_call
16
+ from optimization_utils import aoti_compile
17
+ from optimization_utils import drain_module_parameters
18
+
19
+
20
+ P = ParamSpec('P')
21
+
22
+ # --- 新的、更精确的动态塑形定义 ---
23
+
24
+ # VAE 时间缩放因子为 1,latent_frames = num_frames。范围是 [8, 81]。
25
+ LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
26
+
27
+ # Transformer 的 patch_size 为 (1, 2, 2),这意味着输入潜像的高度和宽度
28
+ # 实际上被除以 2。如果符号追踪器假设奇数是可能的,这会产生约束失败。
29
+ #
30
+ # 为了解决这个问题,我们为 *打过补丁后* (即除法后) 的尺寸定义动态维度,
31
+ # 然后将输入形状表示为该维度的 2 倍。这在数学上向编译器保证了
32
+ # 输入潜像维度始终为偶数,从而满足约束。
33
+
34
+ # 应用的像素尺寸范围:[480, 832]。VAE 缩放因子为 8。
35
+ # 潜像维度范围:[480/8, 832/8] = [60, 104]。
36
+ # 打过补丁后的潜像维度范围:[60/2, 104/2] = [30, 52]。
37
+ LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
38
+ LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)
39
+
40
+ # 现在,我们为 transformer 的 `hidden_states` 输入定义动态形状,
41
+ # 其形状为 (batch_size, channels, num_frames, height, width)。
42
+ TRANSFORMER_DYNAMIC_SHAPES = {
43
+ 'hidden_states': {
44
+ 2: LATENT_FRAMES_DIM,
45
+ 3: 2 * LATENT_PATCHED_HEIGHT_DIM, # 保证高度为偶数
46
+ 4: 2 * LATENT_PATCHED_WIDTH_DIM, # 保证宽度为偶数
47
+ },
48
+ }
49
+
50
+ # --- 定义结束 ---
51
+
52
+
53
+ INDUCTOR_CONFIGS = {
54
+ 'conv_1x1_as_mm': True,
55
+ 'epilogue_fusion': False,
56
+ 'coordinate_descent_tuning': True,
57
+ 'coordinate_descent_check_all_directions': True,
58
+ 'max_autotune': True,
59
+ 'triton.cudagraphs': True,
60
+ }
61
+
62
+
63
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
64
+
65
+ @spaces.GPU(duration=1500)
66
+ def compile_transformer():
67
+
68
+ # LoRA 融合部分保持不变
69
+ pipeline.load_lora_weights(
70
+ "Kijai/WanVideo_comfy",
71
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
72
+ adapter_name="lightx2v"
73
+ )
74
+ kwargs_lora = {}
75
+ kwargs_lora["load_into_transformer_2"] = True
76
+ pipeline.load_lora_weights(
77
+ "Kijai/WanVideo_comfy",
78
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
79
+ adapter_name="lightx2v_2", **kwargs_lora
80
+ )
81
+ pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
82
+ pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
83
+ pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
84
+ pipeline.unload_lora_weights()
85
+
86
+ # 捕获单次调用以获取 args/kwargs 结构
87
+ with capture_component_call(pipeline, 'transformer') as call:
88
+ pipeline(*args, **kwargs)
89
+
90
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
91
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
92
+
93
+ # 量化保持不变
94
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
95
+ quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
96
+
97
+ # --- 简化的编译流程 ---
98
+
99
+ exported_1 = torch.export.export(
100
+ mod=pipeline.transformer,
101
+ args=call.args,
102
+ kwargs=call.kwargs,
103
+ dynamic_shapes=dynamic_shapes,
104
+ )
105
+
106
+ exported_2 = torch.export.export(
107
+ mod=pipeline.transformer_2,
108
+ args=call.args,
109
+ kwargs=call.kwargs,
110
+ dynamic_shapes=dynamic_shapes,
111
+ )
112
+
113
+ compiled_1 = aoti_compile(exported_1, INDUCTOR_CONFIGS)
114
+ compiled_2 = aoti_compile(exported_2, INDUCTOR_CONFIGS)
115
+
116
+ # 返回两个已编译的模型
117
+ return compiled_1, compiled_2
118
+
119
+
120
+ # 量化文本编码器 (与之前相同)
121
+ quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
122
+
123
+ # 获取两个经过动态塑形的已编译模型
124
+ compiled_transformer_1, compiled_transformer_2 = compile_transformer()
125
+
126
+ # --- 简化的赋值流程 ---
127
+
128
+ pipeline.transformer.forward = compiled_transformer_1
129
+ drain_module_parameters(pipeline.transformer)
130
+
131
+ pipeline.transformer_2.forward = compiled_transformer_2
132
+ drain_module_parameters(pipeline.transformer_2)
optimization_utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ import contextlib
4
+ from contextvars import ContextVar
5
+ from io import BytesIO
6
+ from typing import Any
7
+ from typing import cast
8
+ from unittest.mock import patch
9
+
10
+ import torch
11
+ from torch._inductor.package.package import package_aoti
12
+ from torch.export.pt2_archive._package import AOTICompiledModel
13
+ from torch.export.pt2_archive._package_weights import Weights
14
+
15
+
16
+ INDUCTOR_CONFIGS_OVERRIDES = {
17
+ 'aot_inductor.package_constants_in_so': False,
18
+ 'aot_inductor.package_constants_on_disk': True,
19
+ 'aot_inductor.package': True,
20
+ }
21
+
22
+
23
+ class ZeroGPUWeights:
24
+ def __init__(self, constants_map: dict[str, torch.Tensor], to_cuda: bool = False):
25
+ if to_cuda:
26
+ self.constants_map = {name: tensor.to('cuda') for name, tensor in constants_map.items()}
27
+ else:
28
+ self.constants_map = constants_map
29
+ def __reduce__(self):
30
+ constants_map: dict[str, torch.Tensor] = {}
31
+ for name, tensor in self.constants_map.items():
32
+ tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
33
+ constants_map[name] = tensor_.copy_(tensor).detach().share_memory_()
34
+ return ZeroGPUWeights, (constants_map, True)
35
+
36
+
37
+ class ZeroGPUCompiledModel:
38
+ def __init__(self, archive_file: torch.types.FileLike, weights: ZeroGPUWeights):
39
+ self.archive_file = archive_file
40
+ self.weights = weights
41
+ self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
42
+ def __call__(self, *args, **kwargs):
43
+ if (compiled_model := self.compiled_model.get()) is None:
44
+ compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
45
+ compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
46
+ self.compiled_model.set(compiled_model)
47
+ return compiled_model(*args, **kwargs)
48
+ def __reduce__(self):
49
+ return ZeroGPUCompiledModel, (self.archive_file, self.weights)
50
+
51
+
52
+ def aoti_compile(
53
+ exported_program: torch.export.ExportedProgram,
54
+ inductor_configs: dict[str, Any] | None = None,
55
+ ):
56
+ inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
57
+ gm = cast(torch.fx.GraphModule, exported_program.module())
58
+ assert exported_program.example_inputs is not None
59
+ args, kwargs = exported_program.example_inputs
60
+ artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
61
+ archive_file = BytesIO()
62
+ files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
63
+ package_aoti(archive_file, files)
64
+ weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
65
+ zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights})
66
+ return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
67
+
68
+
69
+ @contextlib.contextmanager
70
+ def capture_component_call(
71
+ pipeline: Any,
72
+ component_name: str,
73
+ component_method='forward',
74
+ ):
75
+
76
+ class CapturedCallException(Exception):
77
+ def __init__(self, *args, **kwargs):
78
+ super().__init__()
79
+ self.args = args
80
+ self.kwargs = kwargs
81
+
82
+ class CapturedCall:
83
+ def __init__(self):
84
+ self.args: tuple[Any, ...] = ()
85
+ self.kwargs: dict[str, Any] = {}
86
+
87
+ component = getattr(pipeline, component_name)
88
+ captured_call = CapturedCall()
89
+
90
+ def capture_call(*args, **kwargs):
91
+ raise CapturedCallException(*args, **kwargs)
92
+
93
+ with patch.object(component, component_method, new=capture_call):
94
+ try:
95
+ yield captured_call
96
+ except CapturedCallException as e:
97
+ captured_call.args = e.args
98
+ captured_call.kwargs = e.kwargs
99
+
100
+
101
+ def drain_module_parameters(module: torch.nn.Module):
102
+ state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
103
+ state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
104
+ module.load_state_dict(state_dict, assign=True)
105
+ for name, param in state_dict.items():
106
+ meta = state_dict_meta[name]
107
+ param.data = torch.Tensor([]).to(**meta)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/linoytsaban/diffusers.git@wan22-loras
2
+
3
+ transformers
4
+ accelerate
5
+ safetensors
6
+ sentencepiece
7
+ peft
8
+ ftfy
9
+ imageio-ffmpeg
10
+ opencv-python
11
+ torchao==0.11.0