nielsr HF Staff commited on
Commit
13eaa41
ยท
verified ยท
1 Parent(s): bc0cab7

Update model card for PSA: Pyramid Sparse Attention

Browse files

This PR completely updates the model card to reflect the "PSA: Pyramid Sparse Attention for Efficient Video Understanding and Generation" paper (https://huggingface.co/papers/2512.04025).

It replaces the existing, unrelated content with accurate information, including:
- The correct `pipeline_tag: text-to-video` for better discoverability.
- The `library_name: diffusers` to enable the "how to use" widget, based on compatibility with underlying models.
- Links to the paper, project page, and GitHub repository.
- A concise description of the model's core contributions.
- Installation steps, weight download instructions, and quick start inference examples directly from the official GitHub repository.

Please review and merge if everything looks good!

Files changed (1) hide show
  1. README.md +47 -247
README.md CHANGED
@@ -1,288 +1,88 @@
1
  ---
2
  license: apache-2.0
 
 
3
  ---
4
- # BLADE: Block-Sparse Attention Meets Step Distillation for Efficient Video Generation
5
 
6
- <div align="center">
7
 
8
- [๐Ÿ“– Paper](https://arxiv.org/abs/2508.10774) | [๐Ÿš€ Homepage](http://ziplab.co/BLADE-Homepage/) | [๐Ÿ’พ Models](https://huggingface.co/GYP666/BLADE) | [๐Ÿ“– ไธญๆ–‡้˜…่ฏป](README_zh.md)
9
 
10
- </div>
11
 
12
- BLADE is a data-free framework for efficient video generation. By jointly training an adaptive sparse attention mechanism with a step distillation technique, it achieves a significant acceleration in video generation models. This project combines a block-sparse attention mechanism with step distillation, reducing the number of inference steps from 50 to just 8 while maintaining high-quality generation.
 
 
13
 
14
- ## ๐Ÿ“ข News
15
 
16
- - **[Aug 2025]** ๐ŸŽ‰ The code and pre-trained models for BLADE have been released\!
17
- - **[Aug 2025]** ๐Ÿ“ Support for two mainstream video generation models, CogVideoX-5B and WanX-1.3B, is now available.
18
- - **[Aug 2025]** โšก Achieved high-quality video generation in just 8 steps, a significant speedup compared to the 50-step baseline.
19
 
20
- ## โœจ Key Features
21
 
22
- - ๐Ÿš€ **Efficient Inference**: Reduces the number of inference steps from 50 to 8 while preserving generation quality.
23
- - ๐ŸŽฏ **Adaptive Sparse Attention**: Employs a block-sparse attention mechanism to significantly reduce computational complexity.
24
- - ๐Ÿ“ˆ **Step Distillation**: Utilizes the Trajectory Distillation Method (TDM), enabling training without the need for video data.
25
- - ๐ŸŽฎ **Plug-and-Play**: Supports CogVideoX-5B and WanX-1.3B models without requiring modifications to their original architectures.
26
 
27
- ## ๐Ÿ› ๏ธ Environment Setup
28
-
29
- ### System Requirements
30
-
31
- - Python \>= 3.11 (Recommended)
32
- - CUDA \>= 11.6 (Recommended)
33
- - GPU Memory \>= 24GB (for Inference)
34
- - GPU Memory \>= 80GB (for Training)
35
-
36
- ### Installation Steps
37
-
38
- 1. **Clone the repository**
39
-
40
- ```bash
41
- git clone https://github.com/Tacossp/BLADE
42
- cd BLADE
43
- ```
44
-
45
- 2. **Install dependencies**
46
-
47
- ```bash
48
- # Install using uv (Recommended)
49
- uv pip install -r requirements.txt
50
-
51
- # Or use pip
52
- pip install -r requirements.txt
53
- ```
54
-
55
- 3. **Compile the Block-Sparse-Attention library**
56
-
57
- ```bash
58
- git clone https://github.com/mit-han-lab/Block-Sparse-Attention.git
59
- cd Block-Sparse-Attention
60
- pip install packaging
61
- pip install ninja
62
- python setup.py install
63
- cd ..
64
- ```
65
-
66
- ## ๐Ÿ“ฅ Model Weights Download
67
-
68
- ### Base Model Weights
69
-
70
- Please download the following base model weights and place them in the specified directories:
71
-
72
- 1. **CogVideoX-5B Model**
73
-
74
- ```bash
75
- # Download from Hugging Face
76
- git lfs install
77
- git clone https://huggingface.co/zai-org/CogVideoX-5b cogvideox/CogVideoX-5b
78
- ```
79
-
80
- 2. **WanX-1.3B Model**
81
-
82
- ```bash
83
- # Download from Hugging Face
84
- git clone https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers wanx/wan1.3b
85
- ```
86
-
87
- ### Pre-trained BLADE Weights
88
-
89
- We provide pre-trained weights for BLADE:
90
 
91
  ```bash
92
- # Download pre-trained weights
93
- git clone https://huggingface.co/GYP666/BLADE pretrained_weights
 
94
  ```
95
 
96
- ### Weight Directory Structure
97
-
98
- Ensure your directory structure for weights is as follows:
99
-
100
- ```
101
- BLADE/
102
- โ”œโ”€โ”€ cogvideox/
103
- โ”‚ โ””โ”€โ”€ CogVideoX-5b/ # Base model weights for CogVideoX
104
- โ”œโ”€โ”€ wanx/
105
- โ”‚ โ””โ”€โ”€ wan1.3b/ # Base model weights for WanX
106
- โ””โ”€โ”€ pretrained_weights/ # Pre-trained weights for BLADE
107
- โ”œโ”€โ”€ BLADE_cogvideox_weight/
108
- โ””โ”€โ”€ BLADE_wanx_weight/
109
- ```
110
-
111
- ## ๐Ÿš€ Quick Start - Inference
112
-
113
- ### CogVideoX Inference
114
 
115
  ```bash
116
- cd cogvideox
117
- python train/inference.py \
118
- --lora_path ../pretrained_weights/cogvideox_checkpoints/your_checkpoint \
119
- --gpu 0
120
  ```
121
 
122
- **Argument Descriptions**:
123
-
124
- - `--lora_path`: Path to the LoRA weights file.
125
- - `--gpu`: The ID of the GPU device to use (Default: 0).
126
 
127
- **Output**: The generated videos will be saved in the `cogvideox/outputs/inference/` directory.
128
 
129
- ### WanX Inference
130
 
131
  ```bash
132
- cd wanx
133
- python train/inference.py \
134
- --lora_path ../pretrained_weights/wanx_checkpoints/your_checkpoint \
135
- --gpu 0
136
  ```
137
 
138
- **Output**: The generated videos will be saved in the `wanx/outputs/` directory.
139
-
140
- ## ๐Ÿ”ง Training Process
141
 
142
- ### Step 1: Prompt Preprocessing
143
 
144
- Before training, you need to preprocess the text prompts to generate embeddings.
145
-
146
- #### CogVideoX Preprocessing
147
-
148
- ```bash
149
- cd utils
150
- python process_prompts_cogvideox.py \
151
- --input_file your_prompts.txt \
152
- --output_dir ../cogvideox/prompts \
153
- --model_path ../cogvideox/CogVideoX-5b \
154
- --batch_size 32 \
155
- --save_separate
156
- ```
157
-
158
- **Argument Descriptions**:
159
-
160
- - `--input_file`: A `.txt` file containing prompts, with one prompt per line.
161
- - `--output_dir`: The directory to save the output embeddings.
162
- - `--model_path`: Path to the CogVideoX model.
163
- - `--batch_size`: The batch size for processing.
164
- - `--save_separate`: Whether to save each embedding as a separate file.
165
-
166
- #### WanX Preprocessing
167
-
168
- ```bash
169
- cd utils
170
- python process_prompts_wanx.py
171
- ```
172
-
173
- This script will automatically process the prompts in `utils/all_dimension_aug_wanx.txt` and generate the corresponding embeddings.
174
-
175
- ### Step 2: Start Training
176
-
177
- #### CogVideoX Training
178
 
179
  ```bash
180
- cd cogvideox
181
- bash train_tdm_1.sh
 
 
182
  ```
183
 
184
- **Core Training Parameters**:
185
 
186
  ```bash
187
- # If not training with 8 GPUs, you must modify CUDA_VISIBLE_DEVICES and the num_processes in config.yaml
188
- CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
189
- --config_file train/config.yaml \
190
- train/train_cogvideo_tdm.py \
191
- --pretrained_model_name_or_path CogVideoX-5b \ # Path to the base model
192
- --mixed_precision bf16 \ # Use mixed-precision for reduced memory usage
193
- --train_batch_size 5 \ # Training batch size
194
- --gradient_accumulation_steps 4 \ # Number of gradient accumulation steps
195
- --learning_rate 1e-4 \ # Learning rate for the student model
196
- --learning_rate_g 1e-4 \
197
- --learning_rate_fake 5e-4 \ # Learning rate for the fake model
198
- --lambda_reg 0.5 \ # Regularization weight
199
- --k_step 8 \ # Target number of steps for distillation
200
- --cfg 3.5 \ # Classifier-Free Guidance scale
201
- --eta 0.9 \ # ETA parameter for DDIM
202
- --use_sparsity true \ # Enable sparse attention
203
- --rank 64 \
204
- --lora_alpha 64 \ # LoRA configuration
205
- --max_train_steps 300 \ # Maximum number of training steps
206
- --checkpointing_steps 15 \ # Interval for saving checkpoints
207
- --gradient_checkpointing \ # Use gradient checkpointing to save memory
208
- --enable_slicing \
209
- --enable_tiling # VAE memory optimization
210
  ```
211
 
212
- #### WanX Training
213
 
214
- ```bash
215
- cd wanx
216
- bash train_wanx_tdm.sh
217
- ```
218
-
219
- ## ๐Ÿ“Š Project Structure
220
-
221
- ```
222
- BLADE/
223
- โ”œโ”€โ”€ README.md # Project documentation
224
- โ”œโ”€โ”€ requirements.txt # List of Python dependencies
225
- โ”‚
226
- โ”œโ”€โ”€ cogvideox/ # Code related to CogVideoX
227
- โ”‚ โ”œโ”€โ”€ CogVideoX-5b/ # Directory for base model weights
228
- โ”‚ โ”œโ”€โ”€ train/ # Training scripts
229
- โ”‚ โ”‚ โ”œโ”€โ”€ inference.py # Inference script
230
- โ”‚ โ”‚ โ”œโ”€โ”€ train_cogvideo_tdm.py # Training script
231
- โ”‚ โ”‚ โ”œโ”€โ”€ train_tdm_1.sh # Script to launch training
232
- โ”‚ โ”‚ โ”œโ”€โ”€ modify_cogvideo.py # Model modification script
233
- โ”‚ โ”‚ โ””โ”€โ”€ config.yaml # Training configuration file
234
- โ”‚ โ”œโ”€โ”€ prompts/ # Preprocessed prompts and embeddings
235
- โ”‚ โ””โ”€โ”€ outputs/ # Output from training and inference
236
- โ”‚
237
- โ”œโ”€โ”€ wanx/ # Code related to WanX
238
- โ”‚ โ”œโ”€โ”€ wan1.3b/ # Directory for base model weights
239
- โ”‚ โ”œโ”€โ”€ train/ # Training scripts
240
- โ”‚ โ”‚ โ”œโ”€โ”€ inference.py # Inference script
241
- โ”‚ โ”‚ โ”œโ”€โ”€ train_wanx_tdm.py # Training script
242
- โ”‚ โ”‚ โ”œโ”€โ”€ train_wanx_tdm.sh # Script to launch training
243
- โ”‚ โ”‚ โ””โ”€โ”€ modify_wan.py # Model modification script
244
- โ”‚ โ”œโ”€โ”€ prompts/ # Preprocessed prompts and embeddings
245
- โ”‚ โ””โ”€โ”€ outputs/ # Output from training and inference
246
- โ”‚
247
- โ”œโ”€โ”€ utils/ # Utility scripts
248
- โ”‚ โ”œโ”€โ”€ process_prompts_cogvideox.py # Data preprocessing for CogVideoX
249
- โ”‚ โ”œโ”€โ”€ process_prompts_wanx.py # Data preprocessing for WanX
250
- โ”‚ โ””โ”€โ”€ all_dimension_aug_wanx.txt # Training prompts for WanX
251
- โ”‚
252
- โ”œโ”€โ”€ Block-Sparse-Attention/ # Sparse attention library
253
- โ”‚ โ”œโ”€โ”€ setup.py # Compilation and installation script
254
- โ”‚ โ”œโ”€โ”€ block_sparse_attn/ # Core library code
255
- โ”‚ โ””โ”€โ”€ README.md # Library usage instructions
256
- โ”‚
257
- โ””โ”€โ”€ ds_config.json # DeepSpeed configuration file
258
- ```
259
-
260
- ## ๐Ÿค Acknowledgements
261
 
262
- - [FlashAttention](https://github.com/Dao-AILab/flash-attention), [Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention): For the foundational work on sparse attention.
263
- - [CogVideoX](https://github.com/THUDM/CogVideo), [Wan2.1](https://github.com/Wan-Video/Wan2.1): For the supported models.
264
- - [TDM](https://www.google.com/search?q=https://github.com/Luo-Yihong/TDM): For the foundational work on distillation implementation.
265
- - [Diffusers](https://github.com/huggingface/diffusers): For the invaluable diffusion models library.
266
-
267
- ## ๐Ÿ“„ Citation
268
-
269
- If you use BLADE in your research, please cite our work:
270
 
271
  ```bibtex
272
- @misc{gu2025videobladeblocksparseattentionmeets,
273
- title={BLADE: Block-Sparse Attention Meets Step Distillation for Efficient Video Generation},
274
- author={Youping Gu and Xiaolong Li and Yuhao Hu and Bohan Zhuang},
275
- year={2025},
276
- eprint={2508.10774},
277
- archivePrefix={arXiv},
278
- primaryClass={cs.CV},
279
- url={https://arxiv.org/abs/2508.10774},
280
  }
281
- ```
282
-
283
- ## ๐Ÿ“ง Contact
284
-
285
- For any questions or suggestions, feel free to:
286
-
287
- - Contact Youping Gu at [email protected].
288
- - Submit an issue on our [Github page](https://github.com/ziplab/BLADE/issues).
 
1
  ---
2
  license: apache-2.0
3
+ pipeline_tag: text-to-video
4
+ library_name: diffusers
5
  ---
 
6
 
7
+ # PSA: Pyramid Sparse Attention for Efficient Video Understanding and Generation
8
 
9
+ [๐Ÿ“– Paper](https://huggingface.co/papers/2512.04025) | [๐Ÿš€ Project Page](http://ziplab.co/PSA) | [๐Ÿ’ป Code](https://github.com/ziplab/Pyramid-Sparse-Attention)
10
 
11
+ Official PyTorch implementation of [PSA: Pyramid Sparse Attention for Efficient Video Understanding and Generation](https://huggingface.co/papers/2512.04025).
12
 
13
+ <p align="center">
14
+ <img src="https://github.com/ziplab/Pyramid-Sparse-Attention/raw/main/figures/prompt007comparison.jpg" width="100%">
15
+ </p>
16
 
17
+ <p align="center"><em>Visual comparison of sparse attention methods at similar sparsity levels (~90%). PSA maintains visual fidelity close to full attention while other methods show noticeable artifacts.</em></p>
18
 
19
+ Pyramid Sparse Attention (PSA) is a versatile attention module designed to overcome the quadratic complexity bottleneck of attention mechanisms in foundation models. It introduces multi-level pooled Key-Value (KV) representations, enabling a finer mask granularity than traditional binary masking approaches. This design allows critical KV blocks to receive full resolution attention while less important blocks utilize progressively pooled representations, creating an informative interpolation between full retention and complete pruning. This approach effectively mitigates information loss and preserves computational efficiency. PSA is applicable to both video understanding and generation tasks, consistently outperforming or achieving comparable performance to existing sparse attention baselines with superior efficiency-quality trade-offs.
 
 
20
 
21
+ > **Note:** This release focuses on **inference-only** with **bidirectional attention**. Support for causal attention masks and backward propagation (training) is still under optimization and will be released in a future update.
22
 
23
+ ## Installation
 
 
 
24
 
25
+ ### Using uv (Recommended)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  ```bash
28
+ uv venv --python 3.11
29
+ source .venv/bin/activate
30
+ uv pip install -e .
31
  ```
32
 
33
+ ### Using pip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  ```bash
36
+ python -m venv .venv
37
+ source .venv/bin/activate
38
+ pip install -e .
 
39
  ```
40
 
41
+ > For best performance, we recommend using PyTorch nightly version.
 
 
 
42
 
43
+ ## Download Weights
44
 
45
+ ### CogVideoX-5B LoRA (4-step)
46
 
47
  ```bash
48
+ huggingface-cli download GYP666/BLADE cogvideox-5b-psa-lora/pytorch_lora_weights.safetensors --local-dir ./weights
 
 
 
49
  ```
50
 
51
+ **Note:** After downloading, update the `lora_path` in `examples/configs/model_configs.py` to point to your weights directory.
 
 
52
 
53
+ ## Quick Start (Inference)
54
 
55
+ ### CogVideoX1.5-5B
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  ```bash
58
+ python examples/inference/cogvideo/cogvideo_5b.py \
59
+ --model cogvideo1.5_5b \
60
+ --prompt "your prompt here" \
61
+ --use_psa
62
  ```
63
 
64
+ ### Wan2.1-1.3B
65
 
66
  ```bash
67
+ python examples/inference/wan21/wan21_1.3b.py \
68
+ --prompt "your prompt here" \
69
+ --use_psa --no_warmup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  ```
71
 
72
+ For more inference examples, see [examples/README.md](https://github.com/ziplab/Pyramid-Sparse-Attention/blob/main/examples/README.md).
73
 
74
+ ## Citation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ If you find this work useful, please cite our paper:
 
 
 
 
 
 
 
77
 
78
  ```bibtex
79
+ @misc{li2025psapyramidsparseattention,
80
+ title={PSA: Pyramid Sparse Attention for Efficient Video Understanding and Generation},
81
+ author={Xiaolong Li and Youping Gu and Xi Lin and Weijie Wang and Bohan Zhuang},
82
+ year={2025},
83
+ eprint={2512.04025},
84
+ archivePrefix={arXiv},
85
+ primaryClass={cs.CV},
86
+ url={https://arxiv.org/abs/2512.04025},
87
  }
88
+ ```