IdlecloudX commited on
Commit
64e881e
·
verified ·
1 Parent(s): d15788a

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +268 -0
  2. optimization.py +130 -0
  3. optimization_utils.py +107 -0
  4. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch 2.8 (temporary hack)
2
+ import os
3
+ os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
4
+
5
+ # Actual demo code
6
+ import spaces
7
+ import torch
8
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
9
+ from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
10
+ from diffusers.utils.export_utils import export_to_video
11
+ import gradio as gr
12
+ import tempfile
13
+ import numpy as np
14
+ from PIL import Image
15
+ import random
16
+ import gc
17
+ from optimization import optimize_pipeline_
18
+ from huggingface_hub import hf_hub_download
19
+
20
+ SECRET_KEY = os.environ.get("SECRET_KEY")
21
+
22
+ # 如果在 Space 中没有设置密钥
23
+ if not SECRET_KEY:
24
+ raise ValueError("请设置 SECRET_KEY 环境变量。")
25
+
26
+ MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
27
+
28
+ # 在这里配置所有的 LoRA。
29
+ LORA_REPO_ID = "IdlecloudX/Flux_and_Wan_Lora"
30
+ LORA_SETS = {
31
+ "NF": {
32
+ "high_noise": {"file": "NSFW-22-H-e8.safetensors", "adapter_name": "nf_high"},
33
+ "low_noise": {"file": "NSFW-22-L-e8.safetensors", "adapter_name": "nf_low"}
34
+ },
35
+ "BP": {
36
+ "high_noise": {"file": "Wan2.2_BP-v1-HighNoise-I2V_T2V.safetensors", "adapter_name": "bp_high"},
37
+ "low_noise": {"file": "Wan2.2_BP-v1-LowNoise-I2V_T2V.safetensors", "adapter_name": "bp_low"}
38
+ },
39
+ "Py-v1": {
40
+ "high_noise": {"file": "WAN2.2-HighNoise_Pyv1-I2V_T2V.safetensors", "adapter_name": "py_high"},
41
+ "low_noise": {"file": "WAN2.2-LowNoise_Pyv1-I2V_T2V.safetensors", "adapter_name": "py_low"}
42
+ }
43
+ }
44
+
45
+ LANDSCAPE_WIDTH = 832
46
+ LANDSCAPE_HEIGHT = 576
47
+ MAX_SEED = np.iinfo(np.int32).max
48
+
49
+ FIXED_FPS = 16
50
+ MIN_FRAMES_MODEL = 8
51
+ MAX_FRAMES_MODEL = 81
52
+
53
+ MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
54
+ MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)
55
+
56
+
57
+ pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID,
58
+ transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
59
+ subfolder='transformer',
60
+ torch_dtype=torch.bfloat16,
61
+ device_map='cuda',
62
+ ),
63
+ transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
64
+ subfolder='transformer_2',
65
+ torch_dtype=torch.bfloat16,
66
+ device_map='cuda',
67
+ ),
68
+ torch_dtype=torch.bfloat16,
69
+ ).to('cuda')
70
+
71
+
72
+ print("开始优化 Pipeline...")
73
+ optimize_pipeline_(pipe,
74
+ image=Image.new('RGB', (LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT)),
75
+ prompt='prompt',
76
+ height=LANDSCAPE_HEIGHT,
77
+ width=LANDSCAPE_WIDTH,
78
+ num_frames=MAX_FRAMES_MODEL,
79
+ )
80
+ print("优化完成。")
81
+
82
+ for name, lora_set in LORA_SETS.items():
83
+ print(f"--- 开始加载 LoRA 集合: {name} ---")
84
+
85
+ # 加载 High Noise
86
+ high_noise_config = lora_set["high_noise"]
87
+ print(f"正在加载 High Noise: {high_noise_config['file']}...")
88
+ pipe.load_lora_weights(LORA_REPO_ID, weight_name=high_noise_config['file'], adapter_name=high_noise_config['adapter_name'])
89
+ print("High Noise LoRA 加载完成。")
90
+
91
+ # 加载 Low Noise
92
+ low_noise_config = lora_set["low_noise"]
93
+ print(f"正在加载 Low Noise: {low_noise_config['file']}...")
94
+ pipe.load_lora_weights(LORA_REPO_ID, weight_name=low_noise_config['file'], adapter_name=low_noise_config['adapter_name'])
95
+ print("Low Noise LoRA 加载完成。")
96
+ print("所有自定义 LoRA 加载完毕。")
97
+
98
+ for i in range(3):
99
+ gc.collect()
100
+ torch.cuda.synchronize()
101
+ torch.cuda.empty_cache()
102
+
103
+ default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
104
+ default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
105
+
106
+
107
+ def resize_image(image: Image.Image) -> Image.Image:
108
+ if image.height > image.width:
109
+ transposed = image.transpose(Image.Transpose.ROTATE_90)
110
+ resized = resize_image_landscape(transposed)
111
+ return resized.transpose(Image.Transpose.ROTATE_270)
112
+ return resize_image_landscape(image)
113
+
114
+
115
+ def resize_image_landscape(image: Image.Image) -> Image.Image:
116
+ target_aspect = LANDSCAPE_WIDTH / LANDSCAPE_HEIGHT
117
+ width, height = image.size
118
+ in_aspect = width / height
119
+ if in_aspect > target_aspect:
120
+ new_width = round(height * target_aspect)
121
+ left = (width - new_width) // 2
122
+ image = image.crop((left, 0, left + new_width, height))
123
+ else:
124
+ new_height = round(width / target_aspect)
125
+ top = (height - new_height) // 2
126
+ image = image.crop((0, top, width, top + new_height))
127
+ return image.resize((LANDSCAPE_WIDTH, LANDSCAPE_HEIGHT), Image.LANCZOS)
128
+
129
+ def get_duration(
130
+ secret_key,
131
+ input_image,
132
+ prompt,
133
+ steps,
134
+ negative_prompt,
135
+ duration_seconds,
136
+ guidance_scale,
137
+ guidance_scale_2,
138
+ seed,
139
+ randomize_seed,
140
+ selected_loras,
141
+ progress,
142
+ ):
143
+ return int(steps) * 15
144
+
145
+ @spaces.GPU(duration=get_duration)
146
+ def generate_video(
147
+ secret_key,
148
+ input_image,
149
+ prompt,
150
+ steps = 4,
151
+ negative_prompt=default_negative_prompt,
152
+ duration_seconds = MAX_DURATION,
153
+ guidance_scale = 1,
154
+ guidance_scale_2 = 1,
155
+ seed = 42,
156
+ randomize_seed = False,
157
+ selected_loras = [],
158
+ progress=gr.Progress(track_tqdm=True),
159
+ ):
160
+ if secret_key != SECRET_KEY:
161
+ raise gr.Error("无效的密钥!请输入正确的密钥。")
162
+
163
+ if input_image is None:
164
+ raise gr.Error("Please upload an input image.")
165
+
166
+ num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
167
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
168
+ resized_image = resize_image(input_image)
169
+
170
+ num_inference_steps = int(steps)
171
+ switch_step = num_inference_steps // 2
172
+
173
+ class LoraSwitcher:
174
+ def __init__(self, selected_lora_names):
175
+ self.switched = False
176
+ self.high_noise_adapters = []
177
+ self.low_noise_adapters = []
178
+
179
+ if selected_lora_names:
180
+ for name in selected_lora_names:
181
+ if name in LORA_SETS:
182
+ self.high_noise_adapters.append(LORA_SETS[name]["high_noise"]["adapter_name"])
183
+ self.low_noise_adapters.append(LORA_SETS[name]["low_noise"]["adapter_name"])
184
+
185
+ def __call__(self, pipe, step_index, timestep, callback_kwargs):
186
+ # 在第一步设置正确的 LoRA 状态
187
+ if step_index == 0:
188
+ self.switched = False
189
+ # 如果用户选择了 LoRA,则激活 High Noise 版本
190
+ if self.high_noise_adapters:
191
+ print(f"激活 High Noise LoRA: {self.high_noise_adapters}")
192
+ pipe.set_adapters(self.high_noise_adapters, adapter_weights=[1.0] * len(self.high_noise_adapters))
193
+ # 如果用户没有选择 LoRA,则通过将权重设为0来禁用任何可能残留的 LoRA
194
+ elif pipe.active_adapters:
195
+ print(f"未选择 LoRA,通过设置权重为0来禁用残留的 LoRA: {pipe.active_adapters}")
196
+ pipe.set_adapters(pipe.active_adapters, adapter_weights=[0.0] * len(pipe.active_adapters))
197
+
198
+ # 在切换点,切换到 Low Noise LoRA(仅当有 LoRA 被选择时)
199
+ if self.low_noise_adapters and step_index >= switch_step and not self.switched:
200
+ print(f"在第 {step_index} 步切换到 Low Noise LoRA: {self.low_noise_adapters}")
201
+ pipe.set_adapters(self.low_noise_adapters, adapter_weights=[1.0] * len(self.low_noise_adapters))
202
+ self.switched = True
203
+ return callback_kwargs
204
+ # --- 修改结束 ---
205
+
206
+ lora_switcher_callback = LoraSwitcher(selected_loras)
207
+
208
+ output_frames_list = pipe(
209
+ image=resized_image,
210
+ prompt=prompt,
211
+ negative_prompt=negative_prompt,
212
+ height=resized_image.height,
213
+ width=resized_image.width,
214
+ num_frames=num_frames,
215
+ guidance_scale=float(guidance_scale),
216
+ guidance_scale_2=float(guidance_scale_2),
217
+ num_inference_steps=num_inference_steps,
218
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
219
+ callback_on_step_end=lora_switcher_callback,
220
+ ).frames[0]
221
+
222
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
223
+ video_path = tmpfile.name
224
+
225
+ export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
226
+
227
+ return video_path, current_seed
228
+
229
+ with gr.Blocks() as demo:
230
+ gr.Markdown("# Fast 4 steps Wan 2.2 I2V (14B) with Lightning LoRA")
231
+ gr.Markdown("run Wan 2.2 in just 4-8 steps, with [Lightning LoRA](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Wan22-Lightning), fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
232
+ with gr.Row():
233
+ with gr.Column():
234
+ secret_key_input = gr.Textbox(label="密钥 (Secret Key)", placeholder="Enter your key here...", type="password")
235
+
236
+ input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)")
237
+ prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
238
+ duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
239
+
240
+ lora_selection_checkbox = gr.CheckboxGroup(
241
+ choices=list(LORA_SETS.keys()),
242
+ label="选择要应用的 LoRA (可多选)",
243
+ info="选择一个或多个 LoRA 风格进行组合。"
244
+ )
245
+
246
+ with gr.Accordion("Advanced Settings", open=False):
247
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
248
+ seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
249
+ randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
250
+ steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
251
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage")
252
+ guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage")
253
+
254
+ generate_button = gr.Button("Generate Video", variant="primary")
255
+ with gr.Column():
256
+ video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
257
+
258
+ ui_inputs = [
259
+ secret_key_input,
260
+ input_image_component, prompt_input, steps_slider,
261
+ negative_prompt_input, duration_seconds_input,
262
+ guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox,
263
+ lora_selection_checkbox
264
+ ]
265
+ generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
266
+
267
+ if __name__ == "__main__":
268
+ demo.queue().launch(mcp_server=True)
optimization.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+
4
+ from typing import Any
5
+ from typing import Callable
6
+ from typing import ParamSpec
7
+
8
+ import spaces
9
+ import torch
10
+ from torch.utils._pytree import tree_map_only
11
+ from torchao.quantization import quantize_
12
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
13
+ from torchao.quantization import Int8WeightOnlyConfig
14
+
15
+ from optimization_utils import capture_component_call
16
+ from optimization_utils import aoti_compile
17
+ from optimization_utils import ZeroGPUCompiledModel
18
+ from optimization_utils import drain_module_parameters
19
+
20
+
21
+ P = ParamSpec('P')
22
+
23
+
24
+ TRANSFORMER_NUM_FRAMES_DIM = torch.export.Dim('num_frames', min=3, max=21)
25
+
26
+ TRANSFORMER_DYNAMIC_SHAPES = {
27
+ 'hidden_states': {
28
+ 2: TRANSFORMER_NUM_FRAMES_DIM,
29
+ },
30
+ }
31
+
32
+ INDUCTOR_CONFIGS = {
33
+ 'conv_1x1_as_mm': True,
34
+ 'epilogue_fusion': False,
35
+ 'coordinate_descent_tuning': True,
36
+ 'coordinate_descent_check_all_directions': True,
37
+ 'max_autotune': True,
38
+ 'triton.cudagraphs': True,
39
+ }
40
+
41
+
42
+ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
43
+
44
+ @spaces.GPU(duration=1500)
45
+ def compile_transformer():
46
+
47
+ pipeline.load_lora_weights(
48
+ "Kijai/WanVideo_comfy",
49
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
50
+ adapter_name="lightx2v"
51
+ )
52
+ kwargs_lora = {}
53
+ kwargs_lora["load_into_transformer_2"] = True
54
+ pipeline.load_lora_weights(
55
+ "Kijai/WanVideo_comfy",
56
+ weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
57
+ adapter_name="lightx2v_2", **kwargs_lora
58
+ )
59
+ pipeline.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
60
+ pipeline.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
61
+ pipeline.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
62
+ pipeline.unload_lora_weights()
63
+
64
+ with capture_component_call(pipeline, 'transformer') as call:
65
+ pipeline(*args, **kwargs)
66
+
67
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
68
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
69
+
70
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
71
+ quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
72
+
73
+ hidden_states: torch.Tensor = call.kwargs['hidden_states']
74
+ hidden_states_transposed = hidden_states.transpose(-1, -2).contiguous()
75
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
76
+ hidden_states_landscape = hidden_states
77
+ hidden_states_portrait = hidden_states_transposed
78
+ else:
79
+ hidden_states_landscape = hidden_states_transposed
80
+ hidden_states_portrait = hidden_states
81
+
82
+ exported_landscape_1 = torch.export.export(
83
+ mod=pipeline.transformer,
84
+ args=call.args,
85
+ kwargs=call.kwargs | {'hidden_states': hidden_states_landscape},
86
+ dynamic_shapes=dynamic_shapes,
87
+ )
88
+
89
+ exported_portrait_2 = torch.export.export(
90
+ mod=pipeline.transformer_2,
91
+ args=call.args,
92
+ kwargs=call.kwargs | {'hidden_states': hidden_states_portrait},
93
+ dynamic_shapes=dynamic_shapes,
94
+ )
95
+
96
+ compiled_landscape_1 = aoti_compile(exported_landscape_1, INDUCTOR_CONFIGS)
97
+ compiled_portrait_2 = aoti_compile(exported_portrait_2, INDUCTOR_CONFIGS)
98
+
99
+ compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
100
+ compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
101
+
102
+ return (
103
+ compiled_landscape_1,
104
+ compiled_landscape_2,
105
+ compiled_portrait_1,
106
+ compiled_portrait_2,
107
+ )
108
+
109
+ quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
110
+ cl1, cl2, cp1, cp2 = compile_transformer()
111
+
112
+ def combined_transformer_1(*args, **kwargs):
113
+ hidden_states: torch.Tensor = kwargs['hidden_states']
114
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
115
+ return cl1(*args, **kwargs)
116
+ else:
117
+ return cp1(*args, **kwargs)
118
+
119
+ def combined_transformer_2(*args, **kwargs):
120
+ hidden_states: torch.Tensor = kwargs['hidden_states']
121
+ if hidden_states.shape[-1] > hidden_states.shape[-2]:
122
+ return cl2(*args, **kwargs)
123
+ else:
124
+ return cp2(*args, **kwargs)
125
+
126
+ pipeline.transformer.forward = combined_transformer_1
127
+ drain_module_parameters(pipeline.transformer)
128
+
129
+ pipeline.transformer_2.forward = combined_transformer_2
130
+ drain_module_parameters(pipeline.transformer_2)
optimization_utils.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ """
3
+ import contextlib
4
+ from contextvars import ContextVar
5
+ from io import BytesIO
6
+ from typing import Any
7
+ from typing import cast
8
+ from unittest.mock import patch
9
+
10
+ import torch
11
+ from torch._inductor.package.package import package_aoti
12
+ from torch.export.pt2_archive._package import AOTICompiledModel
13
+ from torch.export.pt2_archive._package_weights import Weights
14
+
15
+
16
+ INDUCTOR_CONFIGS_OVERRIDES = {
17
+ 'aot_inductor.package_constants_in_so': False,
18
+ 'aot_inductor.package_constants_on_disk': True,
19
+ 'aot_inductor.package': True,
20
+ }
21
+
22
+
23
+ class ZeroGPUWeights:
24
+ def __init__(self, constants_map: dict[str, torch.Tensor], to_cuda: bool = False):
25
+ if to_cuda:
26
+ self.constants_map = {name: tensor.to('cuda') for name, tensor in constants_map.items()}
27
+ else:
28
+ self.constants_map = constants_map
29
+ def __reduce__(self):
30
+ constants_map: dict[str, torch.Tensor] = {}
31
+ for name, tensor in self.constants_map.items():
32
+ tensor_ = torch.empty_like(tensor, device='cpu').pin_memory()
33
+ constants_map[name] = tensor_.copy_(tensor).detach().share_memory_()
34
+ return ZeroGPUWeights, (constants_map, True)
35
+
36
+
37
+ class ZeroGPUCompiledModel:
38
+ def __init__(self, archive_file: torch.types.FileLike, weights: ZeroGPUWeights):
39
+ self.archive_file = archive_file
40
+ self.weights = weights
41
+ self.compiled_model: ContextVar[AOTICompiledModel | None] = ContextVar('compiled_model', default=None)
42
+ def __call__(self, *args, **kwargs):
43
+ if (compiled_model := self.compiled_model.get()) is None:
44
+ compiled_model = cast(AOTICompiledModel, torch._inductor.aoti_load_package(self.archive_file))
45
+ compiled_model.load_constants(self.weights.constants_map, check_full_update=True, user_managed=True)
46
+ self.compiled_model.set(compiled_model)
47
+ return compiled_model(*args, **kwargs)
48
+ def __reduce__(self):
49
+ return ZeroGPUCompiledModel, (self.archive_file, self.weights)
50
+
51
+
52
+ def aoti_compile(
53
+ exported_program: torch.export.ExportedProgram,
54
+ inductor_configs: dict[str, Any] | None = None,
55
+ ):
56
+ inductor_configs = (inductor_configs or {}) | INDUCTOR_CONFIGS_OVERRIDES
57
+ gm = cast(torch.fx.GraphModule, exported_program.module())
58
+ assert exported_program.example_inputs is not None
59
+ args, kwargs = exported_program.example_inputs
60
+ artifacts = torch._inductor.aot_compile(gm, args, kwargs, options=inductor_configs)
61
+ archive_file = BytesIO()
62
+ files: list[str | Weights] = [file for file in artifacts if isinstance(file, str)]
63
+ package_aoti(archive_file, files)
64
+ weights, = (artifact for artifact in artifacts if isinstance(artifact, Weights))
65
+ zerogpu_weights = ZeroGPUWeights({name: weights.get_weight(name)[0] for name in weights})
66
+ return ZeroGPUCompiledModel(archive_file, zerogpu_weights)
67
+
68
+
69
+ @contextlib.contextmanager
70
+ def capture_component_call(
71
+ pipeline: Any,
72
+ component_name: str,
73
+ component_method='forward',
74
+ ):
75
+
76
+ class CapturedCallException(Exception):
77
+ def __init__(self, *args, **kwargs):
78
+ super().__init__()
79
+ self.args = args
80
+ self.kwargs = kwargs
81
+
82
+ class CapturedCall:
83
+ def __init__(self):
84
+ self.args: tuple[Any, ...] = ()
85
+ self.kwargs: dict[str, Any] = {}
86
+
87
+ component = getattr(pipeline, component_name)
88
+ captured_call = CapturedCall()
89
+
90
+ def capture_call(*args, **kwargs):
91
+ raise CapturedCallException(*args, **kwargs)
92
+
93
+ with patch.object(component, component_method, new=capture_call):
94
+ try:
95
+ yield captured_call
96
+ except CapturedCallException as e:
97
+ captured_call.args = e.args
98
+ captured_call.kwargs = e.kwargs
99
+
100
+
101
+ def drain_module_parameters(module: torch.nn.Module):
102
+ state_dict_meta = {name: {'device': tensor.device, 'dtype': tensor.dtype} for name, tensor in module.state_dict().items()}
103
+ state_dict = {name: torch.nn.Parameter(torch.empty_like(tensor, device='cpu')) for name, tensor in module.state_dict().items()}
104
+ module.load_state_dict(state_dict, assign=True)
105
+ for name, param in state_dict.items():
106
+ meta = state_dict_meta[name]
107
+ param.data = torch.Tensor([]).to(**meta)
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ git+https://github.com/linoytsaban/diffusers.git@wan22-loras
2
+
3
+ transformers
4
+ accelerate
5
+ safetensors
6
+ sentencepiece
7
+ peft
8
+ ftfy
9
+ imageio-ffmpeg
10
+ opencv-python
11
+ torchao==0.11.0