BLADE / README.md
GYP666's picture
Update README.md
22a1ed6 verified
|
raw
history blame
10.4 kB
---
license: apache-2.0
---
# BLADE: Block-Sparse Attention Meets Step Distillation for Efficient Video Generation
<div align="center">
[๐Ÿ“– Paper](https://arxiv.org/abs/2508.10774) | [๐Ÿš€ Homepage](http://ziplab.co/BLADE-Homepage/) | [๐Ÿ’พ Models](https://huggingface.co/GYP666/BLADE) | [๐Ÿ“– ไธญๆ–‡้˜…่ฏป](README_zh.md)
</div>
BLADE is a data-free framework for efficient video generation. By jointly training an adaptive sparse attention mechanism with a step distillation technique, it achieves a significant acceleration in video generation models. This project combines a block-sparse attention mechanism with step distillation, reducing the number of inference steps from 50 to just 8 while maintaining high-quality generation.
## ๐Ÿ“ข News
- **[Aug 2025]** ๐ŸŽ‰ The code and pre-trained models for BLADE have been released\!
- **[Aug 2025]** ๐Ÿ“ Support for two mainstream video generation models, CogVideoX-5B and WanX-1.3B, is now available.
- **[Aug 2025]** โšก Achieved high-quality video generation in just 8 steps, a significant speedup compared to the 50-step baseline.
## โœจ Key Features
- ๐Ÿš€ **Efficient Inference**: Reduces the number of inference steps from 50 to 8 while preserving generation quality.
- ๐ŸŽฏ **Adaptive Sparse Attention**: Employs a block-sparse attention mechanism to significantly reduce computational complexity.
- ๐Ÿ“ˆ **Step Distillation**: Utilizes the Trajectory Distillation Method (TDM), enabling training without the need for video data.
- ๐ŸŽฎ **Plug-and-Play**: Supports CogVideoX-5B and WanX-1.3B models without requiring modifications to their original architectures.
## ๐Ÿ› ๏ธ Environment Setup
### System Requirements
- Python \>= 3.11 (Recommended)
- CUDA \>= 11.6 (Recommended)
- GPU Memory \>= 24GB (for Inference)
- GPU Memory \>= 80GB (for Training)
### Installation Steps
1. **Clone the repository**
```bash
git clone https://github.com/Tacossp/BLADE
cd BLADE
```
2. **Install dependencies**
```bash
# Install using uv (Recommended)
uv pip install -r requirements.txt
# Or use pip
pip install -r requirements.txt
```
3. **Compile the Block-Sparse-Attention library**
```bash
git clone https://github.com/mit-han-lab/Block-Sparse-Attention.git
cd Block-Sparse-Attention
pip install packaging
pip install ninja
python setup.py install
cd ..
```
## ๐Ÿ“ฅ Model Weights Download
### Base Model Weights
Please download the following base model weights and place them in the specified directories:
1. **CogVideoX-5B Model**
```bash
# Download from Hugging Face
git lfs install
git clone https://huggingface.co/zai-org/CogVideoX-5b cogvideox/CogVideoX-5b
```
2. **WanX-1.3B Model**
```bash
# Download from Hugging Face
git clone https://huggingface.co/Wan-AI/Wan2.1-T2V-1.3B-Diffusers wanx/wan1.3b
```
### Pre-trained BLADE Weights
We provide pre-trained weights for BLADE:
```bash
# Download pre-trained weights
git clone https://huggingface.co/GYP666/BLADE pretrained_weights
```
### Weight Directory Structure
Ensure your directory structure for weights is as follows:
```
BLADE/
โ”œโ”€โ”€ cogvideox/
โ”‚ โ””โ”€โ”€ CogVideoX-5b/ # Base model weights for CogVideoX
โ”œโ”€โ”€ wanx/
โ”‚ โ””โ”€โ”€ wan1.3b/ # Base model weights for WanX
โ””โ”€โ”€ pretrained_weights/ # Pre-trained weights for BLADE
โ”œโ”€โ”€ BLADE_cogvideox_weight/
โ””โ”€โ”€ BLADE_wanx_weight/
```
## ๐Ÿš€ Quick Start - Inference
### CogVideoX Inference
```bash
cd cogvideox
python train/inference.py \
--lora_path ../pretrained_weights/cogvideox_checkpoints/your_checkpoint \
--gpu 0
```
**Argument Descriptions**:
- `--lora_path`: Path to the LoRA weights file.
- `--gpu`: The ID of the GPU device to use (Default: 0).
**Output**: The generated videos will be saved in the `cogvideox/outputs/inference/` directory.
### WanX Inference
```bash
cd wanx
python train/inference.py \
--lora_path ../pretrained_weights/wanx_checkpoints/your_checkpoint \
--gpu 0
```
**Output**: The generated videos will be saved in the `wanx/outputs/` directory.
## ๐Ÿ”ง Training Process
### Step 1: Prompt Preprocessing
Before training, you need to preprocess the text prompts to generate embeddings.
#### CogVideoX Preprocessing
```bash
cd utils
python process_prompts_cogvideox.py \
--input_file your_prompts.txt \
--output_dir ../cogvideox/prompts \
--model_path ../cogvideox/CogVideoX-5b \
--batch_size 32 \
--save_separate
```
**Argument Descriptions**:
- `--input_file`: A `.txt` file containing prompts, with one prompt per line.
- `--output_dir`: The directory to save the output embeddings.
- `--model_path`: Path to the CogVideoX model.
- `--batch_size`: The batch size for processing.
- `--save_separate`: Whether to save each embedding as a separate file.
#### WanX Preprocessing
```bash
cd utils
python process_prompts_wanx.py
```
This script will automatically process the prompts in `utils/all_dimension_aug_wanx.txt` and generate the corresponding embeddings.
### Step 2: Start Training
#### CogVideoX Training
```bash
cd cogvideox
bash train_tdm_1.sh
```
**Core Training Parameters**:
```bash
# If not training with 8 GPUs, you must modify CUDA_VISIBLE_DEVICES and the num_processes in config.yaml
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch \
--config_file train/config.yaml \
train/train_cogvideo_tdm.py \
--pretrained_model_name_or_path CogVideoX-5b \ # Path to the base model
--mixed_precision bf16 \ # Use mixed-precision for reduced memory usage
--train_batch_size 5 \ # Training batch size
--gradient_accumulation_steps 4 \ # Number of gradient accumulation steps
--learning_rate 1e-4 \ # Learning rate for the student model
--learning_rate_g 1e-4 \
--learning_rate_fake 5e-4 \ # Learning rate for the fake model
--lambda_reg 0.5 \ # Regularization weight
--k_step 8 \ # Target number of steps for distillation
--cfg 3.5 \ # Classifier-Free Guidance scale
--eta 0.9 \ # ETA parameter for DDIM
--use_sparsity true \ # Enable sparse attention
--rank 64 \
--lora_alpha 64 \ # LoRA configuration
--max_train_steps 300 \ # Maximum number of training steps
--checkpointing_steps 15 \ # Interval for saving checkpoints
--gradient_checkpointing \ # Use gradient checkpointing to save memory
--enable_slicing \
--enable_tiling # VAE memory optimization
```
#### WanX Training
```bash
cd wanx
bash train_wanx_tdm.sh
```
## ๐Ÿ“Š Project Structure
```
BLADE/
โ”œโ”€โ”€ README.md # Project documentation
โ”œโ”€โ”€ requirements.txt # List of Python dependencies
โ”‚
โ”œโ”€โ”€ cogvideox/ # Code related to CogVideoX
โ”‚ โ”œโ”€โ”€ CogVideoX-5b/ # Directory for base model weights
โ”‚ โ”œโ”€โ”€ train/ # Training scripts
โ”‚ โ”‚ โ”œโ”€โ”€ inference.py # Inference script
โ”‚ โ”‚ โ”œโ”€โ”€ train_cogvideo_tdm.py # Training script
โ”‚ โ”‚ โ”œโ”€โ”€ train_tdm_1.sh # Script to launch training
โ”‚ โ”‚ โ”œโ”€โ”€ modify_cogvideo.py # Model modification script
โ”‚ โ”‚ โ””โ”€โ”€ config.yaml # Training configuration file
โ”‚ โ”œโ”€โ”€ prompts/ # Preprocessed prompts and embeddings
โ”‚ โ””โ”€โ”€ outputs/ # Output from training and inference
โ”‚
โ”œโ”€โ”€ wanx/ # Code related to WanX
โ”‚ โ”œโ”€โ”€ wan1.3b/ # Directory for base model weights
โ”‚ โ”œโ”€โ”€ train/ # Training scripts
โ”‚ โ”‚ โ”œโ”€โ”€ inference.py # Inference script
โ”‚ โ”‚ โ”œโ”€โ”€ train_wanx_tdm.py # Training script
โ”‚ โ”‚ โ”œโ”€โ”€ train_wanx_tdm.sh # Script to launch training
โ”‚ โ”‚ โ””โ”€โ”€ modify_wan.py # Model modification script
โ”‚ โ”œโ”€โ”€ prompts/ # Preprocessed prompts and embeddings
โ”‚ โ””โ”€โ”€ outputs/ # Output from training and inference
โ”‚
โ”œโ”€โ”€ utils/ # Utility scripts
โ”‚ โ”œโ”€โ”€ process_prompts_cogvideox.py # Data preprocessing for CogVideoX
โ”‚ โ”œโ”€โ”€ process_prompts_wanx.py # Data preprocessing for WanX
โ”‚ โ””โ”€โ”€ all_dimension_aug_wanx.txt # Training prompts for WanX
โ”‚
โ”œโ”€โ”€ Block-Sparse-Attention/ # Sparse attention library
โ”‚ โ”œโ”€โ”€ setup.py # Compilation and installation script
โ”‚ โ”œโ”€โ”€ block_sparse_attn/ # Core library code
โ”‚ โ””โ”€โ”€ README.md # Library usage instructions
โ”‚
โ””โ”€โ”€ ds_config.json # DeepSpeed configuration file
```
## ๐Ÿค Acknowledgements
- [FlashAttention](https://github.com/Dao-AILab/flash-attention), [Block-Sparse-Attention](https://github.com/mit-han-lab/Block-Sparse-Attention): For the foundational work on sparse attention.
- [CogVideoX](https://github.com/THUDM/CogVideo), [Wan2.1](https://github.com/Wan-Video/Wan2.1): For the supported models.
- [TDM](https://www.google.com/search?q=https://github.com/Luo-Yihong/TDM): For the foundational work on distillation implementation.
- [Diffusers](https://github.com/huggingface/diffusers): For the invaluable diffusion models library.
## ๐Ÿ“„ Citation
If you use BLADE in your research, please cite our work:
```bibtex
@misc{gu2025videobladeblocksparseattentionmeets,
title={BLADE: Block-Sparse Attention Meets Step Distillation for Efficient Video Generation},
author={Youping Gu and Xiaolong Li and Yuhao Hu and Bohan Zhuang},
year={2025},
eprint={2508.10774},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2508.10774},
}
```
## ๐Ÿ“ง Contact
For any questions or suggestions, feel free to:
- Contact Youping Gu at youpgu71@gmail.com.
- Submit an issue on our [Github page](https://github.com/ziplab/BLADE/issues).