Upload 91 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +26 -0
- configs/distilled_model/gta_drive/config.json +49 -0
- configs/distilled_model/templerun/config.json +42 -0
- configs/distilled_model/universal/config.json +49 -0
- configs/foundation_model/config.json +49 -0
- configs/inference_yaml/inference_gta_drive.yaml +21 -0
- configs/inference_yaml/inference_templerun.yaml +22 -0
- configs/inference_yaml/inference_universal.yaml +21 -0
- demo_images/gta_drive/0000.png +3 -0
- demo_images/gta_drive/0001.png +3 -0
- demo_images/gta_drive/0002.png +3 -0
- demo_images/gta_drive/0003.png +3 -0
- demo_images/gta_drive/0004.png +3 -0
- demo_images/gta_drive/0005.png +3 -0
- demo_images/temple_run/0000.png +3 -0
- demo_images/temple_run/0001.png +3 -0
- demo_images/temple_run/0002.png +3 -0
- demo_images/temple_run/0003.png +3 -0
- demo_images/temple_run/0004.png +3 -0
- demo_images/temple_run/0005.png +3 -0
- demo_images/universal/0000.png +0 -0
- demo_images/universal/0001.png +3 -0
- demo_images/universal/0002.png +3 -0
- demo_images/universal/0003.png +3 -0
- demo_images/universal/0004.png +3 -0
- demo_images/universal/0005.png +3 -0
- demo_images/universal/0006.png +3 -0
- demo_images/universal/0007.png +3 -0
- demo_images/universal/0008.png +3 -0
- demo_images/universal/0009.png +3 -0
- demo_images/universal/0010.webp +0 -0
- demo_images/universal/0011.png +3 -0
- demo_images/universal/0012.png +3 -0
- demo_images/universal/0013.png +3 -0
- demo_images/universal/0014.png +3 -0
- demo_images/universal/0015.png +0 -0
- demo_images/universal/0016.png +3 -0
- demo_utils/constant.py +42 -0
- demo_utils/memory.py +135 -0
- demo_utils/taehv.py +313 -0
- demo_utils/utils.py +616 -0
- demo_utils/vae.py +390 -0
- demo_utils/vae_block3.py +291 -0
- demo_utils/vae_torch2trt.py +308 -0
- inference.py +169 -0
- inference_streaming.py +161 -0
- pipeline/__init__.py +5 -0
- pipeline/causal_inference.py +753 -0
- requirements.txt +41 -0
- setup.py +6 -0
.gitattributes
CHANGED
|
@@ -72,3 +72,29 @@ GameWorldScore/GameWorld/third_party/DROID-SLAM/thirdparty/lietorch/examples/rgb
|
|
| 72 |
GameWorldScore/GameWorld/third_party/DROID-SLAM/thirdparty/lietorch/examples/rgbdslam/assets/room.png filter=lfs diff=lfs merge=lfs -text
|
| 73 |
GameWorldScore/GameWorld/third_party/DROID-SLAM/thirdparty/lietorch/lietorch.png filter=lfs diff=lfs merge=lfs -text
|
| 74 |
GameWorldScore/GameWorld/third_party/RAFT/RAFT.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
GameWorldScore/GameWorld/third_party/DROID-SLAM/thirdparty/lietorch/examples/rgbdslam/assets/room.png filter=lfs diff=lfs merge=lfs -text
|
| 73 |
GameWorldScore/GameWorld/third_party/DROID-SLAM/thirdparty/lietorch/lietorch.png filter=lfs diff=lfs merge=lfs -text
|
| 74 |
GameWorldScore/GameWorld/third_party/RAFT/RAFT.png filter=lfs diff=lfs merge=lfs -text
|
| 75 |
+
demo_images/gta_drive/0000.png filter=lfs diff=lfs merge=lfs -text
|
| 76 |
+
demo_images/gta_drive/0001.png filter=lfs diff=lfs merge=lfs -text
|
| 77 |
+
demo_images/gta_drive/0002.png filter=lfs diff=lfs merge=lfs -text
|
| 78 |
+
demo_images/gta_drive/0003.png filter=lfs diff=lfs merge=lfs -text
|
| 79 |
+
demo_images/gta_drive/0004.png filter=lfs diff=lfs merge=lfs -text
|
| 80 |
+
demo_images/gta_drive/0005.png filter=lfs diff=lfs merge=lfs -text
|
| 81 |
+
demo_images/temple_run/0000.png filter=lfs diff=lfs merge=lfs -text
|
| 82 |
+
demo_images/temple_run/0001.png filter=lfs diff=lfs merge=lfs -text
|
| 83 |
+
demo_images/temple_run/0002.png filter=lfs diff=lfs merge=lfs -text
|
| 84 |
+
demo_images/temple_run/0003.png filter=lfs diff=lfs merge=lfs -text
|
| 85 |
+
demo_images/temple_run/0004.png filter=lfs diff=lfs merge=lfs -text
|
| 86 |
+
demo_images/temple_run/0005.png filter=lfs diff=lfs merge=lfs -text
|
| 87 |
+
demo_images/universal/0001.png filter=lfs diff=lfs merge=lfs -text
|
| 88 |
+
demo_images/universal/0002.png filter=lfs diff=lfs merge=lfs -text
|
| 89 |
+
demo_images/universal/0003.png filter=lfs diff=lfs merge=lfs -text
|
| 90 |
+
demo_images/universal/0004.png filter=lfs diff=lfs merge=lfs -text
|
| 91 |
+
demo_images/universal/0005.png filter=lfs diff=lfs merge=lfs -text
|
| 92 |
+
demo_images/universal/0006.png filter=lfs diff=lfs merge=lfs -text
|
| 93 |
+
demo_images/universal/0007.png filter=lfs diff=lfs merge=lfs -text
|
| 94 |
+
demo_images/universal/0008.png filter=lfs diff=lfs merge=lfs -text
|
| 95 |
+
demo_images/universal/0009.png filter=lfs diff=lfs merge=lfs -text
|
| 96 |
+
demo_images/universal/0011.png filter=lfs diff=lfs merge=lfs -text
|
| 97 |
+
demo_images/universal/0012.png filter=lfs diff=lfs merge=lfs -text
|
| 98 |
+
demo_images/universal/0013.png filter=lfs diff=lfs merge=lfs -text
|
| 99 |
+
demo_images/universal/0014.png filter=lfs diff=lfs merge=lfs -text
|
| 100 |
+
demo_images/universal/0016.png filter=lfs diff=lfs merge=lfs -text
|
configs/distilled_model/gta_drive/config.json
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "CausalWanModel",
|
| 3 |
+
"_diffusers_version": "0.35.0.dev0",
|
| 4 |
+
"action_config": {
|
| 5 |
+
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
|
| 6 |
+
"enable_keyboard": true,
|
| 7 |
+
"enable_mouse": true,
|
| 8 |
+
"heads_num": 16,
|
| 9 |
+
"hidden_size": 128,
|
| 10 |
+
"img_hidden_size": 1536,
|
| 11 |
+
"keyboard_dim_in": 2,
|
| 12 |
+
"keyboard_hidden_dim": 1024,
|
| 13 |
+
"mouse_dim_in": 2,
|
| 14 |
+
"mouse_hidden_dim": 1024,
|
| 15 |
+
"mouse_qk_dim_list": [
|
| 16 |
+
8,
|
| 17 |
+
28,
|
| 18 |
+
28
|
| 19 |
+
],
|
| 20 |
+
"patch_size": [
|
| 21 |
+
1,
|
| 22 |
+
2,
|
| 23 |
+
2
|
| 24 |
+
],
|
| 25 |
+
"qk_norm": true,
|
| 26 |
+
"qkv_bias": false,
|
| 27 |
+
"rope_dim_list": [
|
| 28 |
+
8,
|
| 29 |
+
28,
|
| 30 |
+
28
|
| 31 |
+
],
|
| 32 |
+
"rope_theta": 256,
|
| 33 |
+
"vae_time_compression_ratio": 4,
|
| 34 |
+
"windows_size": 3
|
| 35 |
+
},
|
| 36 |
+
"dim": 1536,
|
| 37 |
+
"eps": 1e-06,
|
| 38 |
+
"ffn_dim": 8960,
|
| 39 |
+
"freq_dim": 256,
|
| 40 |
+
"in_dim": 36,
|
| 41 |
+
"inject_sample_info": false,
|
| 42 |
+
"local_attn_size": 4,
|
| 43 |
+
"model_type": "i2v",
|
| 44 |
+
"num_heads": 12,
|
| 45 |
+
"num_layers": 30,
|
| 46 |
+
"out_dim": 16,
|
| 47 |
+
"sink_size": 0,
|
| 48 |
+
"text_len": 512
|
| 49 |
+
}
|
configs/distilled_model/templerun/config.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "CausalWanModel",
|
| 3 |
+
"_diffusers_version": "0.35.0.dev0",
|
| 4 |
+
"action_config": {
|
| 5 |
+
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
|
| 6 |
+
"enable_keyboard": true,
|
| 7 |
+
"enable_mouse": false,
|
| 8 |
+
"heads_num": 16,
|
| 9 |
+
"hidden_size": 128,
|
| 10 |
+
"img_hidden_size": 1536,
|
| 11 |
+
"keyboard_dim_in": 7,
|
| 12 |
+
"keyboard_hidden_dim": 1024,
|
| 13 |
+
"patch_size": [
|
| 14 |
+
1,
|
| 15 |
+
2,
|
| 16 |
+
2
|
| 17 |
+
],
|
| 18 |
+
"qk_norm": true,
|
| 19 |
+
"qkv_bias": false,
|
| 20 |
+
"rope_dim_list": [
|
| 21 |
+
8,
|
| 22 |
+
28,
|
| 23 |
+
28
|
| 24 |
+
],
|
| 25 |
+
"rope_theta": 256,
|
| 26 |
+
"vae_time_compression_ratio": 4,
|
| 27 |
+
"windows_size": 3
|
| 28 |
+
},
|
| 29 |
+
"dim": 1536,
|
| 30 |
+
"eps": 1e-06,
|
| 31 |
+
"ffn_dim": 8960,
|
| 32 |
+
"freq_dim": 256,
|
| 33 |
+
"in_dim": 36,
|
| 34 |
+
"inject_sample_info": false,
|
| 35 |
+
"local_attn_size": 6,
|
| 36 |
+
"model_type": "i2v",
|
| 37 |
+
"num_heads": 12,
|
| 38 |
+
"num_layers": 30,
|
| 39 |
+
"out_dim": 16,
|
| 40 |
+
"sink_size": 0,
|
| 41 |
+
"text_len": 512
|
| 42 |
+
}
|
configs/distilled_model/universal/config.json
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "CausalWanModel",
|
| 3 |
+
"_diffusers_version": "0.35.0.dev0",
|
| 4 |
+
"action_config": {
|
| 5 |
+
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
|
| 6 |
+
"enable_keyboard": true,
|
| 7 |
+
"enable_mouse": true,
|
| 8 |
+
"heads_num": 16,
|
| 9 |
+
"hidden_size": 128,
|
| 10 |
+
"img_hidden_size": 1536,
|
| 11 |
+
"keyboard_dim_in": 4,
|
| 12 |
+
"keyboard_hidden_dim": 1024,
|
| 13 |
+
"mouse_dim_in": 2,
|
| 14 |
+
"mouse_hidden_dim": 1024,
|
| 15 |
+
"mouse_qk_dim_list": [
|
| 16 |
+
8,
|
| 17 |
+
28,
|
| 18 |
+
28
|
| 19 |
+
],
|
| 20 |
+
"patch_size": [
|
| 21 |
+
1,
|
| 22 |
+
2,
|
| 23 |
+
2
|
| 24 |
+
],
|
| 25 |
+
"qk_norm": true,
|
| 26 |
+
"qkv_bias": false,
|
| 27 |
+
"rope_dim_list": [
|
| 28 |
+
8,
|
| 29 |
+
28,
|
| 30 |
+
28
|
| 31 |
+
],
|
| 32 |
+
"rope_theta": 256,
|
| 33 |
+
"vae_time_compression_ratio": 4,
|
| 34 |
+
"windows_size": 3
|
| 35 |
+
},
|
| 36 |
+
"dim": 1536,
|
| 37 |
+
"eps": 1e-06,
|
| 38 |
+
"ffn_dim": 8960,
|
| 39 |
+
"freq_dim": 256,
|
| 40 |
+
"in_dim": 36,
|
| 41 |
+
"inject_sample_info": false,
|
| 42 |
+
"local_attn_size": 6,
|
| 43 |
+
"model_type": "i2v",
|
| 44 |
+
"num_heads": 12,
|
| 45 |
+
"num_layers": 30,
|
| 46 |
+
"out_dim": 16,
|
| 47 |
+
"sink_size": 0,
|
| 48 |
+
"text_len": 512
|
| 49 |
+
}
|
configs/foundation_model/config.json
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "CausalWanModel",
|
| 3 |
+
"_diffusers_version": "0.35.0.dev0",
|
| 4 |
+
"action_config": {
|
| 5 |
+
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
|
| 6 |
+
"enable_keyboard": true,
|
| 7 |
+
"enable_mouse": true,
|
| 8 |
+
"heads_num": 16,
|
| 9 |
+
"hidden_size": 128,
|
| 10 |
+
"img_hidden_size": 1536,
|
| 11 |
+
"keyboard_dim_in": 4,
|
| 12 |
+
"keyboard_hidden_dim": 1024,
|
| 13 |
+
"mouse_dim_in": 2,
|
| 14 |
+
"mouse_hidden_dim": 1024,
|
| 15 |
+
"mouse_qk_dim_list": [
|
| 16 |
+
8,
|
| 17 |
+
28,
|
| 18 |
+
28
|
| 19 |
+
],
|
| 20 |
+
"patch_size": [
|
| 21 |
+
1,
|
| 22 |
+
2,
|
| 23 |
+
2
|
| 24 |
+
],
|
| 25 |
+
"qk_norm": true,
|
| 26 |
+
"qkv_bias": false,
|
| 27 |
+
"rope_dim_list": [
|
| 28 |
+
8,
|
| 29 |
+
28,
|
| 30 |
+
28
|
| 31 |
+
],
|
| 32 |
+
"rope_theta": 256,
|
| 33 |
+
"vae_time_compression_ratio": 4,
|
| 34 |
+
"windows_size": 3
|
| 35 |
+
},
|
| 36 |
+
"dim": 1536,
|
| 37 |
+
"eps": 1e-06,
|
| 38 |
+
"ffn_dim": 8960,
|
| 39 |
+
"freq_dim": 256,
|
| 40 |
+
"in_dim": 36,
|
| 41 |
+
"inject_sample_info": false,
|
| 42 |
+
"local_attn_size": -1,
|
| 43 |
+
"model_type": "i2v",
|
| 44 |
+
"num_heads": 12,
|
| 45 |
+
"num_layers": 30,
|
| 46 |
+
"out_dim": 16,
|
| 47 |
+
"sink_size": 0,
|
| 48 |
+
"text_len": 512
|
| 49 |
+
}
|
configs/inference_yaml/inference_gta_drive.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
denoising_step_list:
|
| 2 |
+
- 1000
|
| 3 |
+
- 666
|
| 4 |
+
- 333
|
| 5 |
+
warp_denoising_step: true
|
| 6 |
+
ts_schedule: false
|
| 7 |
+
mixed_precision: true
|
| 8 |
+
seed: 42
|
| 9 |
+
image_or_video_shape:
|
| 10 |
+
- 1
|
| 11 |
+
- 16
|
| 12 |
+
- 15
|
| 13 |
+
- 44
|
| 14 |
+
- 80
|
| 15 |
+
num_frame_per_block: 3
|
| 16 |
+
context_noise: 0
|
| 17 |
+
mode: gta_drive
|
| 18 |
+
causal: true
|
| 19 |
+
model_kwargs:
|
| 20 |
+
timestep_shift: 5.0
|
| 21 |
+
model_config: configs/distilled_model/gta_drive
|
configs/inference_yaml/inference_templerun.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
denoising_step_list:
|
| 2 |
+
- 1000
|
| 3 |
+
- 750
|
| 4 |
+
- 500
|
| 5 |
+
- 250
|
| 6 |
+
warp_denoising_step: true
|
| 7 |
+
ts_schedule: false
|
| 8 |
+
mixed_precision: true
|
| 9 |
+
seed: 42
|
| 10 |
+
image_or_video_shape:
|
| 11 |
+
- 1
|
| 12 |
+
- 16
|
| 13 |
+
- 15
|
| 14 |
+
- 44
|
| 15 |
+
- 80
|
| 16 |
+
num_frame_per_block: 3
|
| 17 |
+
context_noise: 0
|
| 18 |
+
mode: templerun
|
| 19 |
+
causal: true
|
| 20 |
+
model_kwargs:
|
| 21 |
+
timestep_shift: 5.0
|
| 22 |
+
model_config: configs/distilled_model/templerun
|
configs/inference_yaml/inference_universal.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
denoising_step_list:
|
| 2 |
+
- 1000
|
| 3 |
+
- 666
|
| 4 |
+
- 333
|
| 5 |
+
warp_denoising_step: true
|
| 6 |
+
ts_schedule: false
|
| 7 |
+
mixed_precision: true
|
| 8 |
+
seed: 42
|
| 9 |
+
image_or_video_shape:
|
| 10 |
+
- 1
|
| 11 |
+
- 16
|
| 12 |
+
- 15
|
| 13 |
+
- 44
|
| 14 |
+
- 80
|
| 15 |
+
num_frame_per_block: 3
|
| 16 |
+
context_noise: 0
|
| 17 |
+
mode: universal
|
| 18 |
+
causal: true
|
| 19 |
+
model_kwargs:
|
| 20 |
+
timestep_shift: 5.0
|
| 21 |
+
model_config: configs/distilled_model/universal
|
demo_images/gta_drive/0000.png
ADDED
|
Git LFS Details
|
demo_images/gta_drive/0001.png
ADDED
|
Git LFS Details
|
demo_images/gta_drive/0002.png
ADDED
|
Git LFS Details
|
demo_images/gta_drive/0003.png
ADDED
|
Git LFS Details
|
demo_images/gta_drive/0004.png
ADDED
|
Git LFS Details
|
demo_images/gta_drive/0005.png
ADDED
|
Git LFS Details
|
demo_images/temple_run/0000.png
ADDED
|
Git LFS Details
|
demo_images/temple_run/0001.png
ADDED
|
Git LFS Details
|
demo_images/temple_run/0002.png
ADDED
|
Git LFS Details
|
demo_images/temple_run/0003.png
ADDED
|
Git LFS Details
|
demo_images/temple_run/0004.png
ADDED
|
Git LFS Details
|
demo_images/temple_run/0005.png
ADDED
|
Git LFS Details
|
demo_images/universal/0000.png
ADDED
|
demo_images/universal/0001.png
ADDED
|
Git LFS Details
|
demo_images/universal/0002.png
ADDED
|
Git LFS Details
|
demo_images/universal/0003.png
ADDED
|
Git LFS Details
|
demo_images/universal/0004.png
ADDED
|
Git LFS Details
|
demo_images/universal/0005.png
ADDED
|
Git LFS Details
|
demo_images/universal/0006.png
ADDED
|
Git LFS Details
|
demo_images/universal/0007.png
ADDED
|
Git LFS Details
|
demo_images/universal/0008.png
ADDED
|
Git LFS Details
|
demo_images/universal/0009.png
ADDED
|
Git LFS Details
|
demo_images/universal/0010.webp
ADDED
|
demo_images/universal/0011.png
ADDED
|
Git LFS Details
|
demo_images/universal/0012.png
ADDED
|
Git LFS Details
|
demo_images/universal/0013.png
ADDED
|
Git LFS Details
|
demo_images/universal/0014.png
ADDED
|
Git LFS Details
|
demo_images/universal/0015.png
ADDED
|
demo_images/universal/0016.png
ADDED
|
Git LFS Details
|
demo_utils/constant.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
base_size = 80
|
| 5 |
+
base_size2 = 44
|
| 6 |
+
ZERO_VAE_CACHE = [
|
| 7 |
+
torch.zeros(1, 16, 2, base_size2, base_size),
|
| 8 |
+
torch.zeros(1, 384, 2, base_size2, base_size),
|
| 9 |
+
torch.zeros(1, 384, 2, base_size2, base_size),
|
| 10 |
+
torch.zeros(1, 384, 2, base_size2, base_size),
|
| 11 |
+
torch.zeros(1, 384, 2, base_size2, base_size),
|
| 12 |
+
torch.zeros(1, 384, 2, base_size2, base_size),
|
| 13 |
+
torch.zeros(1, 384, 2, base_size2, base_size),
|
| 14 |
+
torch.zeros(1, 384, 2, base_size2, base_size),
|
| 15 |
+
torch.zeros(1, 384, 2, base_size2, base_size),
|
| 16 |
+
torch.zeros(1, 384, 2, base_size2, base_size),
|
| 17 |
+
torch.zeros(1, 384, 2, base_size2, base_size),
|
| 18 |
+
torch.zeros(1, 384, 2, base_size2, base_size),
|
| 19 |
+
torch.zeros(1, 192, 2, base_size2*2, base_size*2),
|
| 20 |
+
torch.zeros(1, 384, 2, base_size2*2, base_size*2),
|
| 21 |
+
torch.zeros(1, 384, 2, base_size2*2, base_size*2),
|
| 22 |
+
torch.zeros(1, 384, 2, base_size2*2, base_size*2),
|
| 23 |
+
torch.zeros(1, 384, 2, base_size2*2, base_size*2),
|
| 24 |
+
torch.zeros(1, 384, 2, base_size2*2, base_size*2),
|
| 25 |
+
torch.zeros(1, 384, 2, base_size2*2, base_size*2),
|
| 26 |
+
torch.zeros(1, 192, 2, base_size2*4, base_size*4),
|
| 27 |
+
torch.zeros(1, 192, 2, base_size2*4, base_size*4),
|
| 28 |
+
torch.zeros(1, 192, 2, base_size2*4, base_size*4),
|
| 29 |
+
torch.zeros(1, 192, 2, base_size2*4, base_size*4),
|
| 30 |
+
torch.zeros(1, 192, 2, base_size2*4, base_size*4),
|
| 31 |
+
torch.zeros(1, 192, 2, base_size2*4, base_size*4),
|
| 32 |
+
torch.zeros(1, 96, 2, base_size2*8, base_size*8),
|
| 33 |
+
torch.zeros(1, 96, 2, base_size2*8, base_size*8),
|
| 34 |
+
torch.zeros(1, 96, 2, base_size2*8, base_size*8),
|
| 35 |
+
torch.zeros(1, 96, 2, base_size2*8, base_size*8),
|
| 36 |
+
torch.zeros(1, 96, 2, base_size2*8, base_size*8),
|
| 37 |
+
torch.zeros(1, 96, 2, base_size2*8, base_size*8),
|
| 38 |
+
torch.zeros(1, 96, 2, base_size2*8, base_size*8)
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
feat_names = [f"vae_cache_{i}" for i in range(len(ZERO_VAE_CACHE))]
|
| 42 |
+
ALL_INPUTS_NAMES = ["z", "use_cache"] + feat_names
|
demo_utils/memory.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils
|
| 2 |
+
# Apache-2.0 License
|
| 3 |
+
# By lllyasviel
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
cpu = torch.device('cpu')
|
| 9 |
+
gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
|
| 10 |
+
gpu_complete_modules = []
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DynamicSwapInstaller:
|
| 14 |
+
@staticmethod
|
| 15 |
+
def _install_module(module: torch.nn.Module, **kwargs):
|
| 16 |
+
original_class = module.__class__
|
| 17 |
+
module.__dict__['forge_backup_original_class'] = original_class
|
| 18 |
+
|
| 19 |
+
def hacked_get_attr(self, name: str):
|
| 20 |
+
if '_parameters' in self.__dict__:
|
| 21 |
+
_parameters = self.__dict__['_parameters']
|
| 22 |
+
if name in _parameters:
|
| 23 |
+
p = _parameters[name]
|
| 24 |
+
if p is None:
|
| 25 |
+
return None
|
| 26 |
+
if p.__class__ == torch.nn.Parameter:
|
| 27 |
+
return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
|
| 28 |
+
else:
|
| 29 |
+
return p.to(**kwargs)
|
| 30 |
+
if '_buffers' in self.__dict__:
|
| 31 |
+
_buffers = self.__dict__['_buffers']
|
| 32 |
+
if name in _buffers:
|
| 33 |
+
return _buffers[name].to(**kwargs)
|
| 34 |
+
return super(original_class, self).__getattr__(name)
|
| 35 |
+
|
| 36 |
+
module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
|
| 37 |
+
'__getattr__': hacked_get_attr,
|
| 38 |
+
})
|
| 39 |
+
|
| 40 |
+
return
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def _uninstall_module(module: torch.nn.Module):
|
| 44 |
+
if 'forge_backup_original_class' in module.__dict__:
|
| 45 |
+
module.__class__ = module.__dict__.pop('forge_backup_original_class')
|
| 46 |
+
return
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def install_model(model: torch.nn.Module, **kwargs):
|
| 50 |
+
for m in model.modules():
|
| 51 |
+
DynamicSwapInstaller._install_module(m, **kwargs)
|
| 52 |
+
return
|
| 53 |
+
|
| 54 |
+
@staticmethod
|
| 55 |
+
def uninstall_model(model: torch.nn.Module):
|
| 56 |
+
for m in model.modules():
|
| 57 |
+
DynamicSwapInstaller._uninstall_module(m)
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
|
| 62 |
+
if hasattr(model, 'scale_shift_table'):
|
| 63 |
+
model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
|
| 64 |
+
return
|
| 65 |
+
|
| 66 |
+
for k, p in model.named_modules():
|
| 67 |
+
if hasattr(p, 'weight'):
|
| 68 |
+
p.to(target_device)
|
| 69 |
+
return
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_cuda_free_memory_gb(device=None):
|
| 73 |
+
if device is None:
|
| 74 |
+
device = gpu
|
| 75 |
+
|
| 76 |
+
memory_stats = torch.cuda.memory_stats(device)
|
| 77 |
+
bytes_active = memory_stats['active_bytes.all.current']
|
| 78 |
+
bytes_reserved = memory_stats['reserved_bytes.all.current']
|
| 79 |
+
bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
|
| 80 |
+
bytes_inactive_reserved = bytes_reserved - bytes_active
|
| 81 |
+
bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
|
| 82 |
+
return bytes_total_available / (1024 ** 3)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
|
| 86 |
+
print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
|
| 87 |
+
|
| 88 |
+
for m in model.modules():
|
| 89 |
+
if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
|
| 90 |
+
torch.cuda.empty_cache()
|
| 91 |
+
return
|
| 92 |
+
|
| 93 |
+
if hasattr(m, 'weight'):
|
| 94 |
+
m.to(device=target_device)
|
| 95 |
+
|
| 96 |
+
model.to(device=target_device)
|
| 97 |
+
torch.cuda.empty_cache()
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
|
| 102 |
+
print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
|
| 103 |
+
|
| 104 |
+
for m in model.modules():
|
| 105 |
+
if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
|
| 106 |
+
torch.cuda.empty_cache()
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
if hasattr(m, 'weight'):
|
| 110 |
+
m.to(device=cpu)
|
| 111 |
+
|
| 112 |
+
model.to(device=cpu)
|
| 113 |
+
torch.cuda.empty_cache()
|
| 114 |
+
return
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def unload_complete_models(*args):
|
| 118 |
+
for m in gpu_complete_modules + list(args):
|
| 119 |
+
m.to(device=cpu)
|
| 120 |
+
print(f'Unloaded {m.__class__.__name__} as complete.')
|
| 121 |
+
|
| 122 |
+
gpu_complete_modules.clear()
|
| 123 |
+
torch.cuda.empty_cache()
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def load_model_as_complete(model, target_device, unload=True):
|
| 128 |
+
if unload:
|
| 129 |
+
unload_complete_models()
|
| 130 |
+
|
| 131 |
+
model.to(device=target_device)
|
| 132 |
+
print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
|
| 133 |
+
|
| 134 |
+
gpu_complete_modules.append(model)
|
| 135 |
+
return
|
demo_utils/taehv.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Tiny AutoEncoder for Hunyuan Video
|
| 4 |
+
(DNN for encoding / decoding videos to Hunyuan Video's latent space)
|
| 5 |
+
"""
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from tqdm.auto import tqdm
|
| 10 |
+
from collections import namedtuple
|
| 11 |
+
|
| 12 |
+
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
| 13 |
+
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def conv(n_in, n_out, **kwargs):
|
| 17 |
+
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Clamp(nn.Module):
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
return torch.tanh(x / 3) * 3
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class MemBlock(nn.Module):
|
| 26 |
+
def __init__(self, n_in, n_out):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True),
|
| 29 |
+
conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
|
| 30 |
+
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
| 31 |
+
self.act = nn.ReLU(inplace=True)
|
| 32 |
+
|
| 33 |
+
def forward(self, x, past):
|
| 34 |
+
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TPool(nn.Module):
|
| 38 |
+
def __init__(self, n_f, stride):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.stride = stride
|
| 41 |
+
self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)
|
| 42 |
+
|
| 43 |
+
def forward(self, x):
|
| 44 |
+
_NT, C, H, W = x.shape
|
| 45 |
+
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class TGrow(nn.Module):
|
| 49 |
+
def __init__(self, n_f, stride):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.stride = stride
|
| 52 |
+
self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
|
| 53 |
+
|
| 54 |
+
def forward(self, x):
|
| 55 |
+
_NT, C, H, W = x.shape
|
| 56 |
+
x = self.conv(x)
|
| 57 |
+
return x.reshape(-1, C, H, W)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
| 61 |
+
"""
|
| 62 |
+
Apply a sequential model with memblocks to the given input.
|
| 63 |
+
Args:
|
| 64 |
+
- model: nn.Sequential of blocks to apply
|
| 65 |
+
- x: input data, of dimensions NTCHW
|
| 66 |
+
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
|
| 67 |
+
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
|
| 68 |
+
- show_progress_bar: if True, enables tqdm progressbar display
|
| 69 |
+
|
| 70 |
+
Returns NTCHW tensor of output data.
|
| 71 |
+
"""
|
| 72 |
+
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
|
| 73 |
+
N, T, C, H, W = x.shape
|
| 74 |
+
if parallel:
|
| 75 |
+
x = x.reshape(N * T, C, H, W)
|
| 76 |
+
# parallel over input timesteps, iterate over blocks
|
| 77 |
+
for b in tqdm(model, disable=not show_progress_bar):
|
| 78 |
+
if isinstance(b, MemBlock):
|
| 79 |
+
NT, C, H, W = x.shape
|
| 80 |
+
T = NT // N
|
| 81 |
+
_x = x.reshape(N, T, C, H, W)
|
| 82 |
+
mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape)
|
| 83 |
+
x = b(x, mem)
|
| 84 |
+
else:
|
| 85 |
+
x = b(x)
|
| 86 |
+
NT, C, H, W = x.shape
|
| 87 |
+
T = NT // N
|
| 88 |
+
x = x.view(N, T, C, H, W)
|
| 89 |
+
else:
|
| 90 |
+
# TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
|
| 91 |
+
# need to fix :(
|
| 92 |
+
out = []
|
| 93 |
+
# iterate over input timesteps and also iterate over blocks.
|
| 94 |
+
# because of the cursed TPool/TGrow blocks, this is not a nested loop,
|
| 95 |
+
# it's actually a ***graph traversal*** problem! so let's make a queue
|
| 96 |
+
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
|
| 97 |
+
# in addition to manually managing our queue, we also need to manually manage our progressbar.
|
| 98 |
+
# we'll update it for every source node that we consume.
|
| 99 |
+
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
| 100 |
+
# we'll also need a separate addressable memory per node as well
|
| 101 |
+
mem = [None] * len(model)
|
| 102 |
+
while work_queue:
|
| 103 |
+
xt, i = work_queue.pop(0)
|
| 104 |
+
if i == 0:
|
| 105 |
+
# new source node consumed
|
| 106 |
+
progress_bar.update(1)
|
| 107 |
+
if i == len(model):
|
| 108 |
+
# reached end of the graph, append result to output list
|
| 109 |
+
out.append(xt)
|
| 110 |
+
else:
|
| 111 |
+
# fetch the block to process
|
| 112 |
+
b = model[i]
|
| 113 |
+
if isinstance(b, MemBlock):
|
| 114 |
+
# mem blocks are simple since we're visiting the graph in causal order
|
| 115 |
+
if mem[i] is None:
|
| 116 |
+
xt_new = b(xt, xt * 0)
|
| 117 |
+
mem[i] = xt
|
| 118 |
+
else:
|
| 119 |
+
xt_new = b(xt, mem[i])
|
| 120 |
+
mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
|
| 121 |
+
# add successor to work queue
|
| 122 |
+
work_queue.insert(0, TWorkItem(xt_new, i + 1))
|
| 123 |
+
elif isinstance(b, TPool):
|
| 124 |
+
# pool blocks are miserable
|
| 125 |
+
if mem[i] is None:
|
| 126 |
+
mem[i] = [] # pool memory is itself a queue of inputs to pool
|
| 127 |
+
mem[i].append(xt)
|
| 128 |
+
if len(mem[i]) > b.stride:
|
| 129 |
+
# pool mem is in invalid state, we should have pooled before this
|
| 130 |
+
raise ValueError("???")
|
| 131 |
+
elif len(mem[i]) < b.stride:
|
| 132 |
+
# pool mem is not yet full, go back to processing the work queue
|
| 133 |
+
pass
|
| 134 |
+
else:
|
| 135 |
+
# pool mem is ready, run the pool block
|
| 136 |
+
N, C, H, W = xt.shape
|
| 137 |
+
xt = b(torch.cat(mem[i], 1).view(N * b.stride, C, H, W))
|
| 138 |
+
# reset the pool mem
|
| 139 |
+
mem[i] = []
|
| 140 |
+
# add successor to work queue
|
| 141 |
+
work_queue.insert(0, TWorkItem(xt, i + 1))
|
| 142 |
+
elif isinstance(b, TGrow):
|
| 143 |
+
xt = b(xt)
|
| 144 |
+
NT, C, H, W = xt.shape
|
| 145 |
+
# each tgrow has multiple successor nodes
|
| 146 |
+
for xt_next in reversed(xt.view(N, b.stride * C, H, W).chunk(b.stride, 1)):
|
| 147 |
+
# add successor to work queue
|
| 148 |
+
work_queue.insert(0, TWorkItem(xt_next, i + 1))
|
| 149 |
+
else:
|
| 150 |
+
# normal block with no funny business
|
| 151 |
+
xt = b(xt)
|
| 152 |
+
# add successor to work queue
|
| 153 |
+
work_queue.insert(0, TWorkItem(xt, i + 1))
|
| 154 |
+
progress_bar.close()
|
| 155 |
+
x = torch.stack(out, 1)
|
| 156 |
+
return x
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class TAEHV(nn.Module):
|
| 160 |
+
latent_channels = 16
|
| 161 |
+
image_channels = 3
|
| 162 |
+
|
| 163 |
+
def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
|
| 164 |
+
"""Initialize pretrained TAEHV from the given checkpoint.
|
| 165 |
+
|
| 166 |
+
Arg:
|
| 167 |
+
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
|
| 168 |
+
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
|
| 169 |
+
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
|
| 170 |
+
"""
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.encoder = nn.Sequential(
|
| 173 |
+
conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
|
| 174 |
+
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
| 175 |
+
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
| 176 |
+
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
| 177 |
+
conv(64, TAEHV.latent_channels),
|
| 178 |
+
)
|
| 179 |
+
n_f = [256, 128, 64, 64]
|
| 180 |
+
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
| 181 |
+
self.decoder = nn.Sequential(
|
| 182 |
+
Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
|
| 183 |
+
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(
|
| 184 |
+
scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
| 185 |
+
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(
|
| 186 |
+
scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
| 187 |
+
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(
|
| 188 |
+
scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
| 189 |
+
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
|
| 190 |
+
)
|
| 191 |
+
if checkpoint_path is not None:
|
| 192 |
+
self.load_state_dict(self.patch_tgrow_layers(torch.load(
|
| 193 |
+
checkpoint_path, map_location="cpu", weights_only=True)))
|
| 194 |
+
|
| 195 |
+
def patch_tgrow_layers(self, sd):
|
| 196 |
+
"""Patch TGrow layers to use a smaller kernel if needed.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
sd: state dict to patch
|
| 200 |
+
"""
|
| 201 |
+
new_sd = self.state_dict()
|
| 202 |
+
for i, layer in enumerate(self.decoder):
|
| 203 |
+
if isinstance(layer, TGrow):
|
| 204 |
+
key = f"decoder.{i}.conv.weight"
|
| 205 |
+
if sd[key].shape[0] > new_sd[key].shape[0]:
|
| 206 |
+
# take the last-timestep output channels
|
| 207 |
+
sd[key] = sd[key][-new_sd[key].shape[0]:]
|
| 208 |
+
return sd
|
| 209 |
+
|
| 210 |
+
def encode_video(self, x, parallel=True, show_progress_bar=True):
|
| 211 |
+
"""Encode a sequence of frames.
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
x: input NTCHW RGB (C=3) tensor with values in [0, 1].
|
| 215 |
+
parallel: if True, all frames will be processed at once.
|
| 216 |
+
(this is faster but may require more memory).
|
| 217 |
+
if False, frames will be processed sequentially.
|
| 218 |
+
Returns NTCHW latent tensor with ~Gaussian values.
|
| 219 |
+
"""
|
| 220 |
+
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
|
| 221 |
+
|
| 222 |
+
def decode_video(self, x, parallel=True, show_progress_bar=False):
|
| 223 |
+
"""Decode a sequence of frames.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
x: input NTCHW latent (C=12) tensor with ~Gaussian values.
|
| 227 |
+
parallel: if True, all frames will be processed at once.
|
| 228 |
+
(this is faster but may require more memory).
|
| 229 |
+
if False, frames will be processed sequentially.
|
| 230 |
+
Returns NTCHW RGB tensor with ~[0, 1] values.
|
| 231 |
+
"""
|
| 232 |
+
x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
|
| 233 |
+
# return x[:, self.frames_to_trim:]
|
| 234 |
+
return x
|
| 235 |
+
|
| 236 |
+
def forward(self, x):
|
| 237 |
+
return self.c(x)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@torch.no_grad()
|
| 241 |
+
def main():
|
| 242 |
+
"""Run TAEHV roundtrip reconstruction on the given video paths."""
|
| 243 |
+
import os
|
| 244 |
+
import sys
|
| 245 |
+
import cv2 # no highly esteemed deed is commemorated here
|
| 246 |
+
|
| 247 |
+
class VideoTensorReader:
|
| 248 |
+
def __init__(self, video_file_path):
|
| 249 |
+
self.cap = cv2.VideoCapture(video_file_path)
|
| 250 |
+
assert self.cap.isOpened(), f"Could not load {video_file_path}"
|
| 251 |
+
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
|
| 252 |
+
|
| 253 |
+
def __iter__(self):
|
| 254 |
+
return self
|
| 255 |
+
|
| 256 |
+
def __next__(self):
|
| 257 |
+
ret, frame = self.cap.read()
|
| 258 |
+
if not ret:
|
| 259 |
+
self.cap.release()
|
| 260 |
+
raise StopIteration # End of video or error
|
| 261 |
+
return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW
|
| 262 |
+
|
| 263 |
+
class VideoTensorWriter:
|
| 264 |
+
def __init__(self, video_file_path, width_height, fps=30):
|
| 265 |
+
self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, width_height)
|
| 266 |
+
assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"
|
| 267 |
+
|
| 268 |
+
def write(self, frame_tensor):
|
| 269 |
+
assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
|
| 270 |
+
self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(),
|
| 271 |
+
cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC
|
| 272 |
+
|
| 273 |
+
def __del__(self):
|
| 274 |
+
if hasattr(self, 'writer'):
|
| 275 |
+
self.writer.release()
|
| 276 |
+
|
| 277 |
+
dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|
| 278 |
+
dtype = torch.float16
|
| 279 |
+
checkpoint_path = os.getenv("TAEHV_CHECKPOINT_PATH", "taehv.pth")
|
| 280 |
+
checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
|
| 281 |
+
print(
|
| 282 |
+
f"Using device \033[31m{dev}\033[0m, dtype \033[32m{dtype}\033[0m, checkpoint \033[34m{checkpoint_name}\033[0m ({checkpoint_path})")
|
| 283 |
+
taehv = TAEHV(checkpoint_path=checkpoint_path).to(dev, dtype)
|
| 284 |
+
for video_path in sys.argv[1:]:
|
| 285 |
+
print(f"Processing {video_path}...")
|
| 286 |
+
video_in = VideoTensorReader(video_path)
|
| 287 |
+
video = torch.stack(list(video_in), 0)[None]
|
| 288 |
+
vid_dev = video.to(dev, dtype).div_(255.0)
|
| 289 |
+
# convert to device tensor
|
| 290 |
+
if video.numel() < 100_000_000:
|
| 291 |
+
print(f" {video_path} seems small enough, will process all frames in parallel")
|
| 292 |
+
# convert to device tensor
|
| 293 |
+
vid_enc = taehv.encode_video(vid_dev)
|
| 294 |
+
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
|
| 295 |
+
vid_dec = taehv.decode_video(vid_enc)
|
| 296 |
+
print(f" Decoded {video_path} -> {vid_dec.shape}")
|
| 297 |
+
else:
|
| 298 |
+
print(f" {video_path} seems large, will process each frame sequentially")
|
| 299 |
+
# convert to device tensor
|
| 300 |
+
vid_enc = taehv.encode_video(vid_dev, parallel=False)
|
| 301 |
+
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
|
| 302 |
+
vid_dec = taehv.decode_video(vid_enc, parallel=False)
|
| 303 |
+
print(f" Decoded {video_path} -> {vid_dec.shape}")
|
| 304 |
+
video_out_path = video_path + f".reconstructed_by_{checkpoint_name}.mp4"
|
| 305 |
+
video_out = VideoTensorWriter(
|
| 306 |
+
video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
|
| 307 |
+
for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
|
| 308 |
+
video_out.write(frame)
|
| 309 |
+
print(f" Saved to {video_out_path}")
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
if __name__ == "__main__":
|
| 313 |
+
main()
|
demo_utils/utils.py
ADDED
|
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils
|
| 2 |
+
# Apache-2.0 License
|
| 3 |
+
# By lllyasviel
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import cv2
|
| 7 |
+
import json
|
| 8 |
+
import random
|
| 9 |
+
import glob
|
| 10 |
+
import torch
|
| 11 |
+
import einops
|
| 12 |
+
import numpy as np
|
| 13 |
+
import datetime
|
| 14 |
+
import torchvision
|
| 15 |
+
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def min_resize(x, m):
|
| 20 |
+
if x.shape[0] < x.shape[1]:
|
| 21 |
+
s0 = m
|
| 22 |
+
s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
|
| 23 |
+
else:
|
| 24 |
+
s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
|
| 25 |
+
s1 = m
|
| 26 |
+
new_max = max(s1, s0)
|
| 27 |
+
raw_max = max(x.shape[0], x.shape[1])
|
| 28 |
+
if new_max < raw_max:
|
| 29 |
+
interpolation = cv2.INTER_AREA
|
| 30 |
+
else:
|
| 31 |
+
interpolation = cv2.INTER_LANCZOS4
|
| 32 |
+
y = cv2.resize(x, (s1, s0), interpolation=interpolation)
|
| 33 |
+
return y
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def d_resize(x, y):
|
| 37 |
+
H, W, C = y.shape
|
| 38 |
+
new_min = min(H, W)
|
| 39 |
+
raw_min = min(x.shape[0], x.shape[1])
|
| 40 |
+
if new_min < raw_min:
|
| 41 |
+
interpolation = cv2.INTER_AREA
|
| 42 |
+
else:
|
| 43 |
+
interpolation = cv2.INTER_LANCZOS4
|
| 44 |
+
y = cv2.resize(x, (W, H), interpolation=interpolation)
|
| 45 |
+
return y
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def resize_and_center_crop(image, target_width, target_height):
|
| 49 |
+
if target_height == image.shape[0] and target_width == image.shape[1]:
|
| 50 |
+
return image
|
| 51 |
+
|
| 52 |
+
pil_image = Image.fromarray(image)
|
| 53 |
+
original_width, original_height = pil_image.size
|
| 54 |
+
scale_factor = max(target_width / original_width, target_height / original_height)
|
| 55 |
+
resized_width = int(round(original_width * scale_factor))
|
| 56 |
+
resized_height = int(round(original_height * scale_factor))
|
| 57 |
+
resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
|
| 58 |
+
left = (resized_width - target_width) / 2
|
| 59 |
+
top = (resized_height - target_height) / 2
|
| 60 |
+
right = (resized_width + target_width) / 2
|
| 61 |
+
bottom = (resized_height + target_height) / 2
|
| 62 |
+
cropped_image = resized_image.crop((left, top, right, bottom))
|
| 63 |
+
return np.array(cropped_image)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def resize_and_center_crop_pytorch(image, target_width, target_height):
|
| 67 |
+
B, C, H, W = image.shape
|
| 68 |
+
|
| 69 |
+
if H == target_height and W == target_width:
|
| 70 |
+
return image
|
| 71 |
+
|
| 72 |
+
scale_factor = max(target_width / W, target_height / H)
|
| 73 |
+
resized_width = int(round(W * scale_factor))
|
| 74 |
+
resized_height = int(round(H * scale_factor))
|
| 75 |
+
|
| 76 |
+
resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
|
| 77 |
+
|
| 78 |
+
top = (resized_height - target_height) // 2
|
| 79 |
+
left = (resized_width - target_width) // 2
|
| 80 |
+
cropped = resized[:, :, top:top + target_height, left:left + target_width]
|
| 81 |
+
|
| 82 |
+
return cropped
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def resize_without_crop(image, target_width, target_height):
|
| 86 |
+
if target_height == image.shape[0] and target_width == image.shape[1]:
|
| 87 |
+
return image
|
| 88 |
+
|
| 89 |
+
pil_image = Image.fromarray(image)
|
| 90 |
+
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
| 91 |
+
return np.array(resized_image)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def just_crop(image, w, h):
|
| 95 |
+
if h == image.shape[0] and w == image.shape[1]:
|
| 96 |
+
return image
|
| 97 |
+
|
| 98 |
+
original_height, original_width = image.shape[:2]
|
| 99 |
+
k = min(original_height / h, original_width / w)
|
| 100 |
+
new_width = int(round(w * k))
|
| 101 |
+
new_height = int(round(h * k))
|
| 102 |
+
x_start = (original_width - new_width) // 2
|
| 103 |
+
y_start = (original_height - new_height) // 2
|
| 104 |
+
cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
|
| 105 |
+
return cropped_image
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def write_to_json(data, file_path):
|
| 109 |
+
temp_file_path = file_path + ".tmp"
|
| 110 |
+
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
|
| 111 |
+
json.dump(data, temp_file, indent=4)
|
| 112 |
+
os.replace(temp_file_path, file_path)
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def read_from_json(file_path):
|
| 117 |
+
with open(file_path, 'rt', encoding='utf-8') as file:
|
| 118 |
+
data = json.load(file)
|
| 119 |
+
return data
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def get_active_parameters(m):
|
| 123 |
+
return {k: v for k, v in m.named_parameters() if v.requires_grad}
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def cast_training_params(m, dtype=torch.float32):
|
| 127 |
+
result = {}
|
| 128 |
+
for n, param in m.named_parameters():
|
| 129 |
+
if param.requires_grad:
|
| 130 |
+
param.data = param.to(dtype)
|
| 131 |
+
result[n] = param
|
| 132 |
+
return result
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def separate_lora_AB(parameters, B_patterns=None):
|
| 136 |
+
parameters_normal = {}
|
| 137 |
+
parameters_B = {}
|
| 138 |
+
|
| 139 |
+
if B_patterns is None:
|
| 140 |
+
B_patterns = ['.lora_B.', '__zero__']
|
| 141 |
+
|
| 142 |
+
for k, v in parameters.items():
|
| 143 |
+
if any(B_pattern in k for B_pattern in B_patterns):
|
| 144 |
+
parameters_B[k] = v
|
| 145 |
+
else:
|
| 146 |
+
parameters_normal[k] = v
|
| 147 |
+
|
| 148 |
+
return parameters_normal, parameters_B
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def set_attr_recursive(obj, attr, value):
|
| 152 |
+
attrs = attr.split(".")
|
| 153 |
+
for name in attrs[:-1]:
|
| 154 |
+
obj = getattr(obj, name)
|
| 155 |
+
setattr(obj, attrs[-1], value)
|
| 156 |
+
return
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def print_tensor_list_size(tensors):
|
| 160 |
+
total_size = 0
|
| 161 |
+
total_elements = 0
|
| 162 |
+
|
| 163 |
+
if isinstance(tensors, dict):
|
| 164 |
+
tensors = tensors.values()
|
| 165 |
+
|
| 166 |
+
for tensor in tensors:
|
| 167 |
+
total_size += tensor.nelement() * tensor.element_size()
|
| 168 |
+
total_elements += tensor.nelement()
|
| 169 |
+
|
| 170 |
+
total_size_MB = total_size / (1024 ** 2)
|
| 171 |
+
total_elements_B = total_elements / 1e9
|
| 172 |
+
|
| 173 |
+
print(f"Total number of tensors: {len(tensors)}")
|
| 174 |
+
print(f"Total size of tensors: {total_size_MB:.2f} MB")
|
| 175 |
+
print(f"Total number of parameters: {total_elements_B:.3f} billion")
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@torch.no_grad()
|
| 180 |
+
def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
|
| 181 |
+
batch_size = a.size(0)
|
| 182 |
+
|
| 183 |
+
if b is None:
|
| 184 |
+
b = torch.zeros_like(a)
|
| 185 |
+
|
| 186 |
+
if mask_a is None:
|
| 187 |
+
mask_a = torch.rand(batch_size) < probability_a
|
| 188 |
+
|
| 189 |
+
mask_a = mask_a.to(a.device)
|
| 190 |
+
mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
|
| 191 |
+
result = torch.where(mask_a, a, b)
|
| 192 |
+
return result
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@torch.no_grad()
|
| 196 |
+
def zero_module(module):
|
| 197 |
+
for p in module.parameters():
|
| 198 |
+
p.detach().zero_()
|
| 199 |
+
return module
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
@torch.no_grad()
|
| 203 |
+
def supress_lower_channels(m, k, alpha=0.01):
|
| 204 |
+
data = m.weight.data.clone()
|
| 205 |
+
|
| 206 |
+
assert int(data.shape[1]) >= k
|
| 207 |
+
|
| 208 |
+
data[:, :k] = data[:, :k] * alpha
|
| 209 |
+
m.weight.data = data.contiguous().clone()
|
| 210 |
+
return m
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def freeze_module(m):
|
| 214 |
+
if not hasattr(m, '_forward_inside_frozen_module'):
|
| 215 |
+
m._forward_inside_frozen_module = m.forward
|
| 216 |
+
m.requires_grad_(False)
|
| 217 |
+
m.forward = torch.no_grad()(m.forward)
|
| 218 |
+
return m
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def get_latest_safetensors(folder_path):
|
| 222 |
+
safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
|
| 223 |
+
|
| 224 |
+
if not safetensors_files:
|
| 225 |
+
raise ValueError('No file to resume!')
|
| 226 |
+
|
| 227 |
+
latest_file = max(safetensors_files, key=os.path.getmtime)
|
| 228 |
+
latest_file = os.path.abspath(os.path.realpath(latest_file))
|
| 229 |
+
return latest_file
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
|
| 233 |
+
tags = tags_str.split(', ')
|
| 234 |
+
tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
|
| 235 |
+
prompt = ', '.join(tags)
|
| 236 |
+
return prompt
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
|
| 240 |
+
numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
|
| 241 |
+
if round_to_int:
|
| 242 |
+
numbers = np.round(numbers).astype(int)
|
| 243 |
+
return numbers.tolist()
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
|
| 247 |
+
edges = np.linspace(0, 1, n + 1)
|
| 248 |
+
points = np.random.uniform(edges[:-1], edges[1:])
|
| 249 |
+
numbers = inclusive + (exclusive - inclusive) * points
|
| 250 |
+
if round_to_int:
|
| 251 |
+
numbers = np.round(numbers).astype(int)
|
| 252 |
+
return numbers.tolist()
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def soft_append_bcthw(history, current, overlap=0):
|
| 256 |
+
if overlap <= 0:
|
| 257 |
+
return torch.cat([history, current], dim=2)
|
| 258 |
+
|
| 259 |
+
assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
|
| 260 |
+
assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
|
| 261 |
+
|
| 262 |
+
weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
|
| 263 |
+
blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
|
| 264 |
+
output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
|
| 265 |
+
|
| 266 |
+
return output.to(history)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
|
| 270 |
+
b, c, t, h, w = x.shape
|
| 271 |
+
|
| 272 |
+
per_row = b
|
| 273 |
+
for p in [6, 5, 4, 3, 2]:
|
| 274 |
+
if b % p == 0:
|
| 275 |
+
per_row = p
|
| 276 |
+
break
|
| 277 |
+
|
| 278 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
| 279 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
| 280 |
+
x = x.detach().cpu().to(torch.uint8)
|
| 281 |
+
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
|
| 282 |
+
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
|
| 283 |
+
return x
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def save_bcthw_as_png(x, output_filename):
|
| 287 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
| 288 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
| 289 |
+
x = x.detach().cpu().to(torch.uint8)
|
| 290 |
+
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
|
| 291 |
+
torchvision.io.write_png(x, output_filename)
|
| 292 |
+
return output_filename
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def save_bchw_as_png(x, output_filename):
|
| 296 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
| 297 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
| 298 |
+
x = x.detach().cpu().to(torch.uint8)
|
| 299 |
+
x = einops.rearrange(x, 'b c h w -> c h (b w)')
|
| 300 |
+
torchvision.io.write_png(x, output_filename)
|
| 301 |
+
return output_filename
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def add_tensors_with_padding(tensor1, tensor2):
|
| 305 |
+
if tensor1.shape == tensor2.shape:
|
| 306 |
+
return tensor1 + tensor2
|
| 307 |
+
|
| 308 |
+
shape1 = tensor1.shape
|
| 309 |
+
shape2 = tensor2.shape
|
| 310 |
+
|
| 311 |
+
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
|
| 312 |
+
|
| 313 |
+
padded_tensor1 = torch.zeros(new_shape)
|
| 314 |
+
padded_tensor2 = torch.zeros(new_shape)
|
| 315 |
+
|
| 316 |
+
padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
|
| 317 |
+
padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
|
| 318 |
+
|
| 319 |
+
result = padded_tensor1 + padded_tensor2
|
| 320 |
+
return result
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def print_free_mem():
|
| 324 |
+
torch.cuda.empty_cache()
|
| 325 |
+
free_mem, total_mem = torch.cuda.mem_get_info(0)
|
| 326 |
+
free_mem_mb = free_mem / (1024 ** 2)
|
| 327 |
+
total_mem_mb = total_mem / (1024 ** 2)
|
| 328 |
+
print(f"Free memory: {free_mem_mb:.2f} MB")
|
| 329 |
+
print(f"Total memory: {total_mem_mb:.2f} MB")
|
| 330 |
+
return
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def print_gpu_parameters(device, state_dict, log_count=1):
|
| 334 |
+
summary = {"device": device, "keys_count": len(state_dict)}
|
| 335 |
+
|
| 336 |
+
logged_params = {}
|
| 337 |
+
for i, (key, tensor) in enumerate(state_dict.items()):
|
| 338 |
+
if i >= log_count:
|
| 339 |
+
break
|
| 340 |
+
logged_params[key] = tensor.flatten()[:3].tolist()
|
| 341 |
+
|
| 342 |
+
summary["params"] = logged_params
|
| 343 |
+
|
| 344 |
+
print(str(summary))
|
| 345 |
+
return
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
|
| 349 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 350 |
+
|
| 351 |
+
txt = Image.new("RGB", (width, height), color="white")
|
| 352 |
+
draw = ImageDraw.Draw(txt)
|
| 353 |
+
font = ImageFont.truetype(font_path, size=size)
|
| 354 |
+
|
| 355 |
+
if text == '':
|
| 356 |
+
return np.array(txt)
|
| 357 |
+
|
| 358 |
+
# Split text into lines that fit within the image width
|
| 359 |
+
lines = []
|
| 360 |
+
words = text.split()
|
| 361 |
+
current_line = words[0]
|
| 362 |
+
|
| 363 |
+
for word in words[1:]:
|
| 364 |
+
line_with_word = f"{current_line} {word}"
|
| 365 |
+
if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
|
| 366 |
+
current_line = line_with_word
|
| 367 |
+
else:
|
| 368 |
+
lines.append(current_line)
|
| 369 |
+
current_line = word
|
| 370 |
+
|
| 371 |
+
lines.append(current_line)
|
| 372 |
+
|
| 373 |
+
# Draw the text line by line
|
| 374 |
+
y = 0
|
| 375 |
+
line_height = draw.textbbox((0, 0), "A", font=font)[3]
|
| 376 |
+
|
| 377 |
+
for line in lines:
|
| 378 |
+
if y + line_height > height:
|
| 379 |
+
break # stop drawing if the next line will be outside the image
|
| 380 |
+
draw.text((0, y), line, fill="black", font=font)
|
| 381 |
+
y += line_height
|
| 382 |
+
|
| 383 |
+
return np.array(txt)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def blue_mark(x):
|
| 387 |
+
x = x.copy()
|
| 388 |
+
c = x[:, :, 2]
|
| 389 |
+
b = cv2.blur(c, (9, 9))
|
| 390 |
+
x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
|
| 391 |
+
return x
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def green_mark(x):
|
| 395 |
+
x = x.copy()
|
| 396 |
+
x[:, :, 2] = -1
|
| 397 |
+
x[:, :, 0] = -1
|
| 398 |
+
return x
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def frame_mark(x):
|
| 402 |
+
x = x.copy()
|
| 403 |
+
x[:64] = -1
|
| 404 |
+
x[-64:] = -1
|
| 405 |
+
x[:, :8] = 1
|
| 406 |
+
x[:, -8:] = 1
|
| 407 |
+
return x
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
@torch.inference_mode()
|
| 411 |
+
def pytorch2numpy(imgs):
|
| 412 |
+
results = []
|
| 413 |
+
for x in imgs:
|
| 414 |
+
y = x.movedim(0, -1)
|
| 415 |
+
y = y * 127.5 + 127.5
|
| 416 |
+
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 417 |
+
results.append(y)
|
| 418 |
+
return results
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
@torch.inference_mode()
|
| 422 |
+
def numpy2pytorch(imgs):
|
| 423 |
+
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
| 424 |
+
h = h.movedim(-1, 1)
|
| 425 |
+
return h
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
@torch.no_grad()
|
| 429 |
+
def duplicate_prefix_to_suffix(x, count, zero_out=False):
|
| 430 |
+
if zero_out:
|
| 431 |
+
return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
|
| 432 |
+
else:
|
| 433 |
+
return torch.cat([x, x[:count]], dim=0)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def weighted_mse(a, b, weight):
|
| 437 |
+
return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
|
| 441 |
+
x = (x - x_min) / (x_max - x_min)
|
| 442 |
+
x = max(0.0, min(x, 1.0))
|
| 443 |
+
x = x ** sigma
|
| 444 |
+
return y_min + x * (y_max - y_min)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def expand_to_dims(x, target_dims):
|
| 448 |
+
return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
|
| 452 |
+
if tensor is None:
|
| 453 |
+
return None
|
| 454 |
+
|
| 455 |
+
first_dim = tensor.shape[0]
|
| 456 |
+
|
| 457 |
+
if first_dim == batch_size:
|
| 458 |
+
return tensor
|
| 459 |
+
|
| 460 |
+
if batch_size % first_dim != 0:
|
| 461 |
+
raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
|
| 462 |
+
|
| 463 |
+
repeat_times = batch_size // first_dim
|
| 464 |
+
|
| 465 |
+
return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
def dim5(x):
|
| 469 |
+
return expand_to_dims(x, 5)
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def dim4(x):
|
| 473 |
+
return expand_to_dims(x, 4)
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def dim3(x):
|
| 477 |
+
return expand_to_dims(x, 3)
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def crop_or_pad_yield_mask(x, length):
|
| 481 |
+
B, F, C = x.shape
|
| 482 |
+
device = x.device
|
| 483 |
+
dtype = x.dtype
|
| 484 |
+
|
| 485 |
+
if F < length:
|
| 486 |
+
y = torch.zeros((B, length, C), dtype=dtype, device=device)
|
| 487 |
+
mask = torch.zeros((B, length), dtype=torch.bool, device=device)
|
| 488 |
+
y[:, :F, :] = x
|
| 489 |
+
mask[:, :F] = True
|
| 490 |
+
return y, mask
|
| 491 |
+
|
| 492 |
+
return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def extend_dim(x, dim, minimal_length, zero_pad=False):
|
| 496 |
+
original_length = int(x.shape[dim])
|
| 497 |
+
|
| 498 |
+
if original_length >= minimal_length:
|
| 499 |
+
return x
|
| 500 |
+
|
| 501 |
+
if zero_pad:
|
| 502 |
+
padding_shape = list(x.shape)
|
| 503 |
+
padding_shape[dim] = minimal_length - original_length
|
| 504 |
+
padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
|
| 505 |
+
else:
|
| 506 |
+
idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
|
| 507 |
+
last_element = x[idx]
|
| 508 |
+
padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
|
| 509 |
+
|
| 510 |
+
return torch.cat([x, padding], dim=dim)
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def lazy_positional_encoding(t, repeats=None):
|
| 514 |
+
if not isinstance(t, list):
|
| 515 |
+
t = [t]
|
| 516 |
+
|
| 517 |
+
from diffusers.models.embeddings import get_timestep_embedding
|
| 518 |
+
|
| 519 |
+
te = torch.tensor(t)
|
| 520 |
+
te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
|
| 521 |
+
|
| 522 |
+
if repeats is None:
|
| 523 |
+
return te
|
| 524 |
+
|
| 525 |
+
te = te[:, None, :].expand(-1, repeats, -1)
|
| 526 |
+
|
| 527 |
+
return te
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def state_dict_offset_merge(A, B, C=None):
|
| 531 |
+
result = {}
|
| 532 |
+
keys = A.keys()
|
| 533 |
+
|
| 534 |
+
for key in keys:
|
| 535 |
+
A_value = A[key]
|
| 536 |
+
B_value = B[key].to(A_value)
|
| 537 |
+
|
| 538 |
+
if C is None:
|
| 539 |
+
result[key] = A_value + B_value
|
| 540 |
+
else:
|
| 541 |
+
C_value = C[key].to(A_value)
|
| 542 |
+
result[key] = A_value + B_value - C_value
|
| 543 |
+
|
| 544 |
+
return result
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def state_dict_weighted_merge(state_dicts, weights):
|
| 548 |
+
if len(state_dicts) != len(weights):
|
| 549 |
+
raise ValueError("Number of state dictionaries must match number of weights")
|
| 550 |
+
|
| 551 |
+
if not state_dicts:
|
| 552 |
+
return {}
|
| 553 |
+
|
| 554 |
+
total_weight = sum(weights)
|
| 555 |
+
|
| 556 |
+
if total_weight == 0:
|
| 557 |
+
raise ValueError("Sum of weights cannot be zero")
|
| 558 |
+
|
| 559 |
+
normalized_weights = [w / total_weight for w in weights]
|
| 560 |
+
|
| 561 |
+
keys = state_dicts[0].keys()
|
| 562 |
+
result = {}
|
| 563 |
+
|
| 564 |
+
for key in keys:
|
| 565 |
+
result[key] = state_dicts[0][key] * normalized_weights[0]
|
| 566 |
+
|
| 567 |
+
for i in range(1, len(state_dicts)):
|
| 568 |
+
state_dict_value = state_dicts[i][key].to(result[key])
|
| 569 |
+
result[key] += state_dict_value * normalized_weights[i]
|
| 570 |
+
|
| 571 |
+
return result
|
| 572 |
+
|
| 573 |
+
|
| 574 |
+
def group_files_by_folder(all_files):
|
| 575 |
+
grouped_files = {}
|
| 576 |
+
|
| 577 |
+
for file in all_files:
|
| 578 |
+
folder_name = os.path.basename(os.path.dirname(file))
|
| 579 |
+
if folder_name not in grouped_files:
|
| 580 |
+
grouped_files[folder_name] = []
|
| 581 |
+
grouped_files[folder_name].append(file)
|
| 582 |
+
|
| 583 |
+
list_of_lists = list(grouped_files.values())
|
| 584 |
+
return list_of_lists
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def generate_timestamp():
|
| 588 |
+
now = datetime.datetime.now()
|
| 589 |
+
timestamp = now.strftime('%y%m%d_%H%M%S')
|
| 590 |
+
milliseconds = f"{int(now.microsecond / 1000):03d}"
|
| 591 |
+
random_number = random.randint(0, 9999)
|
| 592 |
+
return f"{timestamp}_{milliseconds}_{random_number}"
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def write_PIL_image_with_png_info(image, metadata, path):
|
| 596 |
+
from PIL.PngImagePlugin import PngInfo
|
| 597 |
+
|
| 598 |
+
png_info = PngInfo()
|
| 599 |
+
for key, value in metadata.items():
|
| 600 |
+
png_info.add_text(key, value)
|
| 601 |
+
|
| 602 |
+
image.save(path, "PNG", pnginfo=png_info)
|
| 603 |
+
return image
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def torch_safe_save(content, path):
|
| 607 |
+
torch.save(content, path + '_tmp')
|
| 608 |
+
os.replace(path + '_tmp', path)
|
| 609 |
+
return path
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
def move_optimizer_to_device(optimizer, device):
|
| 613 |
+
for state in optimizer.state.values():
|
| 614 |
+
for k, v in state.items():
|
| 615 |
+
if isinstance(v, torch.Tensor):
|
| 616 |
+
state[k] = v.to(device)
|
demo_utils/vae.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
import tensorrt as trt
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from demo_utils.constant import ALL_INPUTS_NAMES, ZERO_VAE_CACHE
|
| 8 |
+
from wan.modules.vae import AttentionBlock, CausalConv3d, RMS_norm, Upsample
|
| 9 |
+
|
| 10 |
+
CACHE_T = 2
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ResidualBlock(nn.Module):
|
| 14 |
+
|
| 15 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.in_dim = in_dim
|
| 18 |
+
self.out_dim = out_dim
|
| 19 |
+
|
| 20 |
+
# layers
|
| 21 |
+
self.residual = nn.Sequential(
|
| 22 |
+
RMS_norm(in_dim, images=False), nn.SiLU(),
|
| 23 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
| 24 |
+
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
| 25 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
| 26 |
+
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
| 27 |
+
if in_dim != out_dim else nn.Identity()
|
| 28 |
+
|
| 29 |
+
def forward(self, x, feat_cache_1, feat_cache_2):
|
| 30 |
+
h = self.shortcut(x)
|
| 31 |
+
feat_cache = feat_cache_1
|
| 32 |
+
out_feat_cache = []
|
| 33 |
+
for layer in self.residual:
|
| 34 |
+
if isinstance(layer, CausalConv3d):
|
| 35 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 36 |
+
if cache_x.shape[2] < 2 and feat_cache is not None:
|
| 37 |
+
# cache last frame of last two chunk
|
| 38 |
+
cache_x = torch.cat([
|
| 39 |
+
feat_cache[:, :, -1, :, :].unsqueeze(2).to(
|
| 40 |
+
cache_x.device), cache_x
|
| 41 |
+
],
|
| 42 |
+
dim=2)
|
| 43 |
+
x = layer(x, feat_cache)
|
| 44 |
+
out_feat_cache.append(cache_x)
|
| 45 |
+
feat_cache = feat_cache_2
|
| 46 |
+
else:
|
| 47 |
+
x = layer(x)
|
| 48 |
+
return x + h, *out_feat_cache
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Resample(nn.Module):
|
| 52 |
+
|
| 53 |
+
def __init__(self, dim, mode):
|
| 54 |
+
assert mode in ('none', 'upsample2d', 'upsample3d')
|
| 55 |
+
super().__init__()
|
| 56 |
+
self.dim = dim
|
| 57 |
+
self.mode = mode
|
| 58 |
+
|
| 59 |
+
# layers
|
| 60 |
+
if mode == 'upsample2d':
|
| 61 |
+
self.resample = nn.Sequential(
|
| 62 |
+
Upsample(scale_factor=(2., 2.), mode='nearest'),
|
| 63 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 64 |
+
elif mode == 'upsample3d':
|
| 65 |
+
self.resample = nn.Sequential(
|
| 66 |
+
Upsample(scale_factor=(2., 2.), mode='nearest'),
|
| 67 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 68 |
+
self.time_conv = CausalConv3d(
|
| 69 |
+
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 70 |
+
else:
|
| 71 |
+
self.resample = nn.Identity()
|
| 72 |
+
|
| 73 |
+
def forward(self, x, is_first_frame, feat_cache):
|
| 74 |
+
if self.mode == 'upsample3d':
|
| 75 |
+
b, c, t, h, w = x.size()
|
| 76 |
+
# x, out_feat_cache = torch.cond(
|
| 77 |
+
# is_first_frame,
|
| 78 |
+
# lambda: (torch.cat([torch.zeros_like(x), x], dim=2), feat_cache.clone()),
|
| 79 |
+
# lambda: self.temporal_conv(x, feat_cache),
|
| 80 |
+
# )
|
| 81 |
+
# x, out_feat_cache = torch.cond(
|
| 82 |
+
# is_first_frame,
|
| 83 |
+
# lambda: (torch.cat([torch.zeros_like(x), x], dim=2), feat_cache.clone()),
|
| 84 |
+
# lambda: self.temporal_conv(x, feat_cache),
|
| 85 |
+
# )
|
| 86 |
+
x, out_feat_cache = self.temporal_conv(x, is_first_frame, feat_cache)
|
| 87 |
+
out_feat_cache = torch.cond(
|
| 88 |
+
is_first_frame,
|
| 89 |
+
lambda: feat_cache.clone().contiguous(),
|
| 90 |
+
lambda: out_feat_cache.clone().contiguous(),
|
| 91 |
+
)
|
| 92 |
+
# if is_first_frame:
|
| 93 |
+
# x = torch.cat([torch.zeros_like(x), x], dim=2)
|
| 94 |
+
# out_feat_cache = feat_cache.clone()
|
| 95 |
+
# else:
|
| 96 |
+
# x, out_feat_cache = self.temporal_conv(x, feat_cache)
|
| 97 |
+
else:
|
| 98 |
+
out_feat_cache = None
|
| 99 |
+
t = x.shape[2]
|
| 100 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
| 101 |
+
x = self.resample(x)
|
| 102 |
+
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
| 103 |
+
return x, out_feat_cache
|
| 104 |
+
|
| 105 |
+
def temporal_conv(self, x, is_first_frame, feat_cache):
|
| 106 |
+
b, c, t, h, w = x.size()
|
| 107 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 108 |
+
if cache_x.shape[2] < 2 and feat_cache is not None:
|
| 109 |
+
cache_x = torch.cat([
|
| 110 |
+
torch.zeros_like(cache_x),
|
| 111 |
+
cache_x
|
| 112 |
+
], dim=2)
|
| 113 |
+
x = torch.cond(
|
| 114 |
+
is_first_frame,
|
| 115 |
+
lambda: torch.cat([torch.zeros_like(x), x], dim=1).contiguous(),
|
| 116 |
+
lambda: self.time_conv(x, feat_cache).contiguous(),
|
| 117 |
+
)
|
| 118 |
+
# x = self.time_conv(x, feat_cache)
|
| 119 |
+
out_feat_cache = cache_x
|
| 120 |
+
|
| 121 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 122 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
| 123 |
+
3)
|
| 124 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 125 |
+
return x.contiguous(), out_feat_cache.contiguous()
|
| 126 |
+
|
| 127 |
+
def init_weight(self, conv):
|
| 128 |
+
conv_weight = conv.weight
|
| 129 |
+
nn.init.zeros_(conv_weight)
|
| 130 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 131 |
+
one_matrix = torch.eye(c1, c2)
|
| 132 |
+
init_matrix = one_matrix
|
| 133 |
+
nn.init.zeros_(conv_weight)
|
| 134 |
+
# conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
| 135 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
|
| 136 |
+
conv.weight.data.copy_(conv_weight)
|
| 137 |
+
nn.init.zeros_(conv.bias.data)
|
| 138 |
+
|
| 139 |
+
def init_weight2(self, conv):
|
| 140 |
+
conv_weight = conv.weight.data
|
| 141 |
+
nn.init.zeros_(conv_weight)
|
| 142 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 143 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
| 144 |
+
# init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
| 145 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
| 146 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
| 147 |
+
conv.weight.data.copy_(conv_weight)
|
| 148 |
+
nn.init.zeros_(conv.bias.data)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class VAEDecoderWrapperSingle(nn.Module):
|
| 152 |
+
def __init__(self):
|
| 153 |
+
super().__init__()
|
| 154 |
+
self.decoder = VAEDecoder3d()
|
| 155 |
+
mean = [
|
| 156 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
| 157 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
| 158 |
+
]
|
| 159 |
+
std = [
|
| 160 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
| 161 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
| 162 |
+
]
|
| 163 |
+
self.mean = torch.tensor(mean, dtype=torch.float32)
|
| 164 |
+
self.std = torch.tensor(std, dtype=torch.float32)
|
| 165 |
+
self.z_dim = 16
|
| 166 |
+
self.conv2 = CausalConv3d(self.z_dim, self.z_dim, 1)
|
| 167 |
+
|
| 168 |
+
def forward(
|
| 169 |
+
self,
|
| 170 |
+
z: torch.Tensor,
|
| 171 |
+
is_first_frame: torch.Tensor,
|
| 172 |
+
*feat_cache: List[torch.Tensor]
|
| 173 |
+
):
|
| 174 |
+
# from [batch_size, num_frames, num_channels, height, width]
|
| 175 |
+
# to [batch_size, num_channels, num_frames, height, width]
|
| 176 |
+
z = z.permute(0, 2, 1, 3, 4)
|
| 177 |
+
assert z.shape[2] == 1
|
| 178 |
+
feat_cache = list(feat_cache)
|
| 179 |
+
is_first_frame = is_first_frame.bool()
|
| 180 |
+
|
| 181 |
+
device, dtype = z.device, z.dtype
|
| 182 |
+
scale = [self.mean.to(device=device, dtype=dtype),
|
| 183 |
+
1.0 / self.std.to(device=device, dtype=dtype)]
|
| 184 |
+
|
| 185 |
+
if isinstance(scale[0], torch.Tensor):
|
| 186 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 187 |
+
1, self.z_dim, 1, 1, 1)
|
| 188 |
+
else:
|
| 189 |
+
z = z / scale[1] + scale[0]
|
| 190 |
+
x = self.conv2(z)
|
| 191 |
+
out, feat_cache = self.decoder(x, is_first_frame, feat_cache=feat_cache)
|
| 192 |
+
out = out.clamp_(-1, 1)
|
| 193 |
+
# from [batch_size, num_channels, num_frames, height, width]
|
| 194 |
+
# to [batch_size, num_frames, num_channels, height, width]
|
| 195 |
+
out = out.permute(0, 2, 1, 3, 4)
|
| 196 |
+
return out, feat_cache
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class VAEDecoder3d(nn.Module):
|
| 200 |
+
def __init__(self,
|
| 201 |
+
dim=96,
|
| 202 |
+
z_dim=16,
|
| 203 |
+
dim_mult=[1, 2, 4, 4],
|
| 204 |
+
num_res_blocks=2,
|
| 205 |
+
attn_scales=[],
|
| 206 |
+
temperal_upsample=[True, True, False],
|
| 207 |
+
dropout=0.0):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.dim = dim
|
| 210 |
+
self.z_dim = z_dim
|
| 211 |
+
self.dim_mult = dim_mult
|
| 212 |
+
self.num_res_blocks = num_res_blocks
|
| 213 |
+
self.attn_scales = attn_scales
|
| 214 |
+
self.temperal_upsample = temperal_upsample
|
| 215 |
+
self.cache_t = 2
|
| 216 |
+
self.decoder_conv_num = 32
|
| 217 |
+
|
| 218 |
+
# dimensions
|
| 219 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 220 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
| 221 |
+
|
| 222 |
+
# init block
|
| 223 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 224 |
+
|
| 225 |
+
# middle blocks
|
| 226 |
+
self.middle = nn.Sequential(
|
| 227 |
+
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
| 228 |
+
ResidualBlock(dims[0], dims[0], dropout))
|
| 229 |
+
|
| 230 |
+
# upsample blocks
|
| 231 |
+
upsamples = []
|
| 232 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 233 |
+
# residual (+attention) blocks
|
| 234 |
+
if i == 1 or i == 2 or i == 3:
|
| 235 |
+
in_dim = in_dim // 2
|
| 236 |
+
for _ in range(num_res_blocks + 1):
|
| 237 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 238 |
+
if scale in attn_scales:
|
| 239 |
+
upsamples.append(AttentionBlock(out_dim))
|
| 240 |
+
in_dim = out_dim
|
| 241 |
+
|
| 242 |
+
# upsample block
|
| 243 |
+
if i != len(dim_mult) - 1:
|
| 244 |
+
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
| 245 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
| 246 |
+
scale *= 2.0
|
| 247 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 248 |
+
|
| 249 |
+
# output blocks
|
| 250 |
+
self.head = nn.Sequential(
|
| 251 |
+
RMS_norm(out_dim, images=False), nn.SiLU(),
|
| 252 |
+
CausalConv3d(out_dim, 3, 3, padding=1))
|
| 253 |
+
|
| 254 |
+
def forward(
|
| 255 |
+
self,
|
| 256 |
+
x: torch.Tensor,
|
| 257 |
+
is_first_frame: torch.Tensor,
|
| 258 |
+
feat_cache: List[torch.Tensor]
|
| 259 |
+
):
|
| 260 |
+
idx = 0
|
| 261 |
+
out_feat_cache = []
|
| 262 |
+
|
| 263 |
+
# conv1
|
| 264 |
+
cache_x = x[:, :, -self.cache_t:, :, :].clone()
|
| 265 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 266 |
+
# cache last frame of last two chunk
|
| 267 |
+
cache_x = torch.cat([
|
| 268 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 269 |
+
cache_x.device), cache_x
|
| 270 |
+
],
|
| 271 |
+
dim=2)
|
| 272 |
+
x = self.conv1(x, feat_cache[idx])
|
| 273 |
+
out_feat_cache.append(cache_x)
|
| 274 |
+
idx += 1
|
| 275 |
+
|
| 276 |
+
# middle
|
| 277 |
+
for layer in self.middle:
|
| 278 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 279 |
+
x, out_feat_cache_1, out_feat_cache_2 = layer(x, feat_cache[idx], feat_cache[idx + 1])
|
| 280 |
+
idx += 2
|
| 281 |
+
out_feat_cache.append(out_feat_cache_1)
|
| 282 |
+
out_feat_cache.append(out_feat_cache_2)
|
| 283 |
+
else:
|
| 284 |
+
x = layer(x)
|
| 285 |
+
|
| 286 |
+
# upsamples
|
| 287 |
+
for layer in self.upsamples:
|
| 288 |
+
if isinstance(layer, Resample):
|
| 289 |
+
x, cache_x = layer(x, is_first_frame, feat_cache[idx])
|
| 290 |
+
if cache_x is not None:
|
| 291 |
+
out_feat_cache.append(cache_x)
|
| 292 |
+
idx += 1
|
| 293 |
+
else:
|
| 294 |
+
x, out_feat_cache_1, out_feat_cache_2 = layer(x, feat_cache[idx], feat_cache[idx + 1])
|
| 295 |
+
idx += 2
|
| 296 |
+
out_feat_cache.append(out_feat_cache_1)
|
| 297 |
+
out_feat_cache.append(out_feat_cache_2)
|
| 298 |
+
|
| 299 |
+
# head
|
| 300 |
+
for layer in self.head:
|
| 301 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 302 |
+
cache_x = x[:, :, -self.cache_t:, :, :].clone()
|
| 303 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 304 |
+
# cache last frame of last two chunk
|
| 305 |
+
cache_x = torch.cat([
|
| 306 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 307 |
+
cache_x.device), cache_x
|
| 308 |
+
],
|
| 309 |
+
dim=2)
|
| 310 |
+
x = layer(x, feat_cache[idx])
|
| 311 |
+
out_feat_cache.append(cache_x)
|
| 312 |
+
idx += 1
|
| 313 |
+
else:
|
| 314 |
+
x = layer(x)
|
| 315 |
+
return x, out_feat_cache
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class VAETRTWrapper():
|
| 319 |
+
def __init__(self):
|
| 320 |
+
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
| 321 |
+
with open("checkpoints/vae_decoder_int8.trt", "rb") as f, trt.Runtime(TRT_LOGGER) as rt:
|
| 322 |
+
self.engine: trt.ICudaEngine = rt.deserialize_cuda_engine(f.read())
|
| 323 |
+
|
| 324 |
+
self.context: trt.IExecutionContext = self.engine.create_execution_context()
|
| 325 |
+
self.stream = torch.cuda.current_stream().cuda_stream
|
| 326 |
+
|
| 327 |
+
# ──────────────────────────────
|
| 328 |
+
# 2️⃣ Feed the engine with tensors
|
| 329 |
+
# (name-based API in TRT ≥10)
|
| 330 |
+
# ──────────────────────────────
|
| 331 |
+
self.dtype_map = {
|
| 332 |
+
trt.float32: torch.float32,
|
| 333 |
+
trt.float16: torch.float16,
|
| 334 |
+
trt.int8: torch.int8,
|
| 335 |
+
trt.int32: torch.int32,
|
| 336 |
+
}
|
| 337 |
+
test_input = torch.zeros(1, 16, 1, 60, 104).cuda().half()
|
| 338 |
+
is_first_frame = torch.tensor(1.0).cuda().half()
|
| 339 |
+
test_cache_inputs = [c.cuda().half() for c in ZERO_VAE_CACHE]
|
| 340 |
+
test_inputs = [test_input, is_first_frame] + test_cache_inputs
|
| 341 |
+
|
| 342 |
+
# keep references so buffers stay alive
|
| 343 |
+
self.device_buffers, self.outputs = {}, []
|
| 344 |
+
|
| 345 |
+
# ---- inputs ----
|
| 346 |
+
for i, name in enumerate(ALL_INPUTS_NAMES):
|
| 347 |
+
tensor, scale = test_inputs[i], 1 / 127
|
| 348 |
+
tensor = self.quantize_if_needed(tensor, self.engine.get_tensor_dtype(name), scale)
|
| 349 |
+
|
| 350 |
+
# dynamic shapes
|
| 351 |
+
if -1 in self.engine.get_tensor_shape(name):
|
| 352 |
+
# new API :contentReference[oaicite:0]{index=0}
|
| 353 |
+
self.context.set_input_shape(name, tuple(tensor.shape))
|
| 354 |
+
|
| 355 |
+
# replaces bindings[] :contentReference[oaicite:1]{index=1}
|
| 356 |
+
self.context.set_tensor_address(name, int(tensor.data_ptr()))
|
| 357 |
+
self.device_buffers[name] = tensor # keep pointer alive
|
| 358 |
+
|
| 359 |
+
# ---- (after all input shapes are known) infer output shapes ----
|
| 360 |
+
# propagates shapes :contentReference[oaicite:2]{index=2}
|
| 361 |
+
self.context.infer_shapes()
|
| 362 |
+
|
| 363 |
+
for i in range(self.engine.num_io_tensors):
|
| 364 |
+
name = self.engine.get_tensor_name(i)
|
| 365 |
+
# replaces binding_is_input :contentReference[oaicite:3]{index=3}
|
| 366 |
+
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
|
| 367 |
+
shape = tuple(self.context.get_tensor_shape(name))
|
| 368 |
+
dtype = self.dtype_map[self.engine.get_tensor_dtype(name)]
|
| 369 |
+
out = torch.empty(shape, dtype=dtype, device="cuda").contiguous()
|
| 370 |
+
|
| 371 |
+
self.context.set_tensor_address(name, int(out.data_ptr()))
|
| 372 |
+
self.outputs.append(out)
|
| 373 |
+
self.device_buffers[name] = out
|
| 374 |
+
|
| 375 |
+
# helper to quant-convert on the fly
|
| 376 |
+
def quantize_if_needed(self, t, expected_dtype, scale):
|
| 377 |
+
if expected_dtype == trt.int8 and t.dtype != torch.int8:
|
| 378 |
+
t = torch.clamp((t / scale).round(), -128, 127).to(torch.int8).contiguous()
|
| 379 |
+
return t # keep pointer alive
|
| 380 |
+
|
| 381 |
+
def forward(self, *test_inputs):
|
| 382 |
+
for i, name in enumerate(ALL_INPUTS_NAMES):
|
| 383 |
+
tensor, scale = test_inputs[i], 1 / 127
|
| 384 |
+
tensor = self.quantize_if_needed(tensor, self.engine.get_tensor_dtype(name), scale)
|
| 385 |
+
self.context.set_tensor_address(name, int(tensor.data_ptr()))
|
| 386 |
+
self.device_buffers[name] = tensor
|
| 387 |
+
|
| 388 |
+
self.context.execute_async_v3(stream_handle=self.stream)
|
| 389 |
+
torch.cuda.current_stream().synchronize()
|
| 390 |
+
return self.outputs
|
demo_utils/vae_block3.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from wan.modules.vae import AttentionBlock, CausalConv3d, RMS_norm, ResidualBlock, Upsample
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Resample(nn.Module):
|
| 10 |
+
|
| 11 |
+
def __init__(self, dim, mode):
|
| 12 |
+
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
| 13 |
+
'downsample3d')
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.dim = dim
|
| 16 |
+
self.mode = mode
|
| 17 |
+
self.cache_t = 2
|
| 18 |
+
|
| 19 |
+
# layers
|
| 20 |
+
if mode == 'upsample2d':
|
| 21 |
+
self.resample = nn.Sequential(
|
| 22 |
+
Upsample(scale_factor=(2., 2.), mode='nearest'),
|
| 23 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 24 |
+
elif mode == 'upsample3d':
|
| 25 |
+
self.resample = nn.Sequential(
|
| 26 |
+
Upsample(scale_factor=(2., 2.), mode='nearest'),
|
| 27 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
| 28 |
+
self.time_conv = CausalConv3d(
|
| 29 |
+
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 30 |
+
|
| 31 |
+
elif mode == 'downsample2d':
|
| 32 |
+
self.resample = nn.Sequential(
|
| 33 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 34 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 35 |
+
elif mode == 'downsample3d':
|
| 36 |
+
self.resample = nn.Sequential(
|
| 37 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
| 38 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
| 39 |
+
self.time_conv = CausalConv3d(
|
| 40 |
+
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
| 41 |
+
|
| 42 |
+
else:
|
| 43 |
+
self.resample = nn.Identity()
|
| 44 |
+
|
| 45 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 46 |
+
b, c, t, h, w = x.size()
|
| 47 |
+
if self.mode == 'upsample3d':
|
| 48 |
+
if feat_cache is not None:
|
| 49 |
+
idx = feat_idx[0]
|
| 50 |
+
if feat_cache[idx] is None:
|
| 51 |
+
feat_cache[idx] = 'Rep'
|
| 52 |
+
feat_idx[0] += 1
|
| 53 |
+
else:
|
| 54 |
+
|
| 55 |
+
cache_x = x[:, :, -self.cache_t:, :, :].clone()
|
| 56 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
| 57 |
+
idx] is not None and feat_cache[idx] != 'Rep':
|
| 58 |
+
# cache last frame of last two chunk
|
| 59 |
+
cache_x = torch.cat([
|
| 60 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 61 |
+
cache_x.device), cache_x
|
| 62 |
+
],
|
| 63 |
+
dim=2)
|
| 64 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
| 65 |
+
idx] is not None and feat_cache[idx] == 'Rep':
|
| 66 |
+
cache_x = torch.cat([
|
| 67 |
+
torch.zeros_like(cache_x).to(cache_x.device),
|
| 68 |
+
cache_x
|
| 69 |
+
],
|
| 70 |
+
dim=2)
|
| 71 |
+
if feat_cache[idx] == 'Rep':
|
| 72 |
+
x = self.time_conv(x)
|
| 73 |
+
else:
|
| 74 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 75 |
+
feat_cache[idx] = cache_x
|
| 76 |
+
feat_idx[0] += 1
|
| 77 |
+
|
| 78 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 79 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
| 80 |
+
3)
|
| 81 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 82 |
+
t = x.shape[2]
|
| 83 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
| 84 |
+
x = self.resample(x)
|
| 85 |
+
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
| 86 |
+
|
| 87 |
+
if self.mode == 'downsample3d':
|
| 88 |
+
if feat_cache is not None:
|
| 89 |
+
idx = feat_idx[0]
|
| 90 |
+
if feat_cache[idx] is None:
|
| 91 |
+
feat_cache[idx] = x.clone()
|
| 92 |
+
feat_idx[0] += 1
|
| 93 |
+
else:
|
| 94 |
+
|
| 95 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 96 |
+
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
| 97 |
+
# # cache last frame of last two chunk
|
| 98 |
+
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 99 |
+
|
| 100 |
+
x = self.time_conv(
|
| 101 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
| 102 |
+
feat_cache[idx] = cache_x
|
| 103 |
+
feat_idx[0] += 1
|
| 104 |
+
return x
|
| 105 |
+
|
| 106 |
+
def init_weight(self, conv):
|
| 107 |
+
conv_weight = conv.weight
|
| 108 |
+
nn.init.zeros_(conv_weight)
|
| 109 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 110 |
+
one_matrix = torch.eye(c1, c2)
|
| 111 |
+
init_matrix = one_matrix
|
| 112 |
+
nn.init.zeros_(conv_weight)
|
| 113 |
+
# conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
| 114 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
|
| 115 |
+
conv.weight.data.copy_(conv_weight)
|
| 116 |
+
nn.init.zeros_(conv.bias.data)
|
| 117 |
+
|
| 118 |
+
def init_weight2(self, conv):
|
| 119 |
+
conv_weight = conv.weight.data
|
| 120 |
+
nn.init.zeros_(conv_weight)
|
| 121 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 122 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
| 123 |
+
# init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
| 124 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
| 125 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
| 126 |
+
conv.weight.data.copy_(conv_weight)
|
| 127 |
+
nn.init.zeros_(conv.bias.data)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class VAEDecoderWrapper(nn.Module):
|
| 131 |
+
def __init__(self):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.decoder = VAEDecoder3d()
|
| 134 |
+
mean = [
|
| 135 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
| 136 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
| 137 |
+
]
|
| 138 |
+
std = [
|
| 139 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
| 140 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
| 141 |
+
]
|
| 142 |
+
self.mean = torch.tensor(mean, dtype=torch.float32)
|
| 143 |
+
self.std = torch.tensor(std, dtype=torch.float32)
|
| 144 |
+
self.z_dim = 16
|
| 145 |
+
self.conv2 = CausalConv3d(self.z_dim, self.z_dim, 1)
|
| 146 |
+
|
| 147 |
+
def forward(
|
| 148 |
+
self,
|
| 149 |
+
z: torch.Tensor,
|
| 150 |
+
*feat_cache: List[torch.Tensor]
|
| 151 |
+
):
|
| 152 |
+
# from [batch_size, num_frames, num_channels, height, width]
|
| 153 |
+
# to [batch_size, num_channels, num_frames, height, width]
|
| 154 |
+
z = z.permute(0, 2, 1, 3, 4)
|
| 155 |
+
feat_cache = list(feat_cache)
|
| 156 |
+
# print("Length of feat_cache: ", len(feat_cache))
|
| 157 |
+
|
| 158 |
+
device, dtype = z.device, z.dtype
|
| 159 |
+
scale = [self.mean.to(device=device, dtype=dtype),
|
| 160 |
+
1.0 / self.std.to(device=device, dtype=dtype)]
|
| 161 |
+
|
| 162 |
+
if isinstance(scale[0], torch.Tensor):
|
| 163 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 164 |
+
1, self.z_dim, 1, 1, 1)
|
| 165 |
+
else:
|
| 166 |
+
z = z / scale[1] + scale[0]
|
| 167 |
+
iter_ = z.shape[2]
|
| 168 |
+
x = self.conv2(z)
|
| 169 |
+
for i in range(iter_):
|
| 170 |
+
if i == 0:
|
| 171 |
+
out, feat_cache = self.decoder(
|
| 172 |
+
x[:, :, i:i + 1, :, :],
|
| 173 |
+
feat_cache=feat_cache)
|
| 174 |
+
else:
|
| 175 |
+
out_, feat_cache = self.decoder(
|
| 176 |
+
x[:, :, i:i + 1, :, :],
|
| 177 |
+
feat_cache=feat_cache)
|
| 178 |
+
out = torch.cat([out, out_], 2)
|
| 179 |
+
|
| 180 |
+
out = out.float().clamp_(-1, 1)
|
| 181 |
+
# from [batch_size, num_channels, num_frames, height, width]
|
| 182 |
+
# to [batch_size, num_frames, num_channels, height, width]
|
| 183 |
+
out = out.permute(0, 2, 1, 3, 4)
|
| 184 |
+
return out, feat_cache
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class VAEDecoder3d(nn.Module):
|
| 188 |
+
def __init__(self,
|
| 189 |
+
dim=96,
|
| 190 |
+
z_dim=16,
|
| 191 |
+
dim_mult=[1, 2, 4, 4],
|
| 192 |
+
num_res_blocks=2,
|
| 193 |
+
attn_scales=[],
|
| 194 |
+
temperal_upsample=[True, True, False],
|
| 195 |
+
dropout=0.0):
|
| 196 |
+
super().__init__()
|
| 197 |
+
self.dim = dim
|
| 198 |
+
self.z_dim = z_dim
|
| 199 |
+
self.dim_mult = dim_mult
|
| 200 |
+
self.num_res_blocks = num_res_blocks
|
| 201 |
+
self.attn_scales = attn_scales
|
| 202 |
+
self.temperal_upsample = temperal_upsample
|
| 203 |
+
self.cache_t = 2
|
| 204 |
+
self.decoder_conv_num = 32
|
| 205 |
+
|
| 206 |
+
# dimensions
|
| 207 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 208 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
| 209 |
+
|
| 210 |
+
# init block
|
| 211 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 212 |
+
|
| 213 |
+
# middle blocks
|
| 214 |
+
self.middle = nn.Sequential(
|
| 215 |
+
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
| 216 |
+
ResidualBlock(dims[0], dims[0], dropout))
|
| 217 |
+
|
| 218 |
+
# upsample blocks
|
| 219 |
+
upsamples = []
|
| 220 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 221 |
+
# residual (+attention) blocks
|
| 222 |
+
if i == 1 or i == 2 or i == 3:
|
| 223 |
+
in_dim = in_dim // 2
|
| 224 |
+
for _ in range(num_res_blocks + 1):
|
| 225 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 226 |
+
if scale in attn_scales:
|
| 227 |
+
upsamples.append(AttentionBlock(out_dim))
|
| 228 |
+
in_dim = out_dim
|
| 229 |
+
|
| 230 |
+
# upsample block
|
| 231 |
+
if i != len(dim_mult) - 1:
|
| 232 |
+
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
| 233 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
| 234 |
+
scale *= 2.0
|
| 235 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 236 |
+
|
| 237 |
+
# output blocks
|
| 238 |
+
self.head = nn.Sequential(
|
| 239 |
+
RMS_norm(out_dim, images=False), nn.SiLU(),
|
| 240 |
+
CausalConv3d(out_dim, 3, 3, padding=1))
|
| 241 |
+
|
| 242 |
+
def forward(
|
| 243 |
+
self,
|
| 244 |
+
x: torch.Tensor,
|
| 245 |
+
feat_cache: List[torch.Tensor]
|
| 246 |
+
):
|
| 247 |
+
feat_idx = [0]
|
| 248 |
+
|
| 249 |
+
# conv1
|
| 250 |
+
idx = feat_idx[0]
|
| 251 |
+
cache_x = x[:, :, -self.cache_t:, :, :].clone()
|
| 252 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 253 |
+
# cache last frame of last two chunk
|
| 254 |
+
cache_x = torch.cat([
|
| 255 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 256 |
+
cache_x.device), cache_x
|
| 257 |
+
],
|
| 258 |
+
dim=2)
|
| 259 |
+
x = self.conv1(x, feat_cache[idx])
|
| 260 |
+
feat_cache[idx] = cache_x
|
| 261 |
+
feat_idx[0] += 1
|
| 262 |
+
|
| 263 |
+
# middle
|
| 264 |
+
for layer in self.middle:
|
| 265 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 266 |
+
x = layer(x, feat_cache, feat_idx)
|
| 267 |
+
else:
|
| 268 |
+
x = layer(x)
|
| 269 |
+
|
| 270 |
+
# upsamples
|
| 271 |
+
for layer in self.upsamples:
|
| 272 |
+
x = layer(x, feat_cache, feat_idx)
|
| 273 |
+
|
| 274 |
+
# head
|
| 275 |
+
for layer in self.head:
|
| 276 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 277 |
+
idx = feat_idx[0]
|
| 278 |
+
cache_x = x[:, :, -self.cache_t:, :, :].clone()
|
| 279 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 280 |
+
# cache last frame of last two chunk
|
| 281 |
+
cache_x = torch.cat([
|
| 282 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
| 283 |
+
cache_x.device), cache_x
|
| 284 |
+
],
|
| 285 |
+
dim=2)
|
| 286 |
+
x = layer(x, feat_cache[idx])
|
| 287 |
+
feat_cache[idx] = cache_x
|
| 288 |
+
feat_idx[0] += 1
|
| 289 |
+
else:
|
| 290 |
+
x = layer(x)
|
| 291 |
+
return x, feat_cache
|
demo_utils/vae_torch2trt.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---- INT8 (optional) ----
|
| 2 |
+
from demo_utils.vae import (
|
| 3 |
+
VAEDecoderWrapperSingle, # main nn.Module
|
| 4 |
+
ZERO_VAE_CACHE # helper constants shipped with your code base
|
| 5 |
+
)
|
| 6 |
+
import pycuda.driver as cuda # ← add
|
| 7 |
+
import pycuda.autoinit # noqa
|
| 8 |
+
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import tensorrt as trt
|
| 14 |
+
|
| 15 |
+
from utils.dataset import ShardingLMDBDataset
|
| 16 |
+
|
| 17 |
+
data_path = "/mnt/localssd/wanx_14B_shift-3.0_cfg-5.0_lmdb_oneshard"
|
| 18 |
+
dataset = ShardingLMDBDataset(data_path, max_pair=int(1e8))
|
| 19 |
+
dataloader = torch.utils.data.DataLoader(
|
| 20 |
+
dataset,
|
| 21 |
+
batch_size=1,
|
| 22 |
+
num_workers=0
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
# ─────────────────────────────────────────────────────────
|
| 26 |
+
# 1️⃣ Bring the PyTorch model into scope
|
| 27 |
+
# (all code you pasted lives in `vae_decoder.py`)
|
| 28 |
+
# ─────────────────────────────────────────────────────────
|
| 29 |
+
|
| 30 |
+
# --- dummy tensors (exact shapes you posted) ---
|
| 31 |
+
dummy_input = torch.randn(1, 1, 16, 60, 104).half().cuda()
|
| 32 |
+
is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
|
| 33 |
+
dummy_cache_input = [
|
| 34 |
+
torch.randn(*s.shape).half().cuda() if isinstance(s, torch.Tensor) else s
|
| 35 |
+
for s in ZERO_VAE_CACHE # keep exactly the same ordering
|
| 36 |
+
]
|
| 37 |
+
inputs = [dummy_input, is_first_frame, *dummy_cache_input]
|
| 38 |
+
|
| 39 |
+
# ─────────────────────────────────────────────────────────
|
| 40 |
+
# 2️⃣ Export → ONNX
|
| 41 |
+
# ─────────────────────────────────────────────────────────
|
| 42 |
+
model = VAEDecoderWrapperSingle().half().cuda().eval()
|
| 43 |
+
|
| 44 |
+
vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
|
| 45 |
+
decoder_state_dict = {}
|
| 46 |
+
for key, value in vae_state_dict.items():
|
| 47 |
+
if 'decoder.' in key or 'conv2' in key:
|
| 48 |
+
decoder_state_dict[key] = value
|
| 49 |
+
model.load_state_dict(decoder_state_dict)
|
| 50 |
+
model = model.half().cuda().eval() # only batch dim dynamic
|
| 51 |
+
|
| 52 |
+
onnx_path = Path("vae_decoder.onnx")
|
| 53 |
+
feat_names = [f"vae_cache_{i}" for i in range(len(dummy_cache_input))]
|
| 54 |
+
all_inputs_names = ["z", "use_cache"] + feat_names
|
| 55 |
+
|
| 56 |
+
with torch.inference_mode():
|
| 57 |
+
torch.onnx.export(
|
| 58 |
+
model,
|
| 59 |
+
tuple(inputs), # must be a tuple
|
| 60 |
+
onnx_path.as_posix(),
|
| 61 |
+
input_names=all_inputs_names,
|
| 62 |
+
output_names=["rgb_out", "cache_out"],
|
| 63 |
+
opset_version=17,
|
| 64 |
+
do_constant_folding=True,
|
| 65 |
+
dynamo=True
|
| 66 |
+
)
|
| 67 |
+
print(f"✅ ONNX graph saved to {onnx_path.resolve()}")
|
| 68 |
+
|
| 69 |
+
# (Optional) quick sanity-check with ONNX-Runtime
|
| 70 |
+
try:
|
| 71 |
+
import onnxruntime as ort
|
| 72 |
+
sess = ort.InferenceSession(onnx_path.as_posix(),
|
| 73 |
+
providers=["CUDAExecutionProvider"])
|
| 74 |
+
ort_inputs = {n: t.cpu().numpy() for n, t in zip(all_inputs_names, inputs)}
|
| 75 |
+
_ = sess.run(None, ort_inputs)
|
| 76 |
+
print("✅ ONNX graph is executable")
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print("⚠️ ONNX check failed:", e)
|
| 79 |
+
|
| 80 |
+
# ─────────────────────────────────────────────────────────
|
| 81 |
+
# 3️⃣ Build the TensorRT engine
|
| 82 |
+
# ─────────────────────────────────────────────────────────
|
| 83 |
+
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
| 84 |
+
builder = trt.Builder(TRT_LOGGER)
|
| 85 |
+
network = builder.create_network(
|
| 86 |
+
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
| 87 |
+
parser = trt.OnnxParser(network, TRT_LOGGER)
|
| 88 |
+
|
| 89 |
+
with open(onnx_path, "rb") as f:
|
| 90 |
+
if not parser.parse(f.read()):
|
| 91 |
+
for i in range(parser.num_errors):
|
| 92 |
+
print(parser.get_error(i))
|
| 93 |
+
sys.exit("❌ ONNX → TRT parsing failed")
|
| 94 |
+
|
| 95 |
+
config = builder.create_builder_config()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def set_workspace(config, bytes_):
|
| 99 |
+
"""Version-agnostic workspace limit."""
|
| 100 |
+
if hasattr(config, "max_workspace_size"): # TRT 8 / 9
|
| 101 |
+
config.max_workspace_size = bytes_
|
| 102 |
+
else: # TRT 10+
|
| 103 |
+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, bytes_)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
# …
|
| 107 |
+
config = builder.create_builder_config()
|
| 108 |
+
set_workspace(config, 4 << 30) # 4 GB
|
| 109 |
+
# 4 GB
|
| 110 |
+
|
| 111 |
+
if builder.platform_has_fast_fp16:
|
| 112 |
+
config.set_flag(trt.BuilderFlag.FP16)
|
| 113 |
+
|
| 114 |
+
# ---- INT8 (optional) ----
|
| 115 |
+
# provide a calibrator if you need an INT8 engine; comment this
|
| 116 |
+
# block if you only care about FP16.
|
| 117 |
+
# ─────────────────────────────────────────────────────────
|
| 118 |
+
# helper: version-agnostic workspace limit
|
| 119 |
+
# ─────────────────────────────────────────────────────────
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def set_workspace(config: trt.IBuilderConfig, bytes_: int = 4 << 30):
|
| 123 |
+
"""
|
| 124 |
+
TRT < 10.x → config.max_workspace_size
|
| 125 |
+
TRT ≥ 10.x → config.set_memory_pool_limit(...)
|
| 126 |
+
"""
|
| 127 |
+
if hasattr(config, "max_workspace_size"): # TRT 8 / 9
|
| 128 |
+
config.max_workspace_size = bytes_
|
| 129 |
+
else: # TRT 10+
|
| 130 |
+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
|
| 131 |
+
bytes_)
|
| 132 |
+
|
| 133 |
+
# ─────────────────────────────────────────────────────────
|
| 134 |
+
# (optional) INT-8 calibrator
|
| 135 |
+
# ─────────────────────────────────────────────────────────
|
| 136 |
+
# ‼ Only keep this block if you really need INT-8 ‼ # gracefully skip if PyCUDA not present
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class VAECalibrator(trt.IInt8EntropyCalibrator2):
|
| 140 |
+
def __init__(self, loader, cache="calibration.cache", max_batches=10):
|
| 141 |
+
super().__init__()
|
| 142 |
+
self.loader = iter(loader)
|
| 143 |
+
self.batch_size = loader.batch_size or 1
|
| 144 |
+
self.max_batches = max_batches
|
| 145 |
+
self.count = 0
|
| 146 |
+
self.cache_file = cache
|
| 147 |
+
self.stream = cuda.Stream()
|
| 148 |
+
self.dev_ptrs = {}
|
| 149 |
+
|
| 150 |
+
# --- TRT 10 needs BOTH spellings ---
|
| 151 |
+
def get_batch_size(self):
|
| 152 |
+
return self.batch_size
|
| 153 |
+
|
| 154 |
+
def getBatchSize(self):
|
| 155 |
+
return self.batch_size
|
| 156 |
+
|
| 157 |
+
def get_batch(self, names):
|
| 158 |
+
if self.count >= self.max_batches:
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
# Randomly sample a number from 1 to 10
|
| 162 |
+
import random
|
| 163 |
+
vae_idx = random.randint(0, 10)
|
| 164 |
+
data = next(self.loader)
|
| 165 |
+
|
| 166 |
+
latent = data['ode_latent'][0][:, :1]
|
| 167 |
+
is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
|
| 168 |
+
feat_cache = ZERO_VAE_CACHE
|
| 169 |
+
for i in range(vae_idx):
|
| 170 |
+
inputs = [latent, is_first_frame, *feat_cache]
|
| 171 |
+
with torch.inference_mode():
|
| 172 |
+
outputs = model(*inputs)
|
| 173 |
+
latent = data['ode_latent'][0][:, i + 1:i + 2]
|
| 174 |
+
is_first_frame = torch.tensor([0.0], device="cuda", dtype=torch.float16)
|
| 175 |
+
feat_cache = outputs[1:]
|
| 176 |
+
|
| 177 |
+
# -------- ensure context is current --------
|
| 178 |
+
z_np = latent.cpu().numpy().astype('float32')
|
| 179 |
+
|
| 180 |
+
ptrs = [] # list[int] – one entry per name
|
| 181 |
+
for name in names: # <-- match TRT's binding order
|
| 182 |
+
if name == "z":
|
| 183 |
+
arr = z_np
|
| 184 |
+
elif name == "use_cache":
|
| 185 |
+
arr = is_first_frame.cpu().numpy().astype('float32')
|
| 186 |
+
else:
|
| 187 |
+
idx = int(name.split('_')[-1]) # "vae_cache_17" -> 17
|
| 188 |
+
arr = feat_cache[idx].cpu().numpy().astype('float32')
|
| 189 |
+
|
| 190 |
+
if name not in self.dev_ptrs:
|
| 191 |
+
self.dev_ptrs[name] = cuda.mem_alloc(arr.nbytes)
|
| 192 |
+
|
| 193 |
+
cuda.memcpy_htod_async(self.dev_ptrs[name], arr, self.stream)
|
| 194 |
+
ptrs.append(int(self.dev_ptrs[name])) # ***int() is required***
|
| 195 |
+
|
| 196 |
+
self.stream.synchronize()
|
| 197 |
+
self.count += 1
|
| 198 |
+
print(f"Calibration batch {self.count}/{self.max_batches}")
|
| 199 |
+
return ptrs
|
| 200 |
+
|
| 201 |
+
# --- calibration-cache helpers (both spellings) ---
|
| 202 |
+
def read_calibration_cache(self):
|
| 203 |
+
try:
|
| 204 |
+
with open(self.cache_file, "rb") as f:
|
| 205 |
+
return f.read()
|
| 206 |
+
except FileNotFoundError:
|
| 207 |
+
return None
|
| 208 |
+
|
| 209 |
+
def readCalibrationCache(self):
|
| 210 |
+
return self.read_calibration_cache()
|
| 211 |
+
|
| 212 |
+
def write_calibration_cache(self, cache):
|
| 213 |
+
with open(self.cache_file, "wb") as f:
|
| 214 |
+
f.write(cache)
|
| 215 |
+
|
| 216 |
+
def writeCalibrationCache(self, cache):
|
| 217 |
+
self.write_calibration_cache(cache)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# ─────────────────────────────────────────────────────────
|
| 221 |
+
# Builder-config + optimisation profile
|
| 222 |
+
# ─────────────────────────────────────────────────────────
|
| 223 |
+
config = builder.create_builder_config()
|
| 224 |
+
set_workspace(config, 4 << 30) # 4 GB
|
| 225 |
+
|
| 226 |
+
# ► enable FP16 if possible
|
| 227 |
+
if builder.platform_has_fast_fp16:
|
| 228 |
+
config.set_flag(trt.BuilderFlag.FP16)
|
| 229 |
+
|
| 230 |
+
# ► enable INT-8 (delete this block if you don’t need it)
|
| 231 |
+
if cuda is not None:
|
| 232 |
+
config.set_flag(trt.BuilderFlag.INT8)
|
| 233 |
+
# supply any representative batch you like – here we reuse the latent z
|
| 234 |
+
calib = VAECalibrator(dataloader)
|
| 235 |
+
# TRT-10 renamed the setter:
|
| 236 |
+
if hasattr(config, "set_int8_calibrator"): # TRT 10+
|
| 237 |
+
config.set_int8_calibrator(calib)
|
| 238 |
+
else: # TRT ≤ 9
|
| 239 |
+
config.int8_calibrator = calib
|
| 240 |
+
|
| 241 |
+
# ---- optimisation profile ----
|
| 242 |
+
profile = builder.create_optimization_profile()
|
| 243 |
+
profile.set_shape(all_inputs_names[0], # latent z
|
| 244 |
+
min=(1, 1, 16, 60, 104),
|
| 245 |
+
opt=(1, 1, 16, 60, 104),
|
| 246 |
+
max=(1, 1, 16, 60, 104))
|
| 247 |
+
profile.set_shape("use_cache", # scalar flag
|
| 248 |
+
min=(1,), opt=(1,), max=(1,))
|
| 249 |
+
for name, tensor in zip(all_inputs_names[2:], dummy_cache_input):
|
| 250 |
+
profile.set_shape(name, tensor.shape, tensor.shape, tensor.shape)
|
| 251 |
+
|
| 252 |
+
config.add_optimization_profile(profile)
|
| 253 |
+
|
| 254 |
+
# ─────────────────────────────────────────────────────────
|
| 255 |
+
# Build the engine (API changed in TRT-10)
|
| 256 |
+
# ─────────────────────────────────────────────────────────
|
| 257 |
+
print("⚙️ Building engine … (can take a minute)")
|
| 258 |
+
|
| 259 |
+
if hasattr(builder, "build_serialized_network"): # TRT 10+
|
| 260 |
+
serialized_engine = builder.build_serialized_network(network, config)
|
| 261 |
+
assert serialized_engine is not None, "build_serialized_network() failed"
|
| 262 |
+
plan_path = Path("checkpoints/vae_decoder_int8.trt")
|
| 263 |
+
plan_path.write_bytes(serialized_engine)
|
| 264 |
+
engine_bytes = serialized_engine # keep for smoke-test
|
| 265 |
+
else: # TRT ≤ 9
|
| 266 |
+
engine = builder.build_engine(network, config)
|
| 267 |
+
assert engine is not None, "build_engine() returned None"
|
| 268 |
+
plan_path = Path("checkpoints/vae_decoder_int8.trt")
|
| 269 |
+
plan_path.write_bytes(engine.serialize())
|
| 270 |
+
engine_bytes = engine.serialize()
|
| 271 |
+
|
| 272 |
+
print(f"✅ TensorRT engine written to {plan_path.resolve()}")
|
| 273 |
+
|
| 274 |
+
# ─────────────────────────────────────────────────────────
|
| 275 |
+
# 4️⃣ Quick smoke-test with the brand-new engine
|
| 276 |
+
# ─────────────────────────────────────────────────────────
|
| 277 |
+
with trt.Runtime(TRT_LOGGER) as rt:
|
| 278 |
+
engine = rt.deserialize_cuda_engine(engine_bytes)
|
| 279 |
+
context = engine.create_execution_context()
|
| 280 |
+
stream = torch.cuda.current_stream().cuda_stream
|
| 281 |
+
|
| 282 |
+
# pre-allocate device buffers once
|
| 283 |
+
device_buffers, outputs = {}, []
|
| 284 |
+
dtype_map = {trt.float32: torch.float32,
|
| 285 |
+
trt.float16: torch.float16,
|
| 286 |
+
trt.int8: torch.int8,
|
| 287 |
+
trt.int32: torch.int32}
|
| 288 |
+
|
| 289 |
+
for name, tensor in zip(all_inputs_names, inputs):
|
| 290 |
+
if -1 in engine.get_tensor_shape(name): # dynamic input
|
| 291 |
+
context.set_input_shape(name, tensor.shape)
|
| 292 |
+
context.set_tensor_address(name, int(tensor.data_ptr()))
|
| 293 |
+
device_buffers[name] = tensor
|
| 294 |
+
|
| 295 |
+
context.infer_shapes() # propagate ⇢ outputs
|
| 296 |
+
for i in range(engine.num_io_tensors):
|
| 297 |
+
name = engine.get_tensor_name(i)
|
| 298 |
+
if engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
|
| 299 |
+
shape = tuple(context.get_tensor_shape(name))
|
| 300 |
+
dtype = dtype_map[engine.get_tensor_dtype(name)]
|
| 301 |
+
out = torch.empty(shape, dtype=dtype, device="cuda")
|
| 302 |
+
context.set_tensor_address(name, int(out.data_ptr()))
|
| 303 |
+
outputs.append(out)
|
| 304 |
+
print(f"output {name} shape: {shape}")
|
| 305 |
+
|
| 306 |
+
context.execute_async_v3(stream_handle=stream)
|
| 307 |
+
torch.cuda.current_stream().synchronize()
|
| 308 |
+
print("✅ TRT execution OK – first output shape:", outputs[0].shape)
|
inference.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from omegaconf import OmegaConf
|
| 7 |
+
from torchvision.transforms import v2
|
| 8 |
+
from diffusers.utils import load_image
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from pipeline import CausalInferencePipeline
|
| 11 |
+
from wan.vae.wanx_vae import get_wanx_vae_wrapper
|
| 12 |
+
from demo_utils.vae_block3 import VAEDecoderWrapper
|
| 13 |
+
from utils.visualize import process_video
|
| 14 |
+
from utils.misc import set_seed
|
| 15 |
+
from utils.conditions import *
|
| 16 |
+
from utils.wan_wrapper import WanDiffusionWrapper
|
| 17 |
+
from safetensors.torch import load_file
|
| 18 |
+
|
| 19 |
+
def parse_args():
|
| 20 |
+
parser = argparse.ArgumentParser()
|
| 21 |
+
parser.add_argument("--config_path", type=str, default="configs/inference_yaml/inference_universal.yaml", help="Path to the config file")
|
| 22 |
+
parser.add_argument("--checkpoint_path", type=str, default="", help="Path to the checkpoint")
|
| 23 |
+
parser.add_argument("--img_path", type=str, default="demo_images/universal/0000.png", help="Path to the image")
|
| 24 |
+
parser.add_argument("--output_folder", type=str, default="outputs/", help="Output folder")
|
| 25 |
+
parser.add_argument("--num_output_frames", type=int, default=150,
|
| 26 |
+
help="Number of output latent frames")
|
| 27 |
+
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
| 28 |
+
parser.add_argument("--pretrained_model_path", type=str, default="Matrix-Game-2.0", help="Path to the VAE model folder")
|
| 29 |
+
args = parser.parse_args()
|
| 30 |
+
return args
|
| 31 |
+
|
| 32 |
+
class InteractiveGameInference:
|
| 33 |
+
def __init__(self, args):
|
| 34 |
+
self.args = args
|
| 35 |
+
self.device = torch.device("cuda")
|
| 36 |
+
self.weight_dtype = torch.bfloat16
|
| 37 |
+
|
| 38 |
+
self._init_config()
|
| 39 |
+
self._init_models()
|
| 40 |
+
|
| 41 |
+
self.frame_process = v2.Compose([
|
| 42 |
+
v2.Resize(size=(352, 640), antialias=True),
|
| 43 |
+
v2.ToTensor(),
|
| 44 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 45 |
+
])
|
| 46 |
+
|
| 47 |
+
def _init_config(self):
|
| 48 |
+
self.config = OmegaConf.load(self.args.config_path)
|
| 49 |
+
|
| 50 |
+
def _init_models(self):
|
| 51 |
+
# Initialize pipeline
|
| 52 |
+
generator = WanDiffusionWrapper(
|
| 53 |
+
**getattr(self.config, "model_kwargs", {}), is_causal=True)
|
| 54 |
+
current_vae_decoder = VAEDecoderWrapper()
|
| 55 |
+
vae_state_dict = torch.load(os.path.join(self.args.pretrained_model_path, "Wan2.1_VAE.pth"), map_location="cpu")
|
| 56 |
+
decoder_state_dict = {}
|
| 57 |
+
for key, value in vae_state_dict.items():
|
| 58 |
+
if 'decoder.' in key or 'conv2' in key:
|
| 59 |
+
decoder_state_dict[key] = value
|
| 60 |
+
current_vae_decoder.load_state_dict(decoder_state_dict)
|
| 61 |
+
current_vae_decoder.to(self.device, torch.float16)
|
| 62 |
+
current_vae_decoder.requires_grad_(False)
|
| 63 |
+
current_vae_decoder.eval()
|
| 64 |
+
current_vae_decoder.compile(mode="max-autotune-no-cudagraphs")
|
| 65 |
+
pipeline = CausalInferencePipeline(self.config, generator=generator, vae_decoder=current_vae_decoder)
|
| 66 |
+
if self.args.checkpoint_path:
|
| 67 |
+
print("Loading Pretrained Model...")
|
| 68 |
+
state_dict = load_file(self.args.checkpoint_path)
|
| 69 |
+
pipeline.generator.load_state_dict(state_dict)
|
| 70 |
+
|
| 71 |
+
self.pipeline = pipeline.to(device=self.device, dtype=self.weight_dtype)
|
| 72 |
+
self.pipeline.vae_decoder.to(torch.float16)
|
| 73 |
+
|
| 74 |
+
vae = get_wanx_vae_wrapper(self.args.pretrained_model_path, torch.float16)
|
| 75 |
+
vae.requires_grad_(False)
|
| 76 |
+
vae.eval()
|
| 77 |
+
self.vae = vae.to(self.device, self.weight_dtype)
|
| 78 |
+
|
| 79 |
+
def _resizecrop(self, image, th, tw):
|
| 80 |
+
w, h = image.size
|
| 81 |
+
if h / w > th / tw:
|
| 82 |
+
new_w = int(w)
|
| 83 |
+
new_h = int(new_w * th / tw)
|
| 84 |
+
else:
|
| 85 |
+
new_h = int(h)
|
| 86 |
+
new_w = int(new_h * tw / th)
|
| 87 |
+
left = (w - new_w) / 2
|
| 88 |
+
top = (h - new_h) / 2
|
| 89 |
+
right = (w + new_w) / 2
|
| 90 |
+
bottom = (h + new_h) / 2
|
| 91 |
+
image = image.crop((left, top, right, bottom))
|
| 92 |
+
return image
|
| 93 |
+
|
| 94 |
+
def generate_videos(self):
|
| 95 |
+
mode = self.config.pop('mode')
|
| 96 |
+
assert mode in ['universal', 'gta_drive', 'templerun']
|
| 97 |
+
|
| 98 |
+
image = load_image(self.args.img_path)
|
| 99 |
+
image = self._resizecrop(image, 352, 640)
|
| 100 |
+
image = self.frame_process(image)[None, :, None, :, :].to(dtype=self.weight_dtype, device=self.device)
|
| 101 |
+
# Encode the input image as the first latent
|
| 102 |
+
padding_video = torch.zeros_like(image).repeat(1, 1, 4 * (self.args.num_output_frames - 1), 1, 1)
|
| 103 |
+
img_cond = torch.concat([image, padding_video], dim=2)
|
| 104 |
+
tiler_kwargs={"tiled": True, "tile_size": [44, 80], "tile_stride": [23, 38]}
|
| 105 |
+
img_cond = self.vae.encode(img_cond, device=self.device, **tiler_kwargs).to(self.device)
|
| 106 |
+
mask_cond = torch.ones_like(img_cond)
|
| 107 |
+
mask_cond[:, :, 1:] = 0
|
| 108 |
+
cond_concat = torch.cat([mask_cond[:, :4], img_cond], dim=1)
|
| 109 |
+
visual_context = self.vae.clip.encode_video(image)
|
| 110 |
+
sampled_noise = torch.randn(
|
| 111 |
+
[1, 16,self.args.num_output_frames, 44, 80], device=self.device, dtype=self.weight_dtype
|
| 112 |
+
)
|
| 113 |
+
num_frames = (self.args.num_output_frames - 1) * 4 + 1
|
| 114 |
+
|
| 115 |
+
conditional_dict = {
|
| 116 |
+
"cond_concat": cond_concat.to(device=self.device, dtype=self.weight_dtype),
|
| 117 |
+
"visual_context": visual_context.to(device=self.device, dtype=self.weight_dtype)
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
if mode == 'universal':
|
| 121 |
+
cond_data = Bench_actions_universal(num_frames)
|
| 122 |
+
mouse_condition = cond_data['mouse_condition'].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
|
| 123 |
+
conditional_dict['mouse_cond'] = mouse_condition
|
| 124 |
+
elif mode == 'gta_drive':
|
| 125 |
+
cond_data = Bench_actions_gta_drive(num_frames)
|
| 126 |
+
mouse_condition = cond_data['mouse_condition'].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
|
| 127 |
+
conditional_dict['mouse_cond'] = mouse_condition
|
| 128 |
+
else:
|
| 129 |
+
cond_data = Bench_actions_templerun(num_frames)
|
| 130 |
+
keyboard_condition = cond_data['keyboard_condition'].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
|
| 131 |
+
conditional_dict['keyboard_cond'] = keyboard_condition
|
| 132 |
+
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
videos = self.pipeline.inference(
|
| 135 |
+
noise=sampled_noise,
|
| 136 |
+
conditional_dict=conditional_dict,
|
| 137 |
+
return_latents=False,
|
| 138 |
+
mode=mode,
|
| 139 |
+
profile=False
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
videos_tensor = torch.cat(videos, dim=1)
|
| 143 |
+
videos = rearrange(videos_tensor, "B T C H W -> B T H W C")
|
| 144 |
+
videos = ((videos.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)[0]
|
| 145 |
+
video = np.ascontiguousarray(videos)
|
| 146 |
+
mouse_icon = 'assets/images/mouse.png'
|
| 147 |
+
if mode != 'templerun':
|
| 148 |
+
config = (
|
| 149 |
+
keyboard_condition[0].float().cpu().numpy(),
|
| 150 |
+
mouse_condition[0].float().cpu().numpy()
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
config = (
|
| 154 |
+
keyboard_condition[0].float().cpu().numpy()
|
| 155 |
+
)
|
| 156 |
+
process_video(video.astype(np.uint8), self.args.output_folder+f'/demo.mp4', config, mouse_icon, mouse_scale=0.1, process_icon=False, mode=mode)
|
| 157 |
+
process_video(video.astype(np.uint8), self.args.output_folder+f'/demo_icon.mp4', config, mouse_icon, mouse_scale=0.1, process_icon=True, mode=mode)
|
| 158 |
+
print("Done")
|
| 159 |
+
|
| 160 |
+
def main():
|
| 161 |
+
"""Main entry point for video generation."""
|
| 162 |
+
args = parse_args()
|
| 163 |
+
set_seed(args.seed)
|
| 164 |
+
os.makedirs(args.output_folder, exist_ok=True)
|
| 165 |
+
pipeline = InteractiveGameInference(args)
|
| 166 |
+
pipeline.generate_videos()
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
main()
|
inference_streaming.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
from torchvision.transforms import v2
|
| 9 |
+
from diffusers.utils import load_image
|
| 10 |
+
|
| 11 |
+
from pipeline import CausalInferenceStreamingPipeline
|
| 12 |
+
from wan.vae.wanx_vae import get_wanx_vae_wrapper
|
| 13 |
+
from demo_utils.vae_block3 import VAEDecoderWrapper
|
| 14 |
+
from utils.visualize import process_video
|
| 15 |
+
from utils.misc import set_seed
|
| 16 |
+
from utils.conditions import *
|
| 17 |
+
from utils.wan_wrapper import WanDiffusionWrapper
|
| 18 |
+
from safetensors.torch import load_file
|
| 19 |
+
|
| 20 |
+
def parse_args():
|
| 21 |
+
parser = argparse.ArgumentParser()
|
| 22 |
+
parser.add_argument("--config_path", type=str, default="configs/inference_yaml/inference_universal.yaml", help="Path to the config file")
|
| 23 |
+
parser.add_argument("--checkpoint_path", type=str, default="", help="Path to the checkpoint")
|
| 24 |
+
parser.add_argument("--output_folder", type=str, default="outputs/", help="Output folder")
|
| 25 |
+
parser.add_argument("--max_num_output_frames", type=int, default=360,
|
| 26 |
+
help="Max number of output latent frames")
|
| 27 |
+
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
| 28 |
+
parser.add_argument("--pretrained_model_path", type=str, default="Matrix-Game-2.0", help="Path to the VAE model folder")
|
| 29 |
+
args = parser.parse_args()
|
| 30 |
+
return args
|
| 31 |
+
|
| 32 |
+
class InteractiveGameInference:
|
| 33 |
+
def __init__(self, args):
|
| 34 |
+
self.args = args
|
| 35 |
+
self.device = torch.device("cuda")
|
| 36 |
+
self.weight_dtype = torch.bfloat16
|
| 37 |
+
|
| 38 |
+
self._init_config()
|
| 39 |
+
self._init_models()
|
| 40 |
+
|
| 41 |
+
self.frame_process = v2.Compose([
|
| 42 |
+
v2.Resize(size=(352, 640), antialias=True),
|
| 43 |
+
v2.ToTensor(),
|
| 44 |
+
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 45 |
+
])
|
| 46 |
+
|
| 47 |
+
def _init_config(self):
|
| 48 |
+
self.config = OmegaConf.load(self.args.config_path)
|
| 49 |
+
|
| 50 |
+
def _init_models(self):
|
| 51 |
+
# Initialize pipeline
|
| 52 |
+
generator = WanDiffusionWrapper(
|
| 53 |
+
**getattr(self.config, "model_kwargs", {}), is_causal=True)
|
| 54 |
+
current_vae_decoder = VAEDecoderWrapper()
|
| 55 |
+
vae_state_dict = torch.load(os.path.join(self.args.pretrained_model_path, "Wan2.1_VAE.pth"), map_location="cpu")
|
| 56 |
+
decoder_state_dict = {}
|
| 57 |
+
for key, value in vae_state_dict.items():
|
| 58 |
+
if 'decoder.' in key or 'conv2' in key:
|
| 59 |
+
decoder_state_dict[key] = value
|
| 60 |
+
current_vae_decoder.load_state_dict(decoder_state_dict)
|
| 61 |
+
current_vae_decoder.to(self.device, torch.float16)
|
| 62 |
+
current_vae_decoder.requires_grad_(False)
|
| 63 |
+
current_vae_decoder.eval()
|
| 64 |
+
current_vae_decoder.compile(mode="max-autotune-no-cudagraphs")
|
| 65 |
+
pipeline = CausalInferenceStreamingPipeline(self.config, generator=generator, vae_decoder=current_vae_decoder)
|
| 66 |
+
if self.args.checkpoint_path:
|
| 67 |
+
print("Loading Pretrained Model...")
|
| 68 |
+
state_dict = load_file(self.args.checkpoint_path)
|
| 69 |
+
pipeline.generator.load_state_dict(state_dict)
|
| 70 |
+
|
| 71 |
+
self.pipeline = pipeline.to(device=self.device, dtype=self.weight_dtype)
|
| 72 |
+
self.pipeline.vae_decoder.to(torch.float16)
|
| 73 |
+
|
| 74 |
+
vae = get_wanx_vae_wrapper(self.args.pretrained_model_path, torch.float16)
|
| 75 |
+
vae.requires_grad_(False)
|
| 76 |
+
vae.eval()
|
| 77 |
+
self.vae = vae.to(self.device, self.weight_dtype)
|
| 78 |
+
|
| 79 |
+
def _resizecrop(self, image, th, tw):
|
| 80 |
+
w, h = image.size
|
| 81 |
+
if h / w > th / tw:
|
| 82 |
+
new_w = int(w)
|
| 83 |
+
new_h = int(new_w * th / tw)
|
| 84 |
+
else:
|
| 85 |
+
new_h = int(h)
|
| 86 |
+
new_w = int(new_h * tw / th)
|
| 87 |
+
left = (w - new_w) / 2
|
| 88 |
+
top = (h - new_h) / 2
|
| 89 |
+
right = (w + new_w) / 2
|
| 90 |
+
bottom = (h + new_h) / 2
|
| 91 |
+
image = image.crop((left, top, right, bottom))
|
| 92 |
+
return image
|
| 93 |
+
|
| 94 |
+
def generate_videos(self, mode='universal'):
|
| 95 |
+
assert mode in ['universal', 'gta_drive', 'templerun']
|
| 96 |
+
|
| 97 |
+
while True:
|
| 98 |
+
try:
|
| 99 |
+
img_path = input("Please input the image path: ")
|
| 100 |
+
image = load_image(img_path.strip())
|
| 101 |
+
break
|
| 102 |
+
except:
|
| 103 |
+
print(f"Fail to load image from {img_path}!")
|
| 104 |
+
|
| 105 |
+
image = self._resizecrop(image, 352, 640)
|
| 106 |
+
image = self.frame_process(image)[None, :, None, :, :].to(dtype=self.weight_dtype, device=self.device)
|
| 107 |
+
# Encode the input image as the first latent
|
| 108 |
+
padding_video = torch.zeros_like(image).repeat(1, 1, 4 * (self.args.max_num_output_frames - 1), 1, 1)
|
| 109 |
+
img_cond = torch.concat([image, padding_video], dim=2)
|
| 110 |
+
tiler_kwargs={"tiled": True, "tile_size": [44, 80], "tile_stride": [23, 38]}
|
| 111 |
+
img_cond = self.vae.encode(img_cond, device=self.device, **tiler_kwargs).to(self.device)
|
| 112 |
+
mask_cond = torch.ones_like(img_cond)
|
| 113 |
+
mask_cond[:, :, 1:] = 0
|
| 114 |
+
cond_concat = torch.cat([mask_cond[:, :4], img_cond], dim=1)
|
| 115 |
+
visual_context = self.vae.clip.encode_video(image)
|
| 116 |
+
sampled_noise = torch.randn(
|
| 117 |
+
[1, 16,self.args.max_num_output_frames, 44, 80], device=self.device, dtype=self.weight_dtype
|
| 118 |
+
)
|
| 119 |
+
num_frames = (self.args.max_num_output_frames - 1) * 4 + 1
|
| 120 |
+
|
| 121 |
+
conditional_dict = {
|
| 122 |
+
"cond_concat": cond_concat.to(device=self.device, dtype=self.weight_dtype),
|
| 123 |
+
"visual_context": visual_context.to(device=self.device, dtype=self.weight_dtype)
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
if mode == 'universal':
|
| 127 |
+
cond_data = Bench_actions_universal(num_frames)
|
| 128 |
+
mouse_condition = cond_data['mouse_condition'].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
|
| 129 |
+
conditional_dict['mouse_cond'] = mouse_condition
|
| 130 |
+
elif mode == 'gta_drive':
|
| 131 |
+
cond_data = Bench_actions_gta_drive(num_frames)
|
| 132 |
+
mouse_condition = cond_data['mouse_condition'].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
|
| 133 |
+
conditional_dict['mouse_cond'] = mouse_condition
|
| 134 |
+
else:
|
| 135 |
+
cond_data = Bench_actions_templerun(num_frames)
|
| 136 |
+
keyboard_condition = cond_data['keyboard_condition'].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
|
| 137 |
+
conditional_dict['keyboard_cond'] = keyboard_condition
|
| 138 |
+
|
| 139 |
+
with torch.no_grad():
|
| 140 |
+
videos = self.pipeline.inference(
|
| 141 |
+
noise=sampled_noise,
|
| 142 |
+
conditional_dict=conditional_dict,
|
| 143 |
+
return_latents=False,
|
| 144 |
+
output_folder=self.args.output_folder,
|
| 145 |
+
name=os.path.basename(img_path),
|
| 146 |
+
mode=mode
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def main():
|
| 150 |
+
"""Main entry point for video generation."""
|
| 151 |
+
args = parse_args()
|
| 152 |
+
set_seed(args.seed)
|
| 153 |
+
os.makedirs(args.output_folder, exist_ok=True)
|
| 154 |
+
pipeline = InteractiveGameInference(args)
|
| 155 |
+
mode = pipeline.config.pop('mode')
|
| 156 |
+
stop = ''
|
| 157 |
+
while stop != 'n':
|
| 158 |
+
pipeline.generate_videos(mode)
|
| 159 |
+
stop = input("Press `n` to stop generation: ").strip().lower()
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
main()
|
pipeline/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .causal_inference import CausalInferencePipeline, CausalInferenceStreamingPipeline
|
| 2 |
+
__all__ = [
|
| 3 |
+
"CausalInferencePipeline",
|
| 4 |
+
"CausalInferenceStreamingPipeline"
|
| 5 |
+
]
|
pipeline/causal_inference.py
ADDED
|
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import time
|
| 5 |
+
import copy
|
| 6 |
+
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanVAEWrapper
|
| 9 |
+
from utils.visualize import process_video
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from demo_utils.constant import ZERO_VAE_CACHE
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
def get_current_action(mode="universal"):
|
| 15 |
+
|
| 16 |
+
CAM_VALUE = 0.1
|
| 17 |
+
if mode == 'universal':
|
| 18 |
+
print()
|
| 19 |
+
print('-'*30)
|
| 20 |
+
print("PRESS [I, K, J, L, U] FOR CAMERA TRANSFORM\n (I: up, K: down, J: left, L: right, U: no move)")
|
| 21 |
+
print("PRESS [W, S, A, D, Q] FOR MOVEMENT\n (W: forward, S: back, A: left, D: right, Q: no move)")
|
| 22 |
+
print('-'*30)
|
| 23 |
+
CAMERA_VALUE_MAP = {
|
| 24 |
+
"i": [CAM_VALUE, 0],
|
| 25 |
+
"k": [-CAM_VALUE, 0],
|
| 26 |
+
"j": [0, -CAM_VALUE],
|
| 27 |
+
"l": [0, CAM_VALUE],
|
| 28 |
+
"u": [0, 0]
|
| 29 |
+
}
|
| 30 |
+
KEYBOARD_IDX = {
|
| 31 |
+
"w": [1, 0, 0, 0], "s": [0, 1, 0, 0], "a": [0, 0, 1, 0], "d": [0, 0, 0, 1],
|
| 32 |
+
"q": [0, 0, 0, 0]
|
| 33 |
+
}
|
| 34 |
+
flag = 0
|
| 35 |
+
while flag != 1:
|
| 36 |
+
try:
|
| 37 |
+
idx_mouse = input('Please input the mouse action (e.g. `U`):\n').strip().lower()
|
| 38 |
+
idx_keyboard = input('Please input the keyboard action (e.g. `W`):\n').strip().lower()
|
| 39 |
+
if idx_mouse in CAMERA_VALUE_MAP.keys() and idx_keyboard in KEYBOARD_IDX.keys():
|
| 40 |
+
flag = 1
|
| 41 |
+
except:
|
| 42 |
+
pass
|
| 43 |
+
mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse]).cuda()
|
| 44 |
+
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).cuda()
|
| 45 |
+
elif mode == 'gta_drive':
|
| 46 |
+
print()
|
| 47 |
+
print('-'*30)
|
| 48 |
+
print("PRESS [W, S, A, D, Q] FOR MOVEMENT\n (W: forward, S: back, A: left, D: right, Q: no move)")
|
| 49 |
+
print('-'*30)
|
| 50 |
+
CAMERA_VALUE_MAP = {
|
| 51 |
+
"a": [0, -CAM_VALUE],
|
| 52 |
+
"d": [0, CAM_VALUE],
|
| 53 |
+
"q": [0, 0]
|
| 54 |
+
}
|
| 55 |
+
KEYBOARD_IDX = {
|
| 56 |
+
"w": [1, 0], "s": [0, 1],
|
| 57 |
+
"q": [0, 0]
|
| 58 |
+
}
|
| 59 |
+
flag = 0
|
| 60 |
+
while flag != 1:
|
| 61 |
+
try:
|
| 62 |
+
indexes = input('Please input the actions (split with ` `):\n(e.g. `W` for forward, `W A` for forward and left)\n').strip().lower().split(' ')
|
| 63 |
+
idx_mouse = []
|
| 64 |
+
idx_keyboard = []
|
| 65 |
+
for i in indexes:
|
| 66 |
+
if i in CAMERA_VALUE_MAP.keys():
|
| 67 |
+
idx_mouse += [i]
|
| 68 |
+
elif i in KEYBOARD_IDX.keys():
|
| 69 |
+
idx_keyboard += [i]
|
| 70 |
+
if len(idx_mouse) == 0:
|
| 71 |
+
idx_mouse += ['q']
|
| 72 |
+
if len(idx_keyboard) == 0:
|
| 73 |
+
idx_keyboard += ['q']
|
| 74 |
+
assert idx_mouse in [['a'], ['d'], ['q']] and idx_keyboard in [['q'], ['w'], ['s']]
|
| 75 |
+
flag = 1
|
| 76 |
+
except:
|
| 77 |
+
pass
|
| 78 |
+
mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse[0]]).cuda()
|
| 79 |
+
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard[0]]).cuda()
|
| 80 |
+
elif mode == 'templerun':
|
| 81 |
+
print()
|
| 82 |
+
print('-'*30)
|
| 83 |
+
print("PRESS [W, S, A, D, Z, C, Q] FOR ACTIONS\n (W: jump, S: slide, A: left side, D: right side, Z: turn left, C: turn right, Q: no move)")
|
| 84 |
+
print('-'*30)
|
| 85 |
+
KEYBOARD_IDX = {
|
| 86 |
+
"w": [0, 1, 0, 0, 0, 0, 0], "s": [0, 0, 1, 0, 0, 0, 0],
|
| 87 |
+
"a": [0, 0, 0, 0, 0, 1, 0], "d": [0, 0, 0, 0, 0, 0, 1],
|
| 88 |
+
"z": [0, 0, 0, 1, 0, 0, 0], "c": [0, 0, 0, 0, 1, 0, 0],
|
| 89 |
+
"q": [1, 0, 0, 0, 0, 0, 0]
|
| 90 |
+
}
|
| 91 |
+
flag = 0
|
| 92 |
+
while flag != 1:
|
| 93 |
+
try:
|
| 94 |
+
idx_keyboard = input('Please input the action: \n(e.g. `W` for forward, `Z` for turning left)\n').strip().lower()
|
| 95 |
+
if idx_keyboard in KEYBOARD_IDX.keys():
|
| 96 |
+
flag = 1
|
| 97 |
+
except:
|
| 98 |
+
pass
|
| 99 |
+
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).cuda()
|
| 100 |
+
|
| 101 |
+
if mode != 'templerun':
|
| 102 |
+
return {
|
| 103 |
+
"mouse": mouse_cond,
|
| 104 |
+
"keyboard": keyboard_cond
|
| 105 |
+
}
|
| 106 |
+
return {
|
| 107 |
+
"keyboard": keyboard_cond
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
def cond_current(conditional_dict, current_start_frame, num_frame_per_block, replace=None, mode='universal'):
|
| 111 |
+
|
| 112 |
+
new_cond = {}
|
| 113 |
+
|
| 114 |
+
new_cond["cond_concat"] = conditional_dict["cond_concat"][:, :, current_start_frame: current_start_frame + num_frame_per_block]
|
| 115 |
+
new_cond["visual_context"] = conditional_dict["visual_context"]
|
| 116 |
+
if replace != None:
|
| 117 |
+
if current_start_frame == 0:
|
| 118 |
+
last_frame_num = 1 + 4 * (num_frame_per_block - 1)
|
| 119 |
+
else:
|
| 120 |
+
last_frame_num = 4 * num_frame_per_block
|
| 121 |
+
final_frame = 1 + 4 * (current_start_frame + num_frame_per_block-1)
|
| 122 |
+
if mode != 'templerun':
|
| 123 |
+
conditional_dict["mouse_cond"][:, -last_frame_num + final_frame: final_frame] = replace['mouse'][None, None, :].repeat(1, last_frame_num, 1)
|
| 124 |
+
conditional_dict["keyboard_cond"][:, -last_frame_num + final_frame: final_frame] = replace['keyboard'][None, None, :].repeat(1, last_frame_num, 1)
|
| 125 |
+
if mode != 'templerun':
|
| 126 |
+
new_cond["mouse_cond"] = conditional_dict["mouse_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)]
|
| 127 |
+
new_cond["keyboard_cond"] = conditional_dict["keyboard_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)]
|
| 128 |
+
|
| 129 |
+
if replace != None:
|
| 130 |
+
return new_cond, conditional_dict
|
| 131 |
+
else:
|
| 132 |
+
return new_cond
|
| 133 |
+
|
| 134 |
+
class CausalInferencePipeline(torch.nn.Module):
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
args,
|
| 138 |
+
device="cuda",
|
| 139 |
+
generator=None,
|
| 140 |
+
vae_decoder=None,
|
| 141 |
+
):
|
| 142 |
+
super().__init__()
|
| 143 |
+
# Step 1: Initialize all models
|
| 144 |
+
self.generator = WanDiffusionWrapper(
|
| 145 |
+
**getattr(args, "model_kwargs", {}), is_causal=True) if generator is None else generator
|
| 146 |
+
|
| 147 |
+
self.vae_decoder = vae_decoder
|
| 148 |
+
# Step 2: Initialize all causal hyperparmeters
|
| 149 |
+
self.scheduler = self.generator.get_scheduler()
|
| 150 |
+
self.denoising_step_list = torch.tensor(
|
| 151 |
+
args.denoising_step_list, dtype=torch.long)
|
| 152 |
+
if args.warp_denoising_step:
|
| 153 |
+
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
|
| 154 |
+
self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
|
| 155 |
+
|
| 156 |
+
self.num_transformer_blocks = 30
|
| 157 |
+
self.frame_seq_length = 880
|
| 158 |
+
|
| 159 |
+
self.kv_cache1 = None
|
| 160 |
+
self.kv_cache_mouse = None
|
| 161 |
+
self.kv_cache_keyboard = None
|
| 162 |
+
self.args = args
|
| 163 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
| 164 |
+
self.local_attn_size = self.generator.model.local_attn_size
|
| 165 |
+
assert self.local_attn_size != -1
|
| 166 |
+
print(f"KV inference with {self.num_frame_per_block} frames per block")
|
| 167 |
+
|
| 168 |
+
if self.num_frame_per_block > 1:
|
| 169 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
| 170 |
+
|
| 171 |
+
def inference(
|
| 172 |
+
self,
|
| 173 |
+
noise: torch.Tensor,
|
| 174 |
+
conditional_dict,
|
| 175 |
+
initial_latent = None,
|
| 176 |
+
return_latents = False,
|
| 177 |
+
mode = 'universal',
|
| 178 |
+
profile = False,
|
| 179 |
+
) -> torch.Tensor:
|
| 180 |
+
"""
|
| 181 |
+
Perform inference on the given noise and text prompts.
|
| 182 |
+
Inputs:
|
| 183 |
+
noise (torch.Tensor): The input noise tensor of shape
|
| 184 |
+
(batch_size, num_output_frames, num_channels, height, width).
|
| 185 |
+
text_prompts (List[str]): The list of text prompts.
|
| 186 |
+
initial_latent (torch.Tensor): The initial latent tensor of shape
|
| 187 |
+
(batch_size, num_input_frames, num_channels, height, width).
|
| 188 |
+
If num_input_frames is 1, perform image to video.
|
| 189 |
+
If num_input_frames is greater than 1, perform video extension.
|
| 190 |
+
return_latents (bool): Whether to return the latents.
|
| 191 |
+
Outputs:
|
| 192 |
+
video (torch.Tensor): The generated video tensor of shape
|
| 193 |
+
(batch_size, num_output_frames, num_channels, height, width).
|
| 194 |
+
It is normalized to be in the range [0, 1].
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
assert noise.shape[1] == 16
|
| 198 |
+
batch_size, num_channels, num_frames, height, width = noise.shape
|
| 199 |
+
|
| 200 |
+
assert num_frames % self.num_frame_per_block == 0
|
| 201 |
+
num_blocks = num_frames // self.num_frame_per_block
|
| 202 |
+
|
| 203 |
+
num_input_frames = initial_latent.shape[2] if initial_latent is not None else 0
|
| 204 |
+
num_output_frames = num_frames + num_input_frames # add the initial latent frames
|
| 205 |
+
|
| 206 |
+
output = torch.zeros(
|
| 207 |
+
[batch_size, num_channels, num_output_frames, height, width],
|
| 208 |
+
device=noise.device,
|
| 209 |
+
dtype=noise.dtype
|
| 210 |
+
)
|
| 211 |
+
videos = []
|
| 212 |
+
vae_cache = copy.deepcopy(ZERO_VAE_CACHE)
|
| 213 |
+
for j in range(len(vae_cache)):
|
| 214 |
+
vae_cache[j] = None
|
| 215 |
+
|
| 216 |
+
self.kv_cache1 = self.kv_cache_keyboard = self.kv_cache_mouse = self.crossattn_cache=None
|
| 217 |
+
# Step 1: Initialize KV cache to all zeros
|
| 218 |
+
if self.kv_cache1 is None:
|
| 219 |
+
self._initialize_kv_cache(
|
| 220 |
+
batch_size=batch_size,
|
| 221 |
+
dtype=noise.dtype,
|
| 222 |
+
device=noise.device
|
| 223 |
+
)
|
| 224 |
+
self._initialize_kv_cache_mouse_and_keyboard(
|
| 225 |
+
batch_size=batch_size,
|
| 226 |
+
dtype=noise.dtype,
|
| 227 |
+
device=noise.device
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
self._initialize_crossattn_cache(
|
| 231 |
+
batch_size=batch_size,
|
| 232 |
+
dtype=noise.dtype,
|
| 233 |
+
device=noise.device
|
| 234 |
+
)
|
| 235 |
+
else:
|
| 236 |
+
# reset cross attn cache
|
| 237 |
+
for block_index in range(self.num_transformer_blocks):
|
| 238 |
+
self.crossattn_cache[block_index]["is_init"] = False
|
| 239 |
+
# reset kv cache
|
| 240 |
+
for block_index in range(len(self.kv_cache1)):
|
| 241 |
+
self.kv_cache1[block_index]["global_end_index"] = torch.tensor(
|
| 242 |
+
[0], dtype=torch.long, device=noise.device)
|
| 243 |
+
self.kv_cache1[block_index]["local_end_index"] = torch.tensor(
|
| 244 |
+
[0], dtype=torch.long, device=noise.device)
|
| 245 |
+
self.kv_cache_mouse[block_index]["global_end_index"] = torch.tensor(
|
| 246 |
+
[0], dtype=torch.long, device=noise.device)
|
| 247 |
+
self.kv_cache_mouse[block_index]["local_end_index"] = torch.tensor(
|
| 248 |
+
[0], dtype=torch.long, device=noise.device)
|
| 249 |
+
self.kv_cache_keyboard[block_index]["global_end_index"] = torch.tensor(
|
| 250 |
+
[0], dtype=torch.long, device=noise.device)
|
| 251 |
+
self.kv_cache_keyboard[block_index]["local_end_index"] = torch.tensor(
|
| 252 |
+
[0], dtype=torch.long, device=noise.device)
|
| 253 |
+
# Step 2: Cache context feature
|
| 254 |
+
current_start_frame = 0
|
| 255 |
+
if initial_latent is not None:
|
| 256 |
+
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
|
| 257 |
+
# Assume num_input_frames is self.num_frame_per_block * num_input_blocks
|
| 258 |
+
assert num_input_frames % self.num_frame_per_block == 0
|
| 259 |
+
num_input_blocks = num_input_frames // self.num_frame_per_block
|
| 260 |
+
|
| 261 |
+
for _ in range(num_input_blocks):
|
| 262 |
+
current_ref_latents = \
|
| 263 |
+
initial_latent[:, :, current_start_frame:current_start_frame + self.num_frame_per_block]
|
| 264 |
+
output[:, :, current_start_frame:current_start_frame + self.num_frame_per_block] = current_ref_latents
|
| 265 |
+
|
| 266 |
+
self.generator(
|
| 267 |
+
noisy_image_or_video=current_ref_latents,
|
| 268 |
+
conditional_dict=cond_current(conditional_dict, current_start_frame, self.num_frame_per_block, mode=mode),
|
| 269 |
+
timestep=timestep * 0,
|
| 270 |
+
kv_cache=self.kv_cache1,
|
| 271 |
+
kv_cache_mouse=self.kv_cache_mouse,
|
| 272 |
+
kv_cache_keyboard=self.kv_cache_keyboard,
|
| 273 |
+
crossattn_cache=self.crossattn_cache,
|
| 274 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 275 |
+
)
|
| 276 |
+
current_start_frame += self.num_frame_per_block
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
# Step 3: Temporal denoising loop
|
| 280 |
+
all_num_frames = [self.num_frame_per_block] * num_blocks
|
| 281 |
+
if profile:
|
| 282 |
+
diffusion_start = torch.cuda.Event(enable_timing=True)
|
| 283 |
+
diffusion_end = torch.cuda.Event(enable_timing=True)
|
| 284 |
+
for current_num_frames in tqdm(all_num_frames):
|
| 285 |
+
|
| 286 |
+
noisy_input = noise[
|
| 287 |
+
:, :, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
|
| 288 |
+
|
| 289 |
+
# Step 3.1: Spatial denoising loop
|
| 290 |
+
if profile:
|
| 291 |
+
torch.cuda.synchronize()
|
| 292 |
+
diffusion_start.record()
|
| 293 |
+
for index, current_timestep in enumerate(self.denoising_step_list):
|
| 294 |
+
# set current timestep
|
| 295 |
+
timestep = torch.ones(
|
| 296 |
+
[batch_size, current_num_frames],
|
| 297 |
+
device=noise.device,
|
| 298 |
+
dtype=torch.int64) * current_timestep
|
| 299 |
+
|
| 300 |
+
if index < len(self.denoising_step_list) - 1:
|
| 301 |
+
_, denoised_pred = self.generator(
|
| 302 |
+
noisy_image_or_video=noisy_input,
|
| 303 |
+
conditional_dict=cond_current(conditional_dict, current_start_frame, self.num_frame_per_block, mode=mode),
|
| 304 |
+
timestep=timestep,
|
| 305 |
+
kv_cache=self.kv_cache1,
|
| 306 |
+
kv_cache_mouse=self.kv_cache_mouse,
|
| 307 |
+
kv_cache_keyboard=self.kv_cache_keyboard,
|
| 308 |
+
crossattn_cache=self.crossattn_cache,
|
| 309 |
+
current_start=current_start_frame * self.frame_seq_length
|
| 310 |
+
)
|
| 311 |
+
next_timestep = self.denoising_step_list[index + 1]
|
| 312 |
+
noisy_input = self.scheduler.add_noise(
|
| 313 |
+
rearrange(denoised_pred, 'b c f h w -> (b f) c h w'),# .flatten(0, 1),
|
| 314 |
+
torch.randn_like(rearrange(denoised_pred, 'b c f h w -> (b f) c h w')),
|
| 315 |
+
next_timestep * torch.ones(
|
| 316 |
+
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
|
| 317 |
+
)
|
| 318 |
+
noisy_input = rearrange(noisy_input, '(b f) c h w -> b c f h w', b=denoised_pred.shape[0])
|
| 319 |
+
else:
|
| 320 |
+
# for getting real output
|
| 321 |
+
_, denoised_pred = self.generator(
|
| 322 |
+
noisy_image_or_video=noisy_input,
|
| 323 |
+
conditional_dict=cond_current(conditional_dict, current_start_frame, self.num_frame_per_block, mode=mode),
|
| 324 |
+
timestep=timestep,
|
| 325 |
+
kv_cache=self.kv_cache1,
|
| 326 |
+
kv_cache_mouse=self.kv_cache_mouse,
|
| 327 |
+
kv_cache_keyboard=self.kv_cache_keyboard,
|
| 328 |
+
crossattn_cache=self.crossattn_cache,
|
| 329 |
+
current_start=current_start_frame * self.frame_seq_length
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Step 3.2: record the model's output
|
| 333 |
+
output[:, :, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
|
| 334 |
+
|
| 335 |
+
# Step 3.3: rerun with timestep zero to update KV cache using clean context
|
| 336 |
+
context_timestep = torch.ones_like(timestep) * self.args.context_noise
|
| 337 |
+
|
| 338 |
+
self.generator(
|
| 339 |
+
noisy_image_or_video=denoised_pred,
|
| 340 |
+
conditional_dict=cond_current(conditional_dict, current_start_frame, self.num_frame_per_block, mode=mode),
|
| 341 |
+
timestep=context_timestep,
|
| 342 |
+
kv_cache=self.kv_cache1,
|
| 343 |
+
kv_cache_mouse=self.kv_cache_mouse,
|
| 344 |
+
kv_cache_keyboard=self.kv_cache_keyboard,
|
| 345 |
+
crossattn_cache=self.crossattn_cache,
|
| 346 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
# Step 3.4: update the start and end frame indices
|
| 350 |
+
current_start_frame += current_num_frames
|
| 351 |
+
|
| 352 |
+
denoised_pred = denoised_pred.transpose(1,2)
|
| 353 |
+
video, vae_cache = self.vae_decoder(denoised_pred.half(), *vae_cache)
|
| 354 |
+
videos += [video]
|
| 355 |
+
|
| 356 |
+
if profile:
|
| 357 |
+
torch.cuda.synchronize()
|
| 358 |
+
diffusion_end.record()
|
| 359 |
+
diffusion_time = diffusion_start.elapsed_time(diffusion_end)
|
| 360 |
+
print(f"diffusion_time: {diffusion_time}", flush=True)
|
| 361 |
+
fps = video.shape[1]*1000/ diffusion_time
|
| 362 |
+
print(f" - FPS: {fps:.2f}")
|
| 363 |
+
|
| 364 |
+
if return_latents:
|
| 365 |
+
return output
|
| 366 |
+
else:
|
| 367 |
+
return videos
|
| 368 |
+
|
| 369 |
+
def _initialize_kv_cache(self, batch_size, dtype, device):
|
| 370 |
+
"""
|
| 371 |
+
Initialize a Per-GPU KV cache for the Wan model.
|
| 372 |
+
"""
|
| 373 |
+
kv_cache1 = []
|
| 374 |
+
if self.local_attn_size != -1:
|
| 375 |
+
# Use the local attention size to compute the KV cache size
|
| 376 |
+
kv_cache_size = self.local_attn_size * self.frame_seq_length
|
| 377 |
+
else:
|
| 378 |
+
# Use the default KV cache size
|
| 379 |
+
kv_cache_size = 15 * 1 * self.frame_seq_length # 32760
|
| 380 |
+
|
| 381 |
+
for _ in range(self.num_transformer_blocks):
|
| 382 |
+
kv_cache1.append({
|
| 383 |
+
"k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
| 384 |
+
"v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
| 385 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
| 386 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
| 387 |
+
})
|
| 388 |
+
|
| 389 |
+
self.kv_cache1 = kv_cache1 # always store the clean cache
|
| 390 |
+
|
| 391 |
+
def _initialize_kv_cache_mouse_and_keyboard(self, batch_size, dtype, device):
|
| 392 |
+
"""
|
| 393 |
+
Initialize a Per-GPU KV cache for the Wan model.
|
| 394 |
+
"""
|
| 395 |
+
kv_cache_mouse = []
|
| 396 |
+
kv_cache_keyboard = []
|
| 397 |
+
if self.local_attn_size != -1:
|
| 398 |
+
kv_cache_size = self.local_attn_size
|
| 399 |
+
else:
|
| 400 |
+
kv_cache_size = 15 * 1
|
| 401 |
+
for _ in range(self.num_transformer_blocks):
|
| 402 |
+
kv_cache_keyboard.append({
|
| 403 |
+
"k": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
|
| 404 |
+
"v": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
|
| 405 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
| 406 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
| 407 |
+
})
|
| 408 |
+
kv_cache_mouse.append({
|
| 409 |
+
"k": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
|
| 410 |
+
"v": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
|
| 411 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
| 412 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
| 413 |
+
})
|
| 414 |
+
self.kv_cache_keyboard = kv_cache_keyboard # always store the clean cache
|
| 415 |
+
self.kv_cache_mouse = kv_cache_mouse # always store the clean cache
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def _initialize_crossattn_cache(self, batch_size, dtype, device):
|
| 420 |
+
"""
|
| 421 |
+
Initialize a Per-GPU cross-attention cache for the Wan model.
|
| 422 |
+
"""
|
| 423 |
+
crossattn_cache = []
|
| 424 |
+
|
| 425 |
+
for _ in range(self.num_transformer_blocks):
|
| 426 |
+
crossattn_cache.append({
|
| 427 |
+
"k": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
|
| 428 |
+
"v": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
|
| 429 |
+
"is_init": False
|
| 430 |
+
})
|
| 431 |
+
self.crossattn_cache = crossattn_cache
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class CausalInferenceStreamingPipeline(torch.nn.Module):
|
| 435 |
+
def __init__(
|
| 436 |
+
self,
|
| 437 |
+
args,
|
| 438 |
+
device="cuda",
|
| 439 |
+
vae_decoder=None,
|
| 440 |
+
generator=None,
|
| 441 |
+
):
|
| 442 |
+
super().__init__()
|
| 443 |
+
# Step 1: Initialize all models
|
| 444 |
+
self.generator = WanDiffusionWrapper(
|
| 445 |
+
**getattr(args, "model_kwargs", {}), is_causal=True) if generator is None else generator
|
| 446 |
+
self.vae_decoder = vae_decoder
|
| 447 |
+
|
| 448 |
+
# Step 2: Initialize all causal hyperparmeters
|
| 449 |
+
self.scheduler = self.generator.get_scheduler()
|
| 450 |
+
self.denoising_step_list = torch.tensor(
|
| 451 |
+
args.denoising_step_list, dtype=torch.long)
|
| 452 |
+
if args.warp_denoising_step:
|
| 453 |
+
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
|
| 454 |
+
self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
|
| 455 |
+
|
| 456 |
+
self.num_transformer_blocks = 30
|
| 457 |
+
self.frame_seq_length = 880 # 1590 # HW/4
|
| 458 |
+
|
| 459 |
+
self.kv_cache1 = None
|
| 460 |
+
self.kv_cache_mouse = None
|
| 461 |
+
self.kv_cache_keyboard = None
|
| 462 |
+
self.args = args
|
| 463 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
| 464 |
+
self.local_attn_size = self.generator.model.local_attn_size
|
| 465 |
+
assert self.local_attn_size != -1
|
| 466 |
+
print(f"KV inference with {self.num_frame_per_block} frames per block")
|
| 467 |
+
|
| 468 |
+
if self.num_frame_per_block > 1:
|
| 469 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
| 470 |
+
|
| 471 |
+
def inference(
|
| 472 |
+
self,
|
| 473 |
+
noise: torch.Tensor,
|
| 474 |
+
conditional_dict,
|
| 475 |
+
initial_latent: Optional[torch.Tensor] = None,
|
| 476 |
+
return_latents: bool = False,
|
| 477 |
+
output_folder = None,
|
| 478 |
+
name = None,
|
| 479 |
+
mode = 'universal'
|
| 480 |
+
) -> torch.Tensor:
|
| 481 |
+
"""
|
| 482 |
+
Perform inference on the given noise and text prompts.
|
| 483 |
+
Inputs:
|
| 484 |
+
noise (torch.Tensor): The input noise tensor of shape
|
| 485 |
+
(batch_size, num_output_frames, num_channels, height, width).
|
| 486 |
+
text_prompts (List[str]): The list of text prompts.
|
| 487 |
+
initial_latent (torch.Tensor): The initial latent tensor of shape
|
| 488 |
+
(batch_size, num_input_frames, num_channels, height, width).
|
| 489 |
+
If num_input_frames is 1, perform image to video.
|
| 490 |
+
If num_input_frames is greater than 1, perform video extension.
|
| 491 |
+
return_latents (bool): Whether to return the latents.
|
| 492 |
+
Outputs:
|
| 493 |
+
video (torch.Tensor): The generated video tensor of shape
|
| 494 |
+
(batch_size, num_output_frames, num_channels, height, width).
|
| 495 |
+
It is normalized to be in the range [0, 1].
|
| 496 |
+
"""
|
| 497 |
+
|
| 498 |
+
assert noise.shape[1] == 16
|
| 499 |
+
batch_size, num_channels, num_frames, height, width = noise.shape
|
| 500 |
+
|
| 501 |
+
assert num_frames % self.num_frame_per_block == 0
|
| 502 |
+
num_blocks = num_frames // self.num_frame_per_block
|
| 503 |
+
|
| 504 |
+
num_input_frames = initial_latent.shape[2] if initial_latent is not None else 0
|
| 505 |
+
num_output_frames = num_frames + num_input_frames # add the initial latent frames
|
| 506 |
+
|
| 507 |
+
output = torch.zeros(
|
| 508 |
+
[batch_size, num_channels, num_output_frames, height, width],
|
| 509 |
+
device=noise.device,
|
| 510 |
+
dtype=noise.dtype
|
| 511 |
+
)
|
| 512 |
+
videos = []
|
| 513 |
+
vae_cache = copy.deepcopy(ZERO_VAE_CACHE)
|
| 514 |
+
for j in range(len(vae_cache)):
|
| 515 |
+
vae_cache[j] = None
|
| 516 |
+
# Set up profiling if requested
|
| 517 |
+
self.kv_cache1=self.kv_cache_keyboard=self.kv_cache_mouse=self.crossattn_cache=None
|
| 518 |
+
# Step 1: Initialize KV cache to all zeros
|
| 519 |
+
if self.kv_cache1 is None:
|
| 520 |
+
self._initialize_kv_cache(
|
| 521 |
+
batch_size=batch_size,
|
| 522 |
+
dtype=noise.dtype,
|
| 523 |
+
device=noise.device
|
| 524 |
+
)
|
| 525 |
+
self._initialize_kv_cache_mouse_and_keyboard(
|
| 526 |
+
batch_size=batch_size,
|
| 527 |
+
dtype=noise.dtype,
|
| 528 |
+
device=noise.device
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
self._initialize_crossattn_cache(
|
| 532 |
+
batch_size=batch_size,
|
| 533 |
+
dtype=noise.dtype,
|
| 534 |
+
device=noise.device
|
| 535 |
+
)
|
| 536 |
+
else:
|
| 537 |
+
# reset cross attn cache
|
| 538 |
+
for block_index in range(self.num_transformer_blocks):
|
| 539 |
+
self.crossattn_cache[block_index]["is_init"] = False
|
| 540 |
+
# reset kv cache
|
| 541 |
+
for block_index in range(len(self.kv_cache1)):
|
| 542 |
+
self.kv_cache1[block_index]["global_end_index"] = torch.tensor(
|
| 543 |
+
[0], dtype=torch.long, device=noise.device)
|
| 544 |
+
self.kv_cache1[block_index]["local_end_index"] = torch.tensor(
|
| 545 |
+
[0], dtype=torch.long, device=noise.device)
|
| 546 |
+
self.kv_cache_mouse[block_index]["global_end_index"] = torch.tensor(
|
| 547 |
+
[0], dtype=torch.long, device=noise.device)
|
| 548 |
+
self.kv_cache_mouse[block_index]["local_end_index"] = torch.tensor(
|
| 549 |
+
[0], dtype=torch.long, device=noise.device)
|
| 550 |
+
self.kv_cache_keyboard[block_index]["global_end_index"] = torch.tensor(
|
| 551 |
+
[0], dtype=torch.long, device=noise.device)
|
| 552 |
+
self.kv_cache_keyboard[block_index]["local_end_index"] = torch.tensor(
|
| 553 |
+
[0], dtype=torch.long, device=noise.device)
|
| 554 |
+
# Step 2: Cache context feature
|
| 555 |
+
current_start_frame = 0
|
| 556 |
+
if initial_latent is not None:
|
| 557 |
+
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
|
| 558 |
+
|
| 559 |
+
# Assume num_input_frames is self.num_frame_per_block * num_input_blocks
|
| 560 |
+
assert num_input_frames % self.num_frame_per_block == 0
|
| 561 |
+
num_input_blocks = num_input_frames // self.num_frame_per_block
|
| 562 |
+
|
| 563 |
+
for _ in range(num_input_blocks):
|
| 564 |
+
current_ref_latents = \
|
| 565 |
+
initial_latent[:, :, current_start_frame:current_start_frame + self.num_frame_per_block]
|
| 566 |
+
output[:, :, current_start_frame:current_start_frame + self.num_frame_per_block] = current_ref_latents
|
| 567 |
+
self.generator(
|
| 568 |
+
noisy_image_or_video=current_ref_latents,
|
| 569 |
+
conditional_dict=cond_current(conditional_dict, current_start_frame, self.num_frame_per_block, replace=True),
|
| 570 |
+
timestep=timestep * 0,
|
| 571 |
+
kv_cache=self.kv_cache1,
|
| 572 |
+
kv_cache_mouse=self.kv_cache_mouse,
|
| 573 |
+
kv_cache_keyboard=self.kv_cache_keyboard,
|
| 574 |
+
crossattn_cache=self.crossattn_cache,
|
| 575 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 576 |
+
)
|
| 577 |
+
current_start_frame += self.num_frame_per_block
|
| 578 |
+
|
| 579 |
+
# Step 3: Temporal denoising loop
|
| 580 |
+
all_num_frames = [self.num_frame_per_block] * num_blocks
|
| 581 |
+
|
| 582 |
+
for current_num_frames in all_num_frames:
|
| 583 |
+
noisy_input = noise[
|
| 584 |
+
:, :, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
|
| 585 |
+
|
| 586 |
+
current_actions = get_current_action(mode=mode)
|
| 587 |
+
new_act, conditional_dict = cond_current(conditional_dict, current_start_frame, self.num_frame_per_block, replace=current_actions, mode=mode)
|
| 588 |
+
# Step 3.1: Spatial denoising loop
|
| 589 |
+
|
| 590 |
+
for index, current_timestep in enumerate(self.denoising_step_list):
|
| 591 |
+
# set current timestep
|
| 592 |
+
timestep = torch.ones(
|
| 593 |
+
[batch_size, current_num_frames],
|
| 594 |
+
device=noise.device,
|
| 595 |
+
dtype=torch.int64) * current_timestep
|
| 596 |
+
|
| 597 |
+
if index < len(self.denoising_step_list) - 1:
|
| 598 |
+
_, denoised_pred = self.generator(
|
| 599 |
+
noisy_image_or_video=noisy_input,
|
| 600 |
+
conditional_dict=new_act,
|
| 601 |
+
timestep=timestep,
|
| 602 |
+
kv_cache=self.kv_cache1,
|
| 603 |
+
kv_cache_mouse=self.kv_cache_mouse,
|
| 604 |
+
kv_cache_keyboard=self.kv_cache_keyboard,
|
| 605 |
+
crossattn_cache=self.crossattn_cache,
|
| 606 |
+
current_start=current_start_frame * self.frame_seq_length
|
| 607 |
+
)
|
| 608 |
+
next_timestep = self.denoising_step_list[index + 1]
|
| 609 |
+
noisy_input = self.scheduler.add_noise(
|
| 610 |
+
rearrange(denoised_pred, 'b c f h w -> (b f) c h w'),# .flatten(0, 1),
|
| 611 |
+
torch.randn_like(rearrange(denoised_pred, 'b c f h w -> (b f) c h w')),
|
| 612 |
+
next_timestep * torch.ones(
|
| 613 |
+
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
|
| 614 |
+
)
|
| 615 |
+
noisy_input = rearrange(noisy_input, '(b f) c h w -> b c f h w', b=denoised_pred.shape[0])
|
| 616 |
+
else:
|
| 617 |
+
# for getting real output
|
| 618 |
+
_, denoised_pred = self.generator(
|
| 619 |
+
noisy_image_or_video=noisy_input,
|
| 620 |
+
conditional_dict=new_act,
|
| 621 |
+
timestep=timestep,
|
| 622 |
+
kv_cache=self.kv_cache1,
|
| 623 |
+
kv_cache_mouse=self.kv_cache_mouse,
|
| 624 |
+
kv_cache_keyboard=self.kv_cache_keyboard,
|
| 625 |
+
crossattn_cache=self.crossattn_cache,
|
| 626 |
+
current_start=current_start_frame * self.frame_seq_length
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
# Step 3.2: record the model's output
|
| 630 |
+
output[:, :, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
|
| 631 |
+
|
| 632 |
+
# Step 3.3: rerun with timestep zero to update KV cache using clean context
|
| 633 |
+
context_timestep = torch.ones_like(timestep) * self.args.context_noise
|
| 634 |
+
|
| 635 |
+
self.generator(
|
| 636 |
+
noisy_image_or_video=denoised_pred,
|
| 637 |
+
conditional_dict=new_act,
|
| 638 |
+
timestep=context_timestep,
|
| 639 |
+
kv_cache=self.kv_cache1,
|
| 640 |
+
kv_cache_mouse=self.kv_cache_mouse,
|
| 641 |
+
kv_cache_keyboard=self.kv_cache_keyboard,
|
| 642 |
+
crossattn_cache=self.crossattn_cache,
|
| 643 |
+
current_start=current_start_frame * self.frame_seq_length,
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
# Step 3.4: update the start and end frame indices
|
| 647 |
+
denoised_pred = denoised_pred.transpose(1,2)
|
| 648 |
+
video, vae_cache = self.vae_decoder(denoised_pred.half(), *vae_cache)
|
| 649 |
+
videos += [video]
|
| 650 |
+
video = rearrange(video, "B T C H W -> B T H W C")
|
| 651 |
+
video = ((video.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)[0]
|
| 652 |
+
video = np.ascontiguousarray(video)
|
| 653 |
+
mouse_icon = 'assets/images/mouse.png'
|
| 654 |
+
if mode != 'templerun':
|
| 655 |
+
config = (
|
| 656 |
+
conditional_dict["keyboard_cond"][0, : 1 + 4 * (current_start_frame + self.num_frame_per_block-1)].float().cpu().numpy(),
|
| 657 |
+
conditional_dict["mouse_cond"][0, : 1 + 4 * (current_start_frame + self.num_frame_per_block-1)].float().cpu().numpy(),
|
| 658 |
+
)
|
| 659 |
+
else:
|
| 660 |
+
config = (
|
| 661 |
+
conditional_dict["keyboard_cond"][0, : 1 + 4 * (current_start_frame + self.num_frame_per_block-1)].float().cpu().numpy()
|
| 662 |
+
)
|
| 663 |
+
process_video(video.astype(np.uint8), output_folder+f'/{name}_current.mp4', config, mouse_icon, mouse_scale=0.1, process_icon=False, mode=mode)
|
| 664 |
+
current_start_frame += current_num_frames
|
| 665 |
+
|
| 666 |
+
if input("Continue? (Press `n` to break)").strip() == "n":
|
| 667 |
+
break
|
| 668 |
+
|
| 669 |
+
videos_tensor = torch.cat(videos, dim=1)
|
| 670 |
+
videos = rearrange(videos_tensor, "B T C H W -> B T H W C")
|
| 671 |
+
videos = ((videos.float() + 1) * 127.5).clip(0, 255).cpu().numpy().astype(np.uint8)[0]
|
| 672 |
+
video = np.ascontiguousarray(videos)
|
| 673 |
+
mouse_icon = 'assets/images/mouse.png'
|
| 674 |
+
if mode != 'templerun':
|
| 675 |
+
config = (
|
| 676 |
+
conditional_dict["keyboard_cond"][0, : 1 + 4 * (current_start_frame + self.num_frame_per_block-1)].float().cpu().numpy(),
|
| 677 |
+
conditional_dict["mouse_cond"][0, : 1 + 4 * (current_start_frame + self.num_frame_per_block-1)].float().cpu().numpy(),
|
| 678 |
+
)
|
| 679 |
+
else:
|
| 680 |
+
config = (
|
| 681 |
+
conditional_dict["keyboard_cond"][0, : 1 + 4 * (current_start_frame + self.num_frame_per_block-1)].float().cpu().numpy()
|
| 682 |
+
)
|
| 683 |
+
process_video(video.astype(np.uint8), output_folder+f'/{name}_icon.mp4', config, mouse_icon, mouse_scale=0.1, mode=mode)
|
| 684 |
+
process_video(video.astype(np.uint8), output_folder+f'/{name}.mp4', config, mouse_icon, mouse_scale=0.1, process_icon=False, mode=mode)
|
| 685 |
+
|
| 686 |
+
if return_latents:
|
| 687 |
+
return output
|
| 688 |
+
else:
|
| 689 |
+
return video
|
| 690 |
+
|
| 691 |
+
def _initialize_kv_cache(self, batch_size, dtype, device):
|
| 692 |
+
"""
|
| 693 |
+
Initialize a Per-GPU KV cache for the Wan model.
|
| 694 |
+
"""
|
| 695 |
+
kv_cache1 = []
|
| 696 |
+
if self.local_attn_size != -1:
|
| 697 |
+
# Use the local attention size to compute the KV cache size
|
| 698 |
+
kv_cache_size = self.local_attn_size * self.frame_seq_length
|
| 699 |
+
else:
|
| 700 |
+
# Use the default KV cache size
|
| 701 |
+
kv_cache_size = 15 * 1 * self.frame_seq_length # 32760
|
| 702 |
+
|
| 703 |
+
for _ in range(self.num_transformer_blocks):
|
| 704 |
+
kv_cache1.append({
|
| 705 |
+
"k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
| 706 |
+
"v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
| 707 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
| 708 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
| 709 |
+
})
|
| 710 |
+
|
| 711 |
+
self.kv_cache1 = kv_cache1 # always store the clean cache
|
| 712 |
+
|
| 713 |
+
def _initialize_kv_cache_mouse_and_keyboard(self, batch_size, dtype, device):
|
| 714 |
+
"""
|
| 715 |
+
Initialize a Per-GPU KV cache for the Wan model.
|
| 716 |
+
"""
|
| 717 |
+
kv_cache_mouse = []
|
| 718 |
+
kv_cache_keyboard = []
|
| 719 |
+
if self.local_attn_size != -1:
|
| 720 |
+
kv_cache_size = self.local_attn_size
|
| 721 |
+
else:
|
| 722 |
+
kv_cache_size = 15 * 1
|
| 723 |
+
for _ in range(self.num_transformer_blocks):
|
| 724 |
+
kv_cache_keyboard.append({
|
| 725 |
+
"k": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
|
| 726 |
+
"v": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
|
| 727 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
| 728 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
| 729 |
+
})
|
| 730 |
+
kv_cache_mouse.append({
|
| 731 |
+
"k": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
|
| 732 |
+
"v": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
|
| 733 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
| 734 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
| 735 |
+
})
|
| 736 |
+
self.kv_cache_keyboard = kv_cache_keyboard # always store the clean cache
|
| 737 |
+
self.kv_cache_mouse = kv_cache_mouse # always store the clean cache
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
def _initialize_crossattn_cache(self, batch_size, dtype, device):
|
| 742 |
+
"""
|
| 743 |
+
Initialize a Per-GPU cross-attention cache for the Wan model.
|
| 744 |
+
"""
|
| 745 |
+
crossattn_cache = []
|
| 746 |
+
|
| 747 |
+
for _ in range(self.num_transformer_blocks):
|
| 748 |
+
crossattn_cache.append({
|
| 749 |
+
"k": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
|
| 750 |
+
"v": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
|
| 751 |
+
"is_init": False
|
| 752 |
+
})
|
| 753 |
+
self.crossattn_cache = crossattn_cache
|
requirements.txt
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.4.0
|
| 2 |
+
torchvision>=0.19.0
|
| 3 |
+
opencv-python>=4.9.0.80
|
| 4 |
+
diffusers
|
| 5 |
+
transformers>=4.49.0
|
| 6 |
+
tokenizers>=0.20.3
|
| 7 |
+
accelerate>=1.1.1
|
| 8 |
+
tqdm
|
| 9 |
+
imageio
|
| 10 |
+
easydict
|
| 11 |
+
ftfy
|
| 12 |
+
dashscope
|
| 13 |
+
imageio-ffmpeg
|
| 14 |
+
numpy
|
| 15 |
+
wandb
|
| 16 |
+
omegaconf
|
| 17 |
+
einops
|
| 18 |
+
av
|
| 19 |
+
safetensors
|
| 20 |
+
opencv-python
|
| 21 |
+
git+https://github.com/openai/CLIP.git
|
| 22 |
+
open_clip_torch
|
| 23 |
+
starlette
|
| 24 |
+
pycocotools
|
| 25 |
+
lmdb
|
| 26 |
+
matplotlib
|
| 27 |
+
sentencepiece
|
| 28 |
+
pydantic
|
| 29 |
+
scikit-image
|
| 30 |
+
huggingface_hub[cli]
|
| 31 |
+
dominate
|
| 32 |
+
nvidia-pyindex
|
| 33 |
+
nvidia-tensorrt
|
| 34 |
+
pycuda
|
| 35 |
+
onnx
|
| 36 |
+
onnxruntime
|
| 37 |
+
onnxscript
|
| 38 |
+
onnxconverter_common
|
| 39 |
+
flask
|
| 40 |
+
flask-socketio
|
| 41 |
+
torchao
|
setup.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
setup(
|
| 3 |
+
name="matrix-game-2.0",
|
| 4 |
+
version="0.0.1",
|
| 5 |
+
packages=find_packages(),
|
| 6 |
+
)
|