|
|
--- |
|
|
license: apache-2.0 |
|
|
--- |
|
|
# BLADE: Block-Sparse Attention Meets Step Distillation for Efficient Video Generation |
|
|
|
|
|
<div align="center"> |
|
|
|
|
|
[๐ Paper](https://arxiv.org/abs/2508.10774) | [๐ Homepage](http://ziplab.co/BLADE-Homepage/) | [๐พ Models](https://huggingface.co/GYP666/BLADE) | [๐ ไธญๆ้
่ฏป](README_zh.md) |
|
|
|
|
|
</div> |
|
|
|
|
|
BLADE is a data-free framework for efficient video generation. By jointly training an adaptive sparse attention mechanism with a step distillation technique, it achieves a significant acceleration in video generation models. This project combines a block-sparse attention mechanism with step distillation, reducing the number of inference steps from 50 to just 8 while maintaining high-quality generation. |
|
|
|
|
|
## ๐ข News |
|
|
|
|
|
- **[Aug 2025]** ๐ The code and pre-trained models for BLADE have been released\! |
|
|
- **[Aug 2025]** ๐ Support for two mainstream video generation models, CogVideoX-5B and WanX-1.3B, is now available. |
|
|
- **[Aug 2025]** โก Achieved high-quality video generation in just 8 steps, a significant speedup compared to the 50-step baseline. |
|
|
|
|
|
## โจ Key Features |
|
|
|
|
|
- ๐ **Efficient Inference**: Reduces the number of inference steps from 50 to 8 while preserving generation quality. |
|
|
- ๐ฏ **Adaptive Sparse Attention**: Employs a block-sparse attention mechanism to significantly reduce computational complexity. |
|
|
- ๐ **Step Distillation**: Utilizes the Trajectory Distillation Method (TDM), enabling training without the need for video data. |
|
|
- ๐ฎ **Plug-and-Play**: Supports CogVideoX-5B and WanX-1.3B models without requiring modifications to their original architectures. |
|
|
|
|
|
## ๐ ๏ธ Environment Setup |
|
|
|
|
|
### System Requirements |
|
|
|
|
|
- Python \>= 3.11 (Recommended) |
|
|
- CUDA \>= 11.6 (Recommended) |
|
|
- GPU Memory \>= 24GB (for Inference) |
|
|
- GPU Memory \>= 80GB (for Training) |
|
|
|
|
|
### Installation Steps |
|
|
|
|
|
1. **Clone the repository** |
|
|
|
|
|
```bash |
|
|
git clone https://github.com/Tacossp/BLADE |
|
|
cd BLADE |
|
|
``` |
|
|
|
|
|
2. **Install dependencies** |
|
|
|
|
|
```bash |
|
|
# Install using uv (Recommended) |
|
|
uv pip install -r requirements.txt |
|
|
|
|
|
# Or use pip |
|
|
pip install -r requirements.txt |
|
|
``` |
|
|
|
|
|
3. **Compile the Block-Sparse-Attention library** |
|
|
|
|
|
```bash |
|
|
git clone https://github.com/mit-han-lab/Block-Sparse-Attention.git |
|
|
cd Block-Sparse-Attention |
|
|
pip install packaging |
|
|
pip install ninja |
|
|
python setup.py install |
|
|
cd .. |
|
|
``` |
|
|
|
|
|
## ๐ฅ Model Weights Download |
|
|
|
|
|
### Base Model Weights |
|
|
|
|
|
Please download the following base model weights and place them in the specified directories: |
|
|
|
|
|
1. **CogVideoX-5B Model** |
|
|
|
|
|
```bash |
|
|
# Download from Hugging Face |
|
|
git lfs install |
|
|
git clone https://huggingface.co/zai-org/CogVideoX-5b cogvideox/CogVideoX-5b |
|
|
``` |
|
|
|
|
|
2. **WanX-1.3B Model** |
|
|
|
|
|
```bash |
|
|
# Download from Hugging Face |
|
|
git clone https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers wanx/wan1.3b |
|
|
``` |
|
|
|
|
|
### Pre-trained BLADE Weights |
|
|
|
|
|
We provide pre-trained weights for BLADE: |
|
|
|
|
|
```bash |
|
|
# Download pre-trained weights |
|
|
git clone https://huggingface.co/GYP666/BLADE pretrained_weights |
|
|
``` |
|
|
|
|
|
### Weight Directory Structure |
|
|
|
|
|
Ensure your directory structure for weights is as follows: |
|
|
|
|
|
``` |
|
|
BLADE/ |
|
|
โโโ cogvideox/ |
|
|
โ โโโ CogVideoX-5b/ # Base model weights for CogVideoX |
|
|
โโโ wanx/ |
|
|
โ โโโ wan1.3b/ # Base model weights for WanX |
|
|
โโโ pretrained_weights/ # Pre-trained weights for BLADE |
|
|
โโโ BLADE_cogvideox_weight/ |
|
|
โโโ BLADE_wanx_weight/ |
|
|
``` |
|
|
|
|
|
## ๐ Quick Start - Inference |
|
|
|
|
|
### CogVideoX Inference |
|
|
|
|
|
```bash |
|
|
cd cogvideox |
|
|
python train/inference.py \ |
|
|
--lora_path ../pretrained_weights/cogvideox_checkpoints/your_checkpoint \ |
|
|
--gpu 0 |
|
|
``` |
|
|
|
|
|
**Argument Descriptions**: |
|
|
|
|
|
- `--lora_path`: Path to the LoRA weights file. |
|
|
- `--gpu`: The ID of the GPU device to use (Default: 0). |
|
|
|
|
|
**Output**: The generated videos will be saved in the `cogvideox/outputs/inference/` directory. |
|
|
|
|
|
### WanX Inference |
|
|
|
|
|
```bash |
|
|
cd wanx |
|
|
python train/inference.py \ |
|
|
--lora_path ../pretrained_weights/wanx_checkpoints/your_checkpoint \ |
|
|
--gpu 0 |
|
|
``` |
|
|
|
|
|
**Output**: The generated videos will be saved in the `wanx/outputs/` directory. |
|
|
|
|
|
## ๐ง Training Process |
|
|
|
|
|
### Step 1: Prompt Preprocessing |
|
|
|
|
|
Before training, you need to preprocess the text prompts to generate embeddings. |
|
|
|
|
|
#### CogVideoX Preprocessing |
|
|
|
|
|
```bash |
|
|
cd utils |
|
|
python process_prompts_cogvideox.py \ |
|
|
--input_file your_prompts.txt \ |
|
|
--output_dir ../cogvideox/prompts \ |
|
|
--model_path ../cogvideox/CogVideoX-5b \ |
|
|
--batch_size 32 \ |
|
|
--save_separate |
|
|
``` |
|
|
|
|
|
**Argument Descriptions**: |
|
|
|
|
|
- `--input_file`: A `.txt` file containing prompts, with one prompt per line. |
|
|
- `--output_dir`: The directory to save the output embeddings. |
|
|
- `--model_path`: Path to the CogVideoX model. |
|
|
- `--batch_size`: The batch size for processing. |
|
|
- `--save_separate`: Whether to save each embedding as a separate file. |
|
|
|
|
|
#### WanX Preprocessing |
|
|
|
|
|
```bash |
|
|
cd utils |
|
|
python process_prompts_wanx.py |
|
|
``` |
|
|
|
|
|
This script will automatically process the prompts in `utils/all_dimension_aug_wanx.txt` and generate the corresponding embeddings. |
|
|
|
|
|
### Step 2: Start Training |
|
|
|
|
|
#### CogVideoX Training |
|
|
|
|
|
```bash |
|
|
cd cogvideox |
|
|
bash train_tdm_1.sh |
|
|
``` |
|
|
|
|
|
**Core Training Parameters**: |
|
|
|
|
|
```bash |
|
|
# If not training with 8 GPUs, you must modify CUDA_VISIBLE_DEVICES and the num_processes in config.yaml |
|
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \ |
|
|
--config_file train/config.yaml \ |
|
|
train/train_cogvideo_tdm.py \ |
|
|
--pretrained_model_name_or_path CogVideoX-5b \ # Path to the base model |
|
|
--mixed_precision bf16 \ # Use mixed-precision for reduced memory usage |
|
|
--train_batch_size 5 \ # Training batch size |
|
|
--gradient_accumulation_steps 4 \ # Number of gradient accumulation steps |
|
|
--learning_rate 1e-4 \ # Learning rate for the student model |
|
|
--learning_rate_g 1e-4 \ |
|
|
--learning_rate_fake 5e-4 \ # Learning rate for the fake model |
|
|
--lambda_reg 0.5 \ # Regularization weight |
|
|
--k_step 8 \ # Target number of steps for distillation |
|
|
--cfg 3.5 \ # Classifier-Free Guidance scale |
|
|
--eta 0.9 \ # ETA parameter for DDIM |
|
|
--use_sparsity true \ # Enable sparse attention |
|
|
--rank 64 \ |
|
|
--lora_alpha 64 \ # LoRA configuration |
|
|
--max_train_steps 300 \ # Maximum number of training steps |
|
|
--checkpointing_steps 15 \ # Interval for saving checkpoints |
|
|
--gradient_checkpointing \ # Use gradient checkpointing to save memory |
|
|
--enable_slicing \ |
|
|
--enable_tiling # VAE memory optimization |
|
|
``` |
|
|
|
|
|
#### WanX Training |
|
|
|
|
|
```bash |
|
|
cd wanx |
|
|
bash train_wanx_tdm.sh |
|
|
``` |
|
|
|
|
|
## ๐ Project Structure |
|
|
|
|
|
``` |
|
|
BLADE/ |
|
|
โโโ README.md # Project documentation |
|
|
โโโ requirements.txt # List of Python dependencies |
|
|
โ |
|
|
โโโ cogvideox/ # Code related to CogVideoX |
|
|
โ โโโ CogVideoX-5b/ # Directory for base model weights |
|
|
โ โโโ train/ # Training scripts |
|
|
โ โ โโโ inference.py # Inference script |
|
|
โ โ โโโ train_cogvideo_tdm.py # Training script |
|
|
โ โ โโโ train_tdm_1.sh # Script to launch training |
|
|
โ โ โโโ modify_cogvideo.py # Model modification script |
|
|
โ โ โโโ config.yaml # Training configuration file |
|
|
โ โโโ prompts/ # Preprocessed prompts and embeddings |
|
|
โ โโโ outputs/ # Output from training and inference |
|
|
โ |
|
|
โโโ wanx/ # Code related to WanX |
|
|
โ โโโ wan1.3b/ # Directory for base model weights |
|
|
โ โโโ train/ # Training scripts |
|
|
โ โ โโโ inference.py # Inference script |
|
|
โ โ โโโ train_wanx_tdm.py # Training script |
|
|
โ โ โโโ train_wanx_tdm.sh # Script to launch training |
|
|
โ โ โโโ modify_wan.py # Model modification script |
|
|
โ โโโ prompts/ # Preprocessed prompts and embeddings |
|
|
โ โโโ outputs/ # Output from training and inference |
|
|
โ |
|
|
โโโ utils/ # Utility scripts |
|
|
โ โโโ process_prompts_cogvideox.py # Data preprocessing for CogVideoX |
|
|
โ โโโ process_prompts_wanx.py # Data preprocessing for WanX |
|
|
โ โโโ all_dimension_aug_wanx.txt # Training prompts for WanX |
|
|
โ |
|
|
โโโ Block-Sparse-Attention/ # Sparse attention library |
|
|
โ โโโ setup.py # Compilation and installation script |
|
|
โ โโโ block_sparse_attn/ # Core library code |
|
|
โ โโโ README.md # Library usage instructions |
|
|
โ |
|
|
โโโ ds_config.json # DeepSpeed configuration file |
|
|
``` |
|
|
|
|
|
## ๐ค Acknowledgements |
|
|
|
|
|
- [FlashAttention](https://github.com/Dao-AILab/flash-attention), [Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention): For the foundational work on sparse attention. |
|
|
- [CogVideoX](https://github.com/THUDM/CogVideo), [Wan2.1](https://github.com/Wan-Video/Wan2.1): For the supported models. |
|
|
- [TDM](https://www.google.com/search?q=https://github.com/Luo-Yihong/TDM): For the foundational work on distillation implementation. |
|
|
- [Diffusers](https://github.com/huggingface/diffusers): For the invaluable diffusion models library. |
|
|
|
|
|
## ๐ Citation |
|
|
|
|
|
If you use BLADE in your research, please cite our work: |
|
|
|
|
|
```bibtex |
|
|
@misc{gu2025videobladeblocksparseattentionmeets, |
|
|
title={BLADE: Block-Sparse Attention Meets Step Distillation for Efficient Video Generation}, |
|
|
author={Youping Gu and Xiaolong Li and Yuhao Hu and Bohan Zhuang}, |
|
|
year={2025}, |
|
|
eprint={2508.10774}, |
|
|
archivePrefix={arXiv}, |
|
|
primaryClass={cs.CV}, |
|
|
url={https://arxiv.org/abs/2508.10774}, |
|
|
} |
|
|
``` |
|
|
|
|
|
## ๐ง Contact |
|
|
|
|
|
For any questions or suggestions, feel free to: |
|
|
|
|
|
- Contact Youping Gu at youpgu71@gmail.com. |
|
|
- Submit an issue on our [Github page](https://github.com/ziplab/BLADE/issues). |