Commit
·
9c451a3
1
Parent(s):
dc5e848
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +15 -0
- .gitignore +20 -0
- README.md +171 -0
- assets/DMDR.webp +3 -0
- assets/Z-Image-Gallery.pdf +3 -0
- assets/architecture.webp +3 -0
- assets/bottle.jpg +3 -0
- assets/canny.jpg +3 -0
- assets/decoupled-dmd.webp +3 -0
- assets/depth.jpg +3 -0
- assets/depth_cat.png +3 -0
- assets/hed.jpg +3 -0
- assets/inpaint.jpg +3 -0
- assets/leaderboard.png +3 -0
- assets/leaderboard.webp +0 -0
- assets/man_hed.png +3 -0
- assets/mask.jpg +3 -0
- assets/mask_inpaint.jpg +3 -0
- assets/pose.jpg +3 -0
- assets/pose2.jpg +3 -0
- assets/pose3.jpg +3 -0
- assets/pose4.png +3 -0
- assets/reasoning.png +3 -0
- assets/room_mlsd.png +3 -0
- assets/showcase.jpg +3 -0
- assets/showcase_editing.png +3 -0
- assets/showcase_realistic.png +3 -0
- assets/showcase_rendering.png +3 -0
- diffusers_local/__init__.py +7 -0
- diffusers_local/patch.py +509 -0
- diffusers_local/pipeline_z_image_control_unified.py +910 -0
- diffusers_local/z_image_control_transformer_2d.py +1443 -0
- infer_controlnet.py +146 -0
- infer_i2i.py +94 -0
- infer_inpaint.py +109 -0
- infer_t2i.py +89 -0
- model_index.json +24 -0
- requirements.txt +22 -0
- results/canny.png +3 -0
- results/depth.png +3 -0
- results/hed.png +3 -0
- results/new_tests/controlnet_result_i2i.png +3 -0
- results/new_tests/result_control_canny.png +3 -0
- results/new_tests/result_control_depth.png +3 -0
- results/new_tests/result_control_hed.png +3 -0
- results/new_tests/result_control_inpaint_original_mask.png +3 -0
- results/new_tests/result_control_mlsd.png +3 -0
- results/new_tests/result_control_pose.png +3 -0
- results/new_tests/result_inpaint.png +3 -0
- results/new_tests/result_t2i.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,18 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/architecture.webp filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/decoupled-dmd.webp filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/DMDR.webp filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/leaderboard.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/reasoning.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/showcase_editing.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
assets/showcase_realistic.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
assets/showcase_rendering.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
assets/showcase.jpg filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
assets/Z-Image-Gallery.pdf filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
assets/*.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
assets/*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
transformer/z_image_turbo_control_unified_v2.1_q4_k_m.gguf filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
local_tests/
|
| 4 |
+
/.vs
|
| 5 |
+
.vscode/
|
| 6 |
+
.ruff_cache/
|
| 7 |
+
.idea/
|
| 8 |
+
models/
|
| 9 |
+
venv/
|
| 10 |
+
models/
|
| 11 |
+
.venv/
|
| 12 |
+
*.log
|
| 13 |
+
.DS_Store
|
| 14 |
+
.gradio
|
| 15 |
+
download.py
|
| 16 |
+
bk/
|
| 17 |
+
outputs/
|
| 18 |
+
original/
|
| 19 |
+
Makefile
|
| 20 |
+
pyproject.toml
|
README.md
CHANGED
|
@@ -1,3 +1,174 @@
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- text-to-image
|
| 5 |
+
- image-to-image
|
| 6 |
+
- inpainting
|
| 7 |
+
- controlnet
|
| 8 |
+
- diffusers
|
| 9 |
+
- gguf
|
| 10 |
+
- z-image-turbo
|
| 11 |
+
pipeline_tag: text-to-image
|
| 12 |
---
|
| 13 |
+
|
| 14 |
+
# Z-Image Turbo Control Unified V2 (V2.1)
|
| 15 |
+
|
| 16 |
+
[](https://github.com/aigc-apps/VideoX-Fun)
|
| 17 |
+
[](https://huggingface.co/alibaba-pai/Z-Image-Turbo-Fun-Controlnet-Union-2.1)
|
| 18 |
+
|
| 19 |
+
This repository hosts the **Z-Image Turbo Control Unified V2** model. This is a specialized architecture that unifies the powerful **Z-Image Turbo** base transformer with enhanced **ControlNet** capabilities into a single, cohesive model. This unified pipeline supports multiple generation modes in one place: **Text-to-Image, Image-to-Image, ControlNet, and Inpainting**.
|
| 20 |
+
|
| 21 |
+
Unlike traditional pipelines where ControlNet is an external add-on, this model integrates control layers directly into the transformer structure. This enables **Unified GGUF Quantization**, allowing the entire merged architecture (Base + Control) to be quantized (e.g., Q4_K_M) and run efficiently on consumer hardware with limited VRAM. This version also introduces significant optimizations, architectural improvements, and bug fixes for features like `group_offload`.
|
| 22 |
+
|
| 23 |
+
## 📥 Installation
|
| 24 |
+
|
| 25 |
+
To set up the environment, simply install the dependencies:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
#create virtual env
|
| 29 |
+
python -m venv venv
|
| 30 |
+
|
| 31 |
+
# Activate your venv
|
| 32 |
+
|
| 33 |
+
#upgrade pip
|
| 34 |
+
python.exe -m pip install --upgrade pip
|
| 35 |
+
|
| 36 |
+
#install requirements
|
| 37 |
+
pip install -r requirements.txt
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
*Note: This repository contains a `diffusers_local` folder with the custom `ZImageControlUnifiedPipeline` and transformer logic required to run this specific architecture.*
|
| 41 |
+
|
| 42 |
+
## 🚀 Usage
|
| 43 |
+
|
| 44 |
+
## 📂 Repository Structure
|
| 45 |
+
|
| 46 |
+
* `./transformer/z_image_turbo_control_unified_v2.1_q4_k_m.gguf`: The unified, quantized model weights.
|
| 47 |
+
* `infer_controlnet.py`: Script for running controlnet inference.
|
| 48 |
+
* `infer_inpaint.py`: Script for running inpaint inference.
|
| 49 |
+
* `infer_t2i.py`: Script for running text-to-image inference.
|
| 50 |
+
* `infer_i2i.py`: Script for running image-to-image inference.
|
| 51 |
+
* `diffusers_local/`: Custom pipeline code (`ZImageControlUnifiedPipeline`) and transformer logic.
|
| 52 |
+
* `requirements.txt`: Python dependencies.
|
| 53 |
+
|
| 54 |
+
The primary script for inference is `infer_controlnet.py`, which is designed to handle all supported generation modes.
|
| 55 |
+
|
| 56 |
+
### Option 1: Low VRAM (GGUF) - Recommended
|
| 57 |
+
Use this version if you have limited VRAM (e.g., 6GB - 8GB). It loads the model from a quantized **GGUF** file (`z_image_turbo_control_unified_v2.1_q4_k_m.gguf`). Simply configure the `infer_controlnet.py` script to point to the GGUF file.
|
| 58 |
+
|
| 59 |
+
**Key Features of this mode:**
|
| 60 |
+
* Loads the unified transformer from a single 4-bit quantized file.
|
| 61 |
+
* Enables aggressive `group_offload` to fit large models in consumer GPUs.
|
| 62 |
+
|
| 63 |
+
### Option 2: High Precision (Diffusers/BF16)
|
| 64 |
+
Use this version if you have ample VRAM (e.g., 24GB+). Configure `infer_controlnet.py` to load the model using the standard `from_pretrained` directory structure for full **BFloat16** precision.
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
## 🛠️ Model Features & Configuration (V2)
|
| 68 |
+
|
| 69 |
+
### Original Features
|
| 70 |
+
- This ControlNet is added on 15 layer blocks and 2 refiner layer blocks.
|
| 71 |
+
- The model was trained from scratch for 70,000 steps on a dataset of 1 million high-quality images.
|
| 72 |
+
- Multiple Control Conditions Supports Canny, HED, Depth, Pose, and MLSD, which can be used like a standard ControlNet.
|
| 73 |
+
- You can adjust `controlnet_conditioning_scale` for stronger control. For better stability, we highly recommend using a detailed prompt. The optimal range for `controlnet_conditioning_scale` is from 0.65 to 0.90.
|
| 74 |
+
- **Note on Steps: As you increase the control strength, it's recommended to appropriately increase the number of inference steps to achieve better results.**
|
| 75 |
+
|
| 76 |
+
This optmized V2 model introduces several new features and parameters for enhanced control and flexibility:
|
| 77 |
+
|
| 78 |
+
* **Unified Pipeline:** A single pipeline now handles Text-to-Image, Image-to-Image, ControlNet, and Inpainting tasks.
|
| 79 |
+
* **Refiner Scale (`controlnet_refiner_conditioning_scale`):** It provides fine-grained control over the influence of the initial refining layers, allowing for isolated adjustments without the influence of the controlnet's conditioning force.
|
| 80 |
+
* **Optional Refiner (`add_control_noise_refiner=False`):** You can now disable the control noise refiner layers when loading the model to save memory or for different stylistic results.
|
| 81 |
+
* **Inpainting Blur (`mask_blur_radius`):** A parameter to soften the edges of the inpainting mask for smoother transitions.
|
| 82 |
+
* **Backward Compatibility:** The model supports running weights from V1.
|
| 83 |
+
* **Group Offload Fixes:** The underlying code includes crucial fixes to ensure diffusers `group_offload` works correctly with `use_stream=True`, enabling efficient memory management without errors.
|
| 84 |
+
|
| 85 |
+
## 🏞️ V2 Examples: Refiner Scale Test
|
| 86 |
+
|
| 87 |
+
The new `controlnet_refiner_conditioning_scale` parameter allows for fine-tuning the control signal. Here is a comparison showing its effect while keeping the main control scale fixed.
|
| 88 |
+
|
| 89 |
+
**Prompt:** "Photorealistic portrait of a beautiful young East Asian woman with long, vibrant purple hair and a black bow. She is wearing a flowing white summer dress, standing on a sunny beach with a sparkling ocean and clear blue sky in the background. Bright natural sunlight, sharp focus, ultra-detailed."
|
| 90 |
+
**Control Image:** Pose.
|
| 91 |
+
| `controlnet_conditioning_scale=0.75, num_steps=25` | Refiner: Off | Refiner Scale: 0.75 | Refiner Scale: 1.0 | Refiner Scale: 1.5 | Refiner Scale: 2.0 |
|
| 92 |
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
| 93 |
+
| **Output** |  |  |  |  |  |
|
| 94 |
+
|
| 95 |
+
---
|
| 96 |
+
### New Tests with this pipeline
|
| 97 |
+
|
| 98 |
+
<table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
|
| 99 |
+
<tr>
|
| 100 |
+
<td>Pose + Inpaint</td>
|
| 101 |
+
<td>Output</td>
|
| 102 |
+
</tr>
|
| 103 |
+
<tr>
|
| 104 |
+
<td><img src="assets/inpaint.jpg" width="50%" /><img src="assets/mask_inpaint.jpg" width="50%" /></td>
|
| 105 |
+
<td><img src="results/new_tests/result_inpaint.png" width="50%" /></td>
|
| 106 |
+
</tr>
|
| 107 |
+
</table>
|
| 108 |
+
<table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
|
| 109 |
+
<tr>
|
| 110 |
+
<td>Pose</td>
|
| 111 |
+
<td>Output</td>
|
| 112 |
+
</tr>
|
| 113 |
+
<tr>
|
| 114 |
+
<td><img src="assets/pose.jpg" width="85%" /></td>
|
| 115 |
+
<td><img src="results/new_tests/result_control_pose.png" width="50%" /></td>
|
| 116 |
+
</tr>
|
| 117 |
+
</table>
|
| 118 |
+
<table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
|
| 119 |
+
<tr>
|
| 120 |
+
<td>Canny</td>
|
| 121 |
+
<td>Output</td>
|
| 122 |
+
</tr>
|
| 123 |
+
<tr>
|
| 124 |
+
<td><img src="assets/canny.jpg" width="50%" /></td>
|
| 125 |
+
<td><img src="results/new_tests/result_control_canny.png" width="50%" /></td>
|
| 126 |
+
</tr>
|
| 127 |
+
</table>
|
| 128 |
+
<table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
|
| 129 |
+
<tr>
|
| 130 |
+
<td>HED</td>
|
| 131 |
+
<td>Output</td>
|
| 132 |
+
</tr>
|
| 133 |
+
<tr>
|
| 134 |
+
<td><img src="assets/man_hed.png" width="50%" /></td>
|
| 135 |
+
<td><img src="results/new_tests/result_control_hed.png" width="50%" /></td>
|
| 136 |
+
</tr>
|
| 137 |
+
</table>
|
| 138 |
+
<table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
|
| 139 |
+
<tr>
|
| 140 |
+
<td>Depth</td>
|
| 141 |
+
<td>Output</td>
|
| 142 |
+
</tr>
|
| 143 |
+
<tr>
|
| 144 |
+
<td><img src="assets/depth_cat.png" width="50%" /></td>
|
| 145 |
+
<td><img src="results/new_tests/result_control_depth.png" width="50%" /></td>
|
| 146 |
+
</tr>
|
| 147 |
+
</table>
|
| 148 |
+
<table border="0" style="width: 100%; text-align: left; margin-top: 20px;">
|
| 149 |
+
<tr>
|
| 150 |
+
<td>MLSD</td>
|
| 151 |
+
<td>Output</td>
|
| 152 |
+
</tr>
|
| 153 |
+
<tr>
|
| 154 |
+
<td><img src="assets/room_mlsd.png" width="100%" /></td>
|
| 155 |
+
<td><img src="results/new_tests/result_control_mlsd.png" width="50%" /></td>
|
| 156 |
+
</tr>
|
| 157 |
+
</table>
|
| 158 |
+
|
| 159 |
+
## Original V2 Model Results
|
| 160 |
+
|
| 161 |
+
This section includes examples from the original model for reference. The V2 model is capable of producing these results and more.
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
### Original Scale Test Results
|
| 165 |
+
|
| 166 |
+
The table below shows the generation results under different combinations of Diffusion steps and Control Scale strength from the original model:
|
| 167 |
+
|
| 168 |
+
| Diffusion Steps | Scale 0.65 | Scale 0.70 | Scale 0.75 | Scale 0.8 | Scale 0.9 | Scale 1.0 |
|
| 169 |
+
|:---------------:|:----------:|:----------:|:----------:|:---------:|:---------:|:---------:|
|
| 170 |
+
| **9** |  |  |  |  |  |  |
|
| 171 |
+
| **10** |  |  |  |  |  |  |
|
| 172 |
+
| **20** |  |  |  |  |  |  |
|
| 173 |
+
| **30** |  |  |  |  |  |  |
|
| 174 |
+
| **40** |  |  |  |  |  |  |
|
assets/DMDR.webp
ADDED
|
Git LFS Details
|
assets/Z-Image-Gallery.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f9895b3246d2547bac74bbe0be975da500eaae93f2cad4248ad3281786b1ac6
|
| 3 |
+
size 15767436
|
assets/architecture.webp
ADDED
|
Git LFS Details
|
assets/bottle.jpg
ADDED
|
Git LFS Details
|
assets/canny.jpg
ADDED
|
Git LFS Details
|
assets/decoupled-dmd.webp
ADDED
|
Git LFS Details
|
assets/depth.jpg
ADDED
|
Git LFS Details
|
assets/depth_cat.png
ADDED
|
Git LFS Details
|
assets/hed.jpg
ADDED
|
Git LFS Details
|
assets/inpaint.jpg
ADDED
|
Git LFS Details
|
assets/leaderboard.png
ADDED
|
Git LFS Details
|
assets/leaderboard.webp
ADDED
|
assets/man_hed.png
ADDED
|
Git LFS Details
|
assets/mask.jpg
ADDED
|
Git LFS Details
|
assets/mask_inpaint.jpg
ADDED
|
Git LFS Details
|
assets/pose.jpg
ADDED
|
Git LFS Details
|
assets/pose2.jpg
ADDED
|
Git LFS Details
|
assets/pose3.jpg
ADDED
|
Git LFS Details
|
assets/pose4.png
ADDED
|
Git LFS Details
|
assets/reasoning.png
ADDED
|
Git LFS Details
|
assets/room_mlsd.png
ADDED
|
Git LFS Details
|
assets/showcase.jpg
ADDED
|
Git LFS Details
|
assets/showcase_editing.png
ADDED
|
Git LFS Details
|
assets/showcase_realistic.png
ADDED
|
Git LFS Details
|
assets/showcase_rendering.png
ADDED
|
Git LFS Details
|
diffusers_local/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .z_image_control_transformer_2d import Transformer2DModelOutput, ZImageControlTransformer2DModel
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"Transformer2DModelOutput",
|
| 6 |
+
"ZImageControlTransformer2DModel",
|
| 7 |
+
]
|
diffusers_local/patch.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional, Set
|
| 4 |
+
|
| 5 |
+
import diffusers.loaders.single_file_model as single_file_model
|
| 6 |
+
import diffusers.pipelines.pipeline_loading_utils as pipe_loading_utils
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers.loaders.single_file_utils import (
|
| 9 |
+
convert_animatediff_checkpoint_to_diffusers,
|
| 10 |
+
convert_auraflow_transformer_checkpoint_to_diffusers,
|
| 11 |
+
convert_autoencoder_dc_checkpoint_to_diffusers,
|
| 12 |
+
convert_chroma_transformer_checkpoint_to_diffusers,
|
| 13 |
+
convert_controlnet_checkpoint,
|
| 14 |
+
convert_cosmos_transformer_checkpoint_to_diffusers,
|
| 15 |
+
convert_flux2_transformer_checkpoint_to_diffusers,
|
| 16 |
+
convert_flux_transformer_checkpoint_to_diffusers,
|
| 17 |
+
convert_hidream_transformer_to_diffusers,
|
| 18 |
+
convert_hunyuan_video_transformer_to_diffusers,
|
| 19 |
+
convert_ldm_unet_checkpoint,
|
| 20 |
+
convert_ldm_vae_checkpoint,
|
| 21 |
+
convert_ltx_transformer_checkpoint_to_diffusers,
|
| 22 |
+
convert_ltx_vae_checkpoint_to_diffusers,
|
| 23 |
+
convert_lumina2_to_diffusers,
|
| 24 |
+
convert_mochi_transformer_checkpoint_to_diffusers,
|
| 25 |
+
convert_sana_transformer_to_diffusers,
|
| 26 |
+
convert_sd3_transformer_checkpoint_to_diffusers,
|
| 27 |
+
convert_stable_cascade_unet_single_file_to_diffusers,
|
| 28 |
+
convert_wan_transformer_to_diffusers,
|
| 29 |
+
convert_wan_vae_to_diffusers,
|
| 30 |
+
convert_z_image_transformer_checkpoint_to_diffusers,
|
| 31 |
+
create_controlnet_diffusers_config_from_ldm,
|
| 32 |
+
create_unet_diffusers_config_from_ldm,
|
| 33 |
+
create_vae_diffusers_config_from_ldm,
|
| 34 |
+
)
|
| 35 |
+
from diffusers.pipelines.pipeline_loading_utils import _unwrap_model
|
| 36 |
+
from diffusers.utils import (
|
| 37 |
+
_maybe_remap_transformers_class,
|
| 38 |
+
get_class_from_dynamic_module,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
from diffusers.hooks.group_offloading import (
|
| 44 |
+
_GROUP_ID_LAZY_LEAF,
|
| 45 |
+
GroupOffloadingConfig,
|
| 46 |
+
ModelHook,
|
| 47 |
+
ModuleGroup,
|
| 48 |
+
_apply_group_offloading_hook,
|
| 49 |
+
_apply_lazy_group_offloading_hook,
|
| 50 |
+
_find_parent_module_in_module_dict,
|
| 51 |
+
_gather_buffers_with_no_group_offloading_parent,
|
| 52 |
+
_gather_parameters_with_no_group_offloading_parent,
|
| 53 |
+
send_to_device,
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
except ImportError:
|
| 57 |
+
ModelHook = object
|
| 58 |
+
ModuleGroup = object
|
| 59 |
+
GroupOffloadingConfig = object
|
| 60 |
+
|
| 61 |
+
def _apply_group_offloading_hook(*args, **kwargs):
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
_MY_GO_LC_SUPPORTED_PYTORCH_LAYERS = (
|
| 66 |
+
torch.nn.Conv1d,
|
| 67 |
+
torch.nn.Conv2d,
|
| 68 |
+
torch.nn.Conv3d,
|
| 69 |
+
torch.nn.ConvTranspose1d,
|
| 70 |
+
torch.nn.ConvTranspose2d,
|
| 71 |
+
torch.nn.ConvTranspose3d,
|
| 72 |
+
torch.nn.Linear,
|
| 73 |
+
torch.nn.Sequential, # A camada que queremos adicionar
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class GroupOffloadingHook(ModelHook):
|
| 78 |
+
r"""
|
| 79 |
+
A hook that offloads groups of torch.nn.Module to the CPU for storage and onloads to accelerator device for
|
| 80 |
+
computation. Each group has one "onload leader" module that is responsible for onloading, and an "offload leader"
|
| 81 |
+
module that is responsible for offloading. If prefetching is enabled, the onload leader of the previous module
|
| 82 |
+
group is responsible for onloading the current module group.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
_is_stateful = False
|
| 86 |
+
|
| 87 |
+
def __init__(self, group: ModuleGroup, *, config: GroupOffloadingConfig) -> None:
|
| 88 |
+
self.group = group
|
| 89 |
+
self.next_group: Optional[ModuleGroup] = None
|
| 90 |
+
self.config = config
|
| 91 |
+
|
| 92 |
+
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
| 93 |
+
if self.group.offload_leader == module:
|
| 94 |
+
self.group.offload_()
|
| 95 |
+
return module
|
| 96 |
+
|
| 97 |
+
def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
|
| 98 |
+
# If there wasn't an onload_leader assigned, we assume that the submodule that first called its forward
|
| 99 |
+
# method is the onload_leader of the group.
|
| 100 |
+
if self.group.onload_leader is None:
|
| 101 |
+
self.group.onload_leader = module
|
| 102 |
+
|
| 103 |
+
if self.group.onload_leader == module:
|
| 104 |
+
# STEP 1: GUARANTEE THE CURRENT GROUP'S STATE
|
| 105 |
+
# This section ensures that the parameters for the *current* module are on the correct device
|
| 106 |
+
# before its forward pass is executed.
|
| 107 |
+
|
| 108 |
+
# This block handles modules that are part of the prefetching chain (`onload_self` is False).
|
| 109 |
+
# The original design relied on the previous module to initiate the onload, which proved fragile.
|
| 110 |
+
# Our robust fix makes each module responsible for itself:
|
| 111 |
+
# 1. `self.group.onload_()`: Guarantees the data transfer is initiated, acting as a backup if the
|
| 112 |
+
# previous module in the chain failed to do so.
|
| 113 |
+
# 2. `self.group.stream.synchronize()`: This is the critical synchronization barrier. It forces the
|
| 114 |
+
# CPU to wait until the asynchronous transfer to the GPU is complete, preventing device mismatch errors.
|
| 115 |
+
if not self.group.onload_self and self.group.stream is not None:
|
| 116 |
+
self.group.onload_()
|
| 117 |
+
self.group.stream.synchronize()
|
| 118 |
+
|
| 119 |
+
# This block handles the first module in an execution chain (`onload_self` is True).
|
| 120 |
+
# It is responsible for loading itself onto the device.
|
| 121 |
+
if self.group.onload_self:
|
| 122 |
+
self.group.onload_()
|
| 123 |
+
# If streams are used, the onload() call above is asynchronous. We MUST synchronize here
|
| 124 |
+
# to ensure the module is ready before its computation starts.
|
| 125 |
+
if self.group.stream is not None:
|
| 126 |
+
self.group.stream.synchronize()
|
| 127 |
+
|
| 128 |
+
# At this point, we are 100% certain that the current group's parameters are on the onload_device.
|
| 129 |
+
|
| 130 |
+
# STEP 2: INITIATE PREFETCHING FOR THE NEXT GROUP
|
| 131 |
+
# With the current group secured, we can now look ahead and start the asynchronous data transfer
|
| 132 |
+
# for the next module in the execution chain. This allows the data transfer to overlap with the
|
| 133 |
+
# computation of the current module's forward pass, which is the core benefit of prefetching.
|
| 134 |
+
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
|
| 135 |
+
if should_onload_next_group:
|
| 136 |
+
self.next_group.onload_()
|
| 137 |
+
|
| 138 |
+
# The rest of the function handles moving positional (*args) and keyword (**kwargs)
|
| 139 |
+
# arguments to the correct device.
|
| 140 |
+
args = send_to_device(args, self.group.onload_device, non_blocking=self.group.non_blocking)
|
| 141 |
+
|
| 142 |
+
exclude_kwargs = self.config.exclude_kwargs or []
|
| 143 |
+
if exclude_kwargs:
|
| 144 |
+
moved_kwargs = send_to_device(
|
| 145 |
+
{k: v for k, v in kwargs.items() if k not in exclude_kwargs},
|
| 146 |
+
self.group.onload_device,
|
| 147 |
+
non_blocking=self.group.non_blocking,
|
| 148 |
+
)
|
| 149 |
+
kwargs.update(moved_kwargs)
|
| 150 |
+
else:
|
| 151 |
+
kwargs = send_to_device(kwargs, self.group.onload_device, non_blocking=self.group.non_blocking)
|
| 152 |
+
|
| 153 |
+
return args, kwargs
|
| 154 |
+
|
| 155 |
+
def post_forward(self, module: torch.nn.Module, output):
|
| 156 |
+
if self.group.offload_leader == module:
|
| 157 |
+
self.group.offload_()
|
| 158 |
+
return output
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def _apply_group_offloading_leaf_level_patched(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
| 162 |
+
"""
|
| 163 |
+
Versão corrigida de _apply_group_offloading_leaf_level que suporta nn.Sequential.
|
| 164 |
+
"""
|
| 165 |
+
modules_with_group_offloading: Set[str] = set()
|
| 166 |
+
for name, submodule in module.named_modules():
|
| 167 |
+
if not isinstance(submodule, _MY_GO_LC_SUPPORTED_PYTORCH_LAYERS):
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
group = ModuleGroup(
|
| 171 |
+
modules=[submodule],
|
| 172 |
+
offload_device=config.offload_device,
|
| 173 |
+
onload_device=config.onload_device,
|
| 174 |
+
offload_to_disk_path=config.offload_to_disk_path,
|
| 175 |
+
offload_leader=submodule,
|
| 176 |
+
onload_leader=submodule,
|
| 177 |
+
non_blocking=config.non_blocking,
|
| 178 |
+
stream=config.stream,
|
| 179 |
+
record_stream=config.record_stream,
|
| 180 |
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| 181 |
+
onload_self=True,
|
| 182 |
+
group_id=name,
|
| 183 |
+
)
|
| 184 |
+
_apply_group_offloading_hook(submodule, group, config=config)
|
| 185 |
+
modules_with_group_offloading.add(name)
|
| 186 |
+
|
| 187 |
+
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
|
| 188 |
+
# of the module is called
|
| 189 |
+
module_dict = dict(module.named_modules())
|
| 190 |
+
parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
| 191 |
+
buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
|
| 192 |
+
|
| 193 |
+
# Find closest module parent for each parameter and buffer, and attach group hooks
|
| 194 |
+
parent_to_parameters = {}
|
| 195 |
+
for name, param in parameters:
|
| 196 |
+
parent_name = _find_parent_module_in_module_dict(name, module_dict)
|
| 197 |
+
if parent_name in parent_to_parameters:
|
| 198 |
+
parent_to_parameters[parent_name].append(param)
|
| 199 |
+
else:
|
| 200 |
+
parent_to_parameters[parent_name] = [param]
|
| 201 |
+
|
| 202 |
+
parent_to_buffers = {}
|
| 203 |
+
for name, buffer in buffers:
|
| 204 |
+
parent_name = _find_parent_module_in_module_dict(name, module_dict)
|
| 205 |
+
if parent_name in parent_to_buffers:
|
| 206 |
+
parent_to_buffers[parent_name].append(buffer)
|
| 207 |
+
else:
|
| 208 |
+
parent_to_buffers[parent_name] = [buffer]
|
| 209 |
+
|
| 210 |
+
parent_names = set(parent_to_parameters.keys()) | set(parent_to_buffers.keys())
|
| 211 |
+
for name in parent_names:
|
| 212 |
+
parameters = parent_to_parameters.get(name, [])
|
| 213 |
+
buffers = parent_to_buffers.get(name, [])
|
| 214 |
+
parent_module = module_dict[name]
|
| 215 |
+
group = ModuleGroup(
|
| 216 |
+
modules=[],
|
| 217 |
+
offload_device=config.offload_device,
|
| 218 |
+
onload_device=config.onload_device,
|
| 219 |
+
offload_leader=parent_module,
|
| 220 |
+
onload_leader=parent_module,
|
| 221 |
+
offload_to_disk_path=config.offload_to_disk_path,
|
| 222 |
+
parameters=parameters,
|
| 223 |
+
buffers=buffers,
|
| 224 |
+
non_blocking=config.non_blocking,
|
| 225 |
+
stream=config.stream,
|
| 226 |
+
record_stream=config.record_stream,
|
| 227 |
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| 228 |
+
onload_self=True,
|
| 229 |
+
group_id=name,
|
| 230 |
+
)
|
| 231 |
+
_apply_group_offloading_hook(parent_module, group, config=config)
|
| 232 |
+
|
| 233 |
+
if config.stream is not None:
|
| 234 |
+
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
|
| 235 |
+
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
|
| 236 |
+
# execution order and apply prefetching in the correct order.
|
| 237 |
+
unmatched_group = ModuleGroup(
|
| 238 |
+
modules=[],
|
| 239 |
+
offload_device=config.offload_device,
|
| 240 |
+
onload_device=config.onload_device,
|
| 241 |
+
offload_to_disk_path=config.offload_to_disk_path,
|
| 242 |
+
offload_leader=module,
|
| 243 |
+
onload_leader=module,
|
| 244 |
+
parameters=None,
|
| 245 |
+
buffers=None,
|
| 246 |
+
non_blocking=False,
|
| 247 |
+
stream=None,
|
| 248 |
+
record_stream=False,
|
| 249 |
+
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
| 250 |
+
onload_self=True,
|
| 251 |
+
group_id=_GROUP_ID_LAZY_LEAF,
|
| 252 |
+
)
|
| 253 |
+
_apply_lazy_group_offloading_hook(module, unmatched_group, config=config)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
try:
|
| 257 |
+
import diffusers.hooks.group_offloading as group_offloading_module
|
| 258 |
+
|
| 259 |
+
setattr(group_offloading_module, "_apply_group_offloading_leaf_level", _apply_group_offloading_leaf_level_patched)
|
| 260 |
+
setattr(group_offloading_module, "GroupOffloadingHook", GroupOffloadingHook)
|
| 261 |
+
except ImportError as e:
|
| 262 |
+
print(f"-> ERRO: Não foi possível importar o módulo `diffusers.hooks.group_offloading` para aplicar o patch: {e}")
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def convert_z_image_control_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
| 266 |
+
Z_IMAGE_KEYS_RENAME_DICT = {
|
| 267 |
+
"final_layer.": "all_final_layer.2-1.",
|
| 268 |
+
"x_embedder.": "all_x_embedder.2-1.",
|
| 269 |
+
".attention.out.bias": ".attention.to_out.0.bias",
|
| 270 |
+
".attention.k_norm.weight": ".attention.norm_k.weight",
|
| 271 |
+
".attention.q_norm.weight": ".attention.norm_q.weight",
|
| 272 |
+
".attention.out.weight": ".attention.to_out.0.weight",
|
| 273 |
+
"control_x_embedder.": "control_all_x_embedder.2-1.",
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
|
| 277 |
+
if ".attention.qkv.weight" not in key:
|
| 278 |
+
return
|
| 279 |
+
|
| 280 |
+
fused_qkv_weight = state_dict.pop(key)
|
| 281 |
+
to_q_weight, to_k_weight, to_v_weight = torch.chunk(fused_qkv_weight, 3, dim=0)
|
| 282 |
+
new_q_name = key.replace(".attention.qkv.weight", ".attention.to_q.weight")
|
| 283 |
+
new_k_name = key.replace(".attention.qkv.weight", ".attention.to_k.weight")
|
| 284 |
+
new_v_name = key.replace(".attention.qkv.weight", ".attention.to_v.weight")
|
| 285 |
+
|
| 286 |
+
state_dict[new_q_name] = to_q_weight
|
| 287 |
+
state_dict[new_k_name] = to_k_weight
|
| 288 |
+
state_dict[new_v_name] = to_v_weight
|
| 289 |
+
return
|
| 290 |
+
|
| 291 |
+
TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
| 292 |
+
".attention.qkv.weight": convert_z_image_fused_attention,
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
def update_state_dict(state_dict: dict[str, object], old_key: str, new_key: str) -> None:
|
| 296 |
+
state_dict[new_key] = state_dict.pop(old_key)
|
| 297 |
+
|
| 298 |
+
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
| 299 |
+
|
| 300 |
+
# Handle single file --> diffusers key remapping via the remap dict
|
| 301 |
+
for key in list(converted_state_dict.keys()):
|
| 302 |
+
new_key = key[:]
|
| 303 |
+
for replace_key, rename_key in Z_IMAGE_KEYS_RENAME_DICT.items():
|
| 304 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 305 |
+
|
| 306 |
+
update_state_dict(converted_state_dict, key, new_key)
|
| 307 |
+
|
| 308 |
+
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
| 309 |
+
# special_keys_remap
|
| 310 |
+
for key in list(converted_state_dict.keys()):
|
| 311 |
+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
| 312 |
+
if special_key not in key:
|
| 313 |
+
continue
|
| 314 |
+
handler_fn_inplace(key, converted_state_dict)
|
| 315 |
+
|
| 316 |
+
return converted_state_dict
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
SINGLE_FILE_LOADABLE_CLASSES = {
|
| 320 |
+
"StableCascadeUNet": {
|
| 321 |
+
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
|
| 322 |
+
},
|
| 323 |
+
"UNet2DConditionModel": {
|
| 324 |
+
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
|
| 325 |
+
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
|
| 326 |
+
"default_subfolder": "unet",
|
| 327 |
+
"legacy_kwargs": {
|
| 328 |
+
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
|
| 329 |
+
},
|
| 330 |
+
},
|
| 331 |
+
"AutoencoderKL": {
|
| 332 |
+
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
|
| 333 |
+
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
|
| 334 |
+
"default_subfolder": "vae",
|
| 335 |
+
},
|
| 336 |
+
"ControlNetModel": {
|
| 337 |
+
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
| 338 |
+
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
| 339 |
+
},
|
| 340 |
+
"SD3Transformer2DModel": {
|
| 341 |
+
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
| 342 |
+
"default_subfolder": "transformer",
|
| 343 |
+
},
|
| 344 |
+
"MotionAdapter": {
|
| 345 |
+
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
| 346 |
+
},
|
| 347 |
+
"SparseControlNetModel": {
|
| 348 |
+
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
| 349 |
+
},
|
| 350 |
+
"FluxTransformer2DModel": {
|
| 351 |
+
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
| 352 |
+
"default_subfolder": "transformer",
|
| 353 |
+
},
|
| 354 |
+
"ChromaTransformer2DModel": {
|
| 355 |
+
"checkpoint_mapping_fn": convert_chroma_transformer_checkpoint_to_diffusers,
|
| 356 |
+
"default_subfolder": "transformer",
|
| 357 |
+
},
|
| 358 |
+
"LTXVideoTransformer3DModel": {
|
| 359 |
+
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
|
| 360 |
+
"default_subfolder": "transformer",
|
| 361 |
+
},
|
| 362 |
+
"AutoencoderKLLTXVideo": {
|
| 363 |
+
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
|
| 364 |
+
"default_subfolder": "vae",
|
| 365 |
+
},
|
| 366 |
+
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
|
| 367 |
+
"MochiTransformer3DModel": {
|
| 368 |
+
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
| 369 |
+
"default_subfolder": "transformer",
|
| 370 |
+
},
|
| 371 |
+
"HunyuanVideoTransformer3DModel": {
|
| 372 |
+
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
| 373 |
+
"default_subfolder": "transformer",
|
| 374 |
+
},
|
| 375 |
+
"AuraFlowTransformer2DModel": {
|
| 376 |
+
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
|
| 377 |
+
"default_subfolder": "transformer",
|
| 378 |
+
},
|
| 379 |
+
"Lumina2Transformer2DModel": {
|
| 380 |
+
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
| 381 |
+
"default_subfolder": "transformer",
|
| 382 |
+
},
|
| 383 |
+
"SanaTransformer2DModel": {
|
| 384 |
+
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
|
| 385 |
+
"default_subfolder": "transformer",
|
| 386 |
+
},
|
| 387 |
+
"WanTransformer3DModel": {
|
| 388 |
+
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
| 389 |
+
"default_subfolder": "transformer",
|
| 390 |
+
},
|
| 391 |
+
"WanVACETransformer3DModel": {
|
| 392 |
+
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
| 393 |
+
"default_subfolder": "transformer",
|
| 394 |
+
},
|
| 395 |
+
"AutoencoderKLWan": {
|
| 396 |
+
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
| 397 |
+
"default_subfolder": "vae",
|
| 398 |
+
},
|
| 399 |
+
"HiDreamImageTransformer2DModel": {
|
| 400 |
+
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
|
| 401 |
+
"default_subfolder": "transformer",
|
| 402 |
+
},
|
| 403 |
+
"CosmosTransformer3DModel": {
|
| 404 |
+
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
| 405 |
+
"default_subfolder": "transformer",
|
| 406 |
+
},
|
| 407 |
+
"QwenImageTransformer2DModel": {
|
| 408 |
+
"checkpoint_mapping_fn": lambda x: x,
|
| 409 |
+
"default_subfolder": "transformer",
|
| 410 |
+
},
|
| 411 |
+
"Flux2Transformer2DModel": {
|
| 412 |
+
"checkpoint_mapping_fn": convert_flux2_transformer_checkpoint_to_diffusers,
|
| 413 |
+
"default_subfolder": "transformer",
|
| 414 |
+
},
|
| 415 |
+
"ZImageTransformer2DModel": {
|
| 416 |
+
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
|
| 417 |
+
"default_subfolder": "transformer",
|
| 418 |
+
},
|
| 419 |
+
"ZImageControlTransformer2DModel": {
|
| 420 |
+
"checkpoint_mapping_fn": convert_z_image_control_transformer_checkpoint_to_diffusers,
|
| 421 |
+
"default_subfolder": "transformer",
|
| 422 |
+
},
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def get_class_obj_and_candidates(library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None):
|
| 427 |
+
"""Simple helper method to retrieve class object of module as well as potential parent class objects"""
|
| 428 |
+
component_folder = os.path.join(cache_dir, component_name) if component_name and cache_dir else None
|
| 429 |
+
|
| 430 |
+
if is_pipeline_module:
|
| 431 |
+
pipeline_module = getattr(pipelines, library_name)
|
| 432 |
+
|
| 433 |
+
class_obj = getattr(pipeline_module, class_name)
|
| 434 |
+
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
|
| 435 |
+
elif component_folder and os.path.isfile(os.path.join(component_folder, library_name + ".py")):
|
| 436 |
+
# load custom component
|
| 437 |
+
class_obj = get_class_from_dynamic_module(component_folder, module_file=library_name + ".py", class_name=class_name)
|
| 438 |
+
class_candidates = dict.fromkeys(importable_classes.keys(), class_obj)
|
| 439 |
+
else:
|
| 440 |
+
# else we just import it from the library.
|
| 441 |
+
library = importlib.import_module(library_name)
|
| 442 |
+
|
| 443 |
+
# Handle deprecated Transformers classes
|
| 444 |
+
if library_name == "transformers":
|
| 445 |
+
class_name = _maybe_remap_transformers_class(class_name) or class_name
|
| 446 |
+
|
| 447 |
+
try:
|
| 448 |
+
class_obj = getattr(library, class_name)
|
| 449 |
+
except Exception:
|
| 450 |
+
module = importlib.import_module("diffusers_local")
|
| 451 |
+
class_obj = getattr(module, class_name)
|
| 452 |
+
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
| 453 |
+
|
| 454 |
+
return class_obj, class_candidates
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def _get_single_file_loadable_mapping_class(cls):
|
| 458 |
+
diffusers_module = importlib.import_module("diffusers")
|
| 459 |
+
class_name_str = cls.__name__
|
| 460 |
+
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
| 461 |
+
try:
|
| 462 |
+
loadable_class = getattr(diffusers_module, loadable_class_str)
|
| 463 |
+
except Exception:
|
| 464 |
+
module = importlib.import_module("diffusers_local")
|
| 465 |
+
loadable_class = getattr(module, loadable_class_str)
|
| 466 |
+
if issubclass(cls, loadable_class):
|
| 467 |
+
return loadable_class_str
|
| 468 |
+
|
| 469 |
+
return class_name_str
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
def maybe_raise_or_warn(library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module):
|
| 473 |
+
"""Simple helper method to raise or warn in case incorrect module has been passed"""
|
| 474 |
+
if not is_pipeline_module:
|
| 475 |
+
library = importlib.import_module(library_name)
|
| 476 |
+
|
| 477 |
+
# Handle deprecated Transformers classes
|
| 478 |
+
if library_name == "transformers":
|
| 479 |
+
class_name = _maybe_remap_transformers_class(class_name) or class_name
|
| 480 |
+
|
| 481 |
+
try:
|
| 482 |
+
class_obj = getattr(library, class_name)
|
| 483 |
+
except Exception:
|
| 484 |
+
module = importlib.import_module("diffusers_local")
|
| 485 |
+
class_obj = getattr(module, class_name)
|
| 486 |
+
|
| 487 |
+
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
|
| 488 |
+
|
| 489 |
+
expected_class_obj = None
|
| 490 |
+
for class_name, class_candidate in class_candidates.items():
|
| 491 |
+
if class_candidate is not None and issubclass(class_obj, class_candidate):
|
| 492 |
+
expected_class_obj = class_candidate
|
| 493 |
+
|
| 494 |
+
# Dynamo wraps the original model in a private class.
|
| 495 |
+
# I didn't find a public API to get the original class.
|
| 496 |
+
sub_model = passed_class_obj[name]
|
| 497 |
+
unwrapped_sub_model = _unwrap_model(sub_model)
|
| 498 |
+
model_cls = unwrapped_sub_model.__class__
|
| 499 |
+
|
| 500 |
+
if not issubclass(model_cls, expected_class_obj):
|
| 501 |
+
raise ValueError(f"{passed_class_obj[name]} is of type: {model_cls}, but should be {expected_class_obj}")
|
| 502 |
+
else:
|
| 503 |
+
print(f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it has the correct type")
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
pipe_loading_utils.get_class_obj_and_candidates = get_class_obj_and_candidates
|
| 507 |
+
pipe_loading_utils.maybe_raise_or_warn = maybe_raise_or_warn
|
| 508 |
+
single_file_model.SINGLE_FILE_LOADABLE_CLASSES = SINGLE_FILE_LOADABLE_CLASSES
|
| 509 |
+
single_file_model._get_single_file_loadable_mapping_class = _get_single_file_loadable_mapping_class
|
diffusers_local/pipeline_z_image_control_unified.py
ADDED
|
@@ -0,0 +1,910 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
# Refactored and optimized by DEVAIEXP Team
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import inspect
|
| 18 |
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from diffusers import AutoencoderKL, DiffusionPipeline, FlowMatchEulerDiscreteScheduler
|
| 24 |
+
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
| 25 |
+
from diffusers.loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
|
| 26 |
+
from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput
|
| 27 |
+
from diffusers.utils import logging
|
| 28 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 29 |
+
from PIL import Image, ImageFilter
|
| 30 |
+
from transformers import AutoTokenizer, PreTrainedModel
|
| 31 |
+
|
| 32 |
+
from diffusers_local.z_image_control_transformer_2d import ZImageControlTransformer2DModel
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
logger = logging.get_logger(__name__)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def calculate_shift(
|
| 39 |
+
image_seq_len,
|
| 40 |
+
base_seq_len: int = 256,
|
| 41 |
+
max_seq_len: int = 4096,
|
| 42 |
+
base_shift: float = 0.5,
|
| 43 |
+
max_shift: float = 1.15,
|
| 44 |
+
):
|
| 45 |
+
"""
|
| 46 |
+
Calculates the shift value `mu` for the scheduler based on the image sequence length.
|
| 47 |
+
|
| 48 |
+
This function implements a linear interpolation to determine the shift value based on the input
|
| 49 |
+
image's sequence length, scaling between a base and a maximum shift value.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
image_seq_len (`int`):
|
| 53 |
+
The sequence length of the image latents (height * width).
|
| 54 |
+
base_seq_len (`int`, *optional*, defaults to 256):
|
| 55 |
+
The base sequence length for the shift calculation.
|
| 56 |
+
max_seq_len (`int`, *optional*, defaults to 4096):
|
| 57 |
+
The maximum sequence length for the shift calculation.
|
| 58 |
+
base_shift (`float`, *optional*, defaults to 0.5):
|
| 59 |
+
The shift value corresponding to `base_seq_len`.
|
| 60 |
+
max_shift (`float`, *optional*, defaults to 1.15):
|
| 61 |
+
The shift value corresponding to `max_seq_len`.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
`float`: The calculated shift value `mu`.
|
| 65 |
+
"""
|
| 66 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 67 |
+
b = base_shift - m * base_seq_len
|
| 68 |
+
mu = image_seq_len * m + b
|
| 69 |
+
return mu
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def retrieve_latents(encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"):
|
| 73 |
+
"""
|
| 74 |
+
Retrieves latents from a VAE encoder output.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
encoder_output (`torch.Tensor`):
|
| 78 |
+
The output of a VAE encoder.
|
| 79 |
+
generator (`torch.Generator`, *optional*):
|
| 80 |
+
A random number generator for sampling from the latent distribution.
|
| 81 |
+
sample_mode (`str`, *optional*, defaults to "sample"):
|
| 82 |
+
The method to retrieve latents. Can be "sample" to sample from the distribution or
|
| 83 |
+
"argmax" to take the mode.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
`torch.Tensor`: The retrieved latents.
|
| 87 |
+
"""
|
| 88 |
+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
| 89 |
+
return encoder_output.latent_dist.sample(generator)
|
| 90 |
+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
| 91 |
+
return encoder_output.latent_dist.mode()
|
| 92 |
+
elif hasattr(encoder_output, "latents"):
|
| 93 |
+
return encoder_output.latents
|
| 94 |
+
else:
|
| 95 |
+
raise AttributeError("Could not access latents of provided encoder_output")
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def retrieve_timesteps(
|
| 99 |
+
scheduler,
|
| 100 |
+
num_inference_steps: Optional[int] = None,
|
| 101 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 102 |
+
timesteps: Optional[List[int]] = None,
|
| 103 |
+
sigmas: Optional[List[float]] = None,
|
| 104 |
+
**kwargs,
|
| 105 |
+
):
|
| 106 |
+
"""
|
| 107 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 108 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
scheduler (`SchedulerMixin`):
|
| 112 |
+
The scheduler to get timesteps from.
|
| 113 |
+
num_inference_steps (`int`):
|
| 114 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 115 |
+
must be `None`.
|
| 116 |
+
device (`str` or `torch.device`, *optional*):
|
| 117 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 118 |
+
timesteps (`List[int]`, *optional*):
|
| 119 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 120 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 121 |
+
sigmas (`List[float]`, *optional*):
|
| 122 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 123 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 127 |
+
second element is the number of inference steps.
|
| 128 |
+
"""
|
| 129 |
+
if timesteps is not None and sigmas is not None:
|
| 130 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 131 |
+
if timesteps is not None:
|
| 132 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 133 |
+
if not accepts_timesteps:
|
| 134 |
+
raise ValueError(
|
| 135 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 136 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 137 |
+
)
|
| 138 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 139 |
+
timesteps = scheduler.timesteps
|
| 140 |
+
num_inference_steps = len(timesteps)
|
| 141 |
+
elif sigmas is not None:
|
| 142 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 143 |
+
if not accept_sigmas:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 146 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 147 |
+
)
|
| 148 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 149 |
+
timesteps = scheduler.timesteps
|
| 150 |
+
num_inference_steps = len(timesteps)
|
| 151 |
+
else:
|
| 152 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 153 |
+
timesteps = scheduler.timesteps
|
| 154 |
+
return timesteps, num_inference_steps
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class ZImageControlUnifiedPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
|
| 158 |
+
model_cpu_offload_seq = "text_encoder->vae->transformer"
|
| 159 |
+
_optional_components = []
|
| 160 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
| 161 |
+
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
scheduler: FlowMatchEulerDiscreteScheduler,
|
| 165 |
+
vae: AutoencoderKL,
|
| 166 |
+
text_encoder: PreTrainedModel,
|
| 167 |
+
tokenizer: AutoTokenizer,
|
| 168 |
+
transformer: ZImageControlTransformer2DModel,
|
| 169 |
+
):
|
| 170 |
+
"""
|
| 171 |
+
Initializes the ZImageControlUnifiedPipeline.
|
| 172 |
+
|
| 173 |
+
Args:
|
| 174 |
+
scheduler (`FlowMatchEulerDiscreteScheduler`):
|
| 175 |
+
A scheduler to be used in combination with `transformer` to denoise the latents.
|
| 176 |
+
vae (`AutoencoderKL`):
|
| 177 |
+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
| 178 |
+
text_encoder (`PreTrainedModel`):
|
| 179 |
+
A pretrained text encoder model.
|
| 180 |
+
tokenizer (`AutoTokenizer`):
|
| 181 |
+
A tokenizer to prepare text prompts for the `text_encoder`.
|
| 182 |
+
transformer (`ZImageControlTransformer2DModel`):
|
| 183 |
+
The main transformer model for the diffusion process.
|
| 184 |
+
"""
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.register_modules(
|
| 187 |
+
vae=vae,
|
| 188 |
+
text_encoder=text_encoder,
|
| 189 |
+
tokenizer=tokenizer,
|
| 190 |
+
scheduler=scheduler,
|
| 191 |
+
transformer=transformer,
|
| 192 |
+
)
|
| 193 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
| 194 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
| 195 |
+
self.mask_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
|
| 196 |
+
|
| 197 |
+
def encode_prompt(
|
| 198 |
+
self,
|
| 199 |
+
prompt: Union[str, List[str]],
|
| 200 |
+
device: Optional[torch.device] = None,
|
| 201 |
+
num_images_per_prompt: int = 1,
|
| 202 |
+
do_classifier_free_guidance: bool = True,
|
| 203 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 204 |
+
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
| 205 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 206 |
+
max_sequence_length: int = 512,
|
| 207 |
+
):
|
| 208 |
+
"""
|
| 209 |
+
Encodes the prompt into text embeddings.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
prompt (`Union[str, List[str]]`):
|
| 213 |
+
The prompt or prompts to guide the image generation.
|
| 214 |
+
device (`Optional[torch.device]`):
|
| 215 |
+
The device to move the embeddings to.
|
| 216 |
+
num_images_per_prompt (`int`):
|
| 217 |
+
The number of images to generate per prompt.
|
| 218 |
+
do_classifier_free_guidance (`bool`):
|
| 219 |
+
Whether to generate embeddings for classifier-free guidance.
|
| 220 |
+
negative_prompt (`Optional[Union[str, List[str]]]`):
|
| 221 |
+
The negative prompt or prompts.
|
| 222 |
+
prompt_embeds (`Optional[List[torch.FloatTensor]]`):
|
| 223 |
+
Pre-generated positive prompt embeddings.
|
| 224 |
+
negative_prompt_embeds (`Optional[torch.FloatTensor]`):
|
| 225 |
+
Pre-generated negative prompt embeddings.
|
| 226 |
+
max_sequence_length (`int`):
|
| 227 |
+
The maximum sequence length for tokenization.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
`Tuple[List[torch.Tensor], List[torch.Tensor]]`: A tuple containing the positive and negative prompt embeddings.
|
| 231 |
+
"""
|
| 232 |
+
device = device or self._execution_device
|
| 233 |
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 234 |
+
|
| 235 |
+
if prompt_embeds is not None:
|
| 236 |
+
pass
|
| 237 |
+
else:
|
| 238 |
+
prompt_embeds = self._encode_prompt(
|
| 239 |
+
prompt=prompt,
|
| 240 |
+
device=device,
|
| 241 |
+
max_sequence_length=max_sequence_length,
|
| 242 |
+
)
|
| 243 |
+
if num_images_per_prompt > 1:
|
| 244 |
+
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
| 245 |
+
|
| 246 |
+
if do_classifier_free_guidance:
|
| 247 |
+
if negative_prompt_embeds is not None:
|
| 248 |
+
pass
|
| 249 |
+
else:
|
| 250 |
+
if negative_prompt is None:
|
| 251 |
+
negative_prompt = [""] * len(prompt)
|
| 252 |
+
else:
|
| 253 |
+
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 254 |
+
assert len(prompt) == len(negative_prompt)
|
| 255 |
+
negative_prompt_embeds = self._encode_prompt(
|
| 256 |
+
prompt=negative_prompt,
|
| 257 |
+
device=device,
|
| 258 |
+
max_sequence_length=max_sequence_length,
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
if num_images_per_prompt > 1:
|
| 262 |
+
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
| 263 |
+
|
| 264 |
+
return prompt_embeds, negative_prompt_embeds
|
| 265 |
+
|
| 266 |
+
def _encode_prompt(self, prompt: Union[str, List[str]], device: torch.device, max_sequence_length: int) -> List[torch.Tensor]:
|
| 267 |
+
"""
|
| 268 |
+
Internal helper to encode a list of prompts into embeddings, applying chat templates if available.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
prompt (`Union[str, List[str]]`):
|
| 272 |
+
A list of strings to be encoded.
|
| 273 |
+
device (`torch.device`):
|
| 274 |
+
The target device for the embeddings.
|
| 275 |
+
max_sequence_length (`int`):
|
| 276 |
+
The maximum length for tokenization.
|
| 277 |
+
|
| 278 |
+
Returns:
|
| 279 |
+
`List[torch.Tensor]`: A list of embedding tensors, one for each prompt.
|
| 280 |
+
"""
|
| 281 |
+
formatted_prompts = []
|
| 282 |
+
for p in prompt:
|
| 283 |
+
messages = [{"role": "user", "content": p}]
|
| 284 |
+
if hasattr(self.tokenizer, "apply_chat_template"):
|
| 285 |
+
formatted_prompts.append(self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True))
|
| 286 |
+
else:
|
| 287 |
+
formatted_prompts.append(p)
|
| 288 |
+
|
| 289 |
+
text_inputs = self.tokenizer(
|
| 290 |
+
formatted_prompts,
|
| 291 |
+
padding="max_length",
|
| 292 |
+
max_length=max_sequence_length,
|
| 293 |
+
truncation=True,
|
| 294 |
+
return_tensors="pt",
|
| 295 |
+
).to(device)
|
| 296 |
+
|
| 297 |
+
prompt_masks = text_inputs.attention_mask.bool()
|
| 298 |
+
|
| 299 |
+
with torch.no_grad():
|
| 300 |
+
prompt_embeds_batch = self.text_encoder(input_ids=text_inputs.input_ids, attention_mask=prompt_masks, output_hidden_states=True).hidden_states[-2]
|
| 301 |
+
|
| 302 |
+
embeddings_list = []
|
| 303 |
+
for i in range(prompt_embeds_batch.shape[0]):
|
| 304 |
+
embeddings_list.append(prompt_embeds_batch[i][prompt_masks[i]])
|
| 305 |
+
|
| 306 |
+
return embeddings_list
|
| 307 |
+
|
| 308 |
+
def get_timesteps(self, num_inference_steps, strength, device):
|
| 309 |
+
"""
|
| 310 |
+
Calculates the timesteps for the scheduler based on the number of inference steps and strength.
|
| 311 |
+
This is primarily used for image-to-image pipelines.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
num_inference_steps (`int`): The total number of diffusion steps.
|
| 315 |
+
strength (`float`): The strength of the denoising process. A value of 1.0 means full denoising.
|
| 316 |
+
device (`torch.device`): The device to place the timesteps on.
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
`Tuple[torch.Tensor, int]`: A tuple containing the timesteps and the number of steps to run.
|
| 320 |
+
"""
|
| 321 |
+
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
| 322 |
+
|
| 323 |
+
t_start = int(max(num_inference_steps - init_timestep, 0))
|
| 324 |
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
| 325 |
+
if hasattr(self.scheduler, "set_begin_index"):
|
| 326 |
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
| 327 |
+
|
| 328 |
+
return timesteps, num_inference_steps - t_start
|
| 329 |
+
|
| 330 |
+
def prepare_latents(
|
| 331 |
+
self,
|
| 332 |
+
batch_size: int,
|
| 333 |
+
num_channels_latents: int,
|
| 334 |
+
height: int,
|
| 335 |
+
width: int,
|
| 336 |
+
dtype: torch.dtype,
|
| 337 |
+
device: torch.device,
|
| 338 |
+
generator: torch.Generator,
|
| 339 |
+
image: Optional[PipelineImageInput] = None,
|
| 340 |
+
timestep: Optional[torch.Tensor] = None,
|
| 341 |
+
latents: Optional[torch.Tensor] = None,
|
| 342 |
+
):
|
| 343 |
+
"""
|
| 344 |
+
Prepares the initial latents for the diffusion process.
|
| 345 |
+
|
| 346 |
+
This function handles three cases:
|
| 347 |
+
1. `latents` are provided: They are returned directly.
|
| 348 |
+
2. `image` is None (Text-to-Image): Random noise is generated.
|
| 349 |
+
3. `image` is provided (Image-to-Image): The image is encoded, and noise is added according to the timestep.
|
| 350 |
+
|
| 351 |
+
Args:
|
| 352 |
+
batch_size (`int`): The number of latents to generate.
|
| 353 |
+
num_channels_latents (`int`): The number of channels in the latents.
|
| 354 |
+
height (`int`): The height of the output image in pixels.
|
| 355 |
+
width (`int`): The width of the output image in pixels.
|
| 356 |
+
dtype (`torch.dtype`): The data type for the latents.
|
| 357 |
+
device (`torch.device`): The device to create the latents on.
|
| 358 |
+
generator (`torch.Generator`): A random generator for creating the initial noise.
|
| 359 |
+
image (`Optional[PipelineImageInput]`): An initial image for img2img mode.
|
| 360 |
+
timestep (`Optional[torch.Tensor]`): The starting timestep for adding noise in img2img mode.
|
| 361 |
+
latents (`Optional[torch.Tensor]`): Pre-generated latents.
|
| 362 |
+
|
| 363 |
+
Returns:
|
| 364 |
+
`torch.Tensor`: The prepared latents.
|
| 365 |
+
"""
|
| 366 |
+
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 367 |
+
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 368 |
+
shape = (batch_size, num_channels_latents, latent_height, latent_width)
|
| 369 |
+
|
| 370 |
+
if latents is not None:
|
| 371 |
+
return latents.to(device=device, dtype=dtype)
|
| 372 |
+
|
| 373 |
+
if image is None:
|
| 374 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 375 |
+
return latents
|
| 376 |
+
|
| 377 |
+
image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(device=device, dtype=self.vae.dtype)
|
| 378 |
+
with torch.no_grad():
|
| 379 |
+
if image_tensor.shape[1] != num_channels_latents:
|
| 380 |
+
if isinstance(generator, list):
|
| 381 |
+
image_latents = [retrieve_latents(self.vae.encode(image_tensor[i : i + 1]), generator=generator[i]) for i in range(image_tensor.shape[0])]
|
| 382 |
+
image_latents = torch.cat(image_latents, dim=0)
|
| 383 |
+
else:
|
| 384 |
+
image_latents = retrieve_latents(self.vae.encode(image_tensor), generator=generator)
|
| 385 |
+
|
| 386 |
+
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 387 |
+
image_latents = image_latents.to(dtype)
|
| 388 |
+
|
| 389 |
+
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
| 390 |
+
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
| 391 |
+
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
| 392 |
+
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
| 393 |
+
raise ValueError(f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts.")
|
| 394 |
+
|
| 395 |
+
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 396 |
+
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
| 397 |
+
|
| 398 |
+
return latents
|
| 399 |
+
|
| 400 |
+
def _prepare_image_latents(
|
| 401 |
+
self,
|
| 402 |
+
image: PipelineImageInput,
|
| 403 |
+
mask_image: PipelineImageInput,
|
| 404 |
+
width: int,
|
| 405 |
+
height: int,
|
| 406 |
+
batch_size: int,
|
| 407 |
+
num_images_per_prompt: int,
|
| 408 |
+
device: torch.device,
|
| 409 |
+
dtype: torch.dtype,
|
| 410 |
+
do_preprocess: bool = True,
|
| 411 |
+
) -> torch.Tensor:
|
| 412 |
+
"""
|
| 413 |
+
Generic function to encode an image into 5D latents for inpainting context.
|
| 414 |
+
|
| 415 |
+
If `do_preprocess` is True, it processes the image (PIL/np).
|
| 416 |
+
If `do_preprocess` is False, it assumes 'image' is already a ready-to-use tensor.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
image (`PipelineImageInput`): The input image. Can be None to return zeros.
|
| 420 |
+
width (`int`): The target width.
|
| 421 |
+
height (`int`): The target height.
|
| 422 |
+
batch_size (`int`): The prompt batch size.
|
| 423 |
+
num_images_per_prompt (`int`): The number of images per prompt.
|
| 424 |
+
device (`torch.device`): The target device.
|
| 425 |
+
dtype (`torch.dtype`): The target data type.
|
| 426 |
+
do_preprocess (`bool`): Whether to preprocess the image.
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
`torch.Tensor`: A 5D tensor of the encoded image latents.
|
| 430 |
+
"""
|
| 431 |
+
if image is None:
|
| 432 |
+
latent_h = height // self.vae_scale_factor
|
| 433 |
+
latent_w = width // self.vae_scale_factor
|
| 434 |
+
shape = (batch_size * num_images_per_prompt, self.transformer.in_channels, 1, latent_h, latent_w)
|
| 435 |
+
return torch.zeros(shape, device=device, dtype=dtype)
|
| 436 |
+
|
| 437 |
+
if do_preprocess:
|
| 438 |
+
image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(device=device, dtype=self.vae.dtype)
|
| 439 |
+
else:
|
| 440 |
+
image_tensor = image.to(device=device, dtype=self.vae.dtype)
|
| 441 |
+
|
| 442 |
+
if mask_image is not None:
|
| 443 |
+
mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width).to(device=device, dtype=self.vae.dtype)
|
| 444 |
+
# Tile para 3 canais (RGB)
|
| 445 |
+
mask_condition = torch.tile(mask_condition, [1, 3, 1, 1])
|
| 446 |
+
# Aplica máscara: mantém apenas áreas escuras (< 0.5)
|
| 447 |
+
image_tensor = image_tensor * (mask_condition < 0.5)
|
| 448 |
+
|
| 449 |
+
with torch.no_grad():
|
| 450 |
+
latents = retrieve_latents(self.vae.encode(image_tensor), sample_mode="argmax")
|
| 451 |
+
latents = (latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 452 |
+
|
| 453 |
+
effective_batch_size = batch_size * num_images_per_prompt
|
| 454 |
+
if latents.shape[0] != effective_batch_size:
|
| 455 |
+
repeat_by = effective_batch_size // latents.shape[0]
|
| 456 |
+
latents = latents.repeat_interleave(repeat_by, dim=0)
|
| 457 |
+
|
| 458 |
+
return latents.to(dtype=dtype).unsqueeze(2)
|
| 459 |
+
|
| 460 |
+
def _prepare_mask_latents(
|
| 461 |
+
self,
|
| 462 |
+
mask_image: PipelineImageInput,
|
| 463 |
+
width: int,
|
| 464 |
+
height: int,
|
| 465 |
+
batch_size: int,
|
| 466 |
+
num_images_per_prompt: int,
|
| 467 |
+
reference_latents_shape: Tuple,
|
| 468 |
+
device: torch.device,
|
| 469 |
+
dtype: torch.dtype,
|
| 470 |
+
) -> torch.Tensor:
|
| 471 |
+
"""
|
| 472 |
+
Processes a MASK using the mask_processor, inverts it, resizes it, and formats it for the control_context.
|
| 473 |
+
|
| 474 |
+
Args:
|
| 475 |
+
mask_image (`PipelineImageInput`): The mask image. Can be None to return zeros.
|
| 476 |
+
width (`int`): The target width.
|
| 477 |
+
height (`int`): The target height.
|
| 478 |
+
batch_size (`int`): The prompt batch size.
|
| 479 |
+
num_images_per_prompt (`int`): The number of images per prompt.
|
| 480 |
+
reference_latents_shape (`Tuple`): The shape of the inpainting latents for resizing.
|
| 481 |
+
device (`torch.device`): The target device.
|
| 482 |
+
dtype (`torch.dtype`): The target data type.
|
| 483 |
+
|
| 484 |
+
Returns:
|
| 485 |
+
`torch.Tensor`: A 5D tensor of the processed mask latents.
|
| 486 |
+
"""
|
| 487 |
+
if mask_image is None:
|
| 488 |
+
placeholder_shape = (
|
| 489 |
+
batch_size * num_images_per_prompt,
|
| 490 |
+
1,
|
| 491 |
+
1,
|
| 492 |
+
reference_latents_shape[-2],
|
| 493 |
+
reference_latents_shape[-1],
|
| 494 |
+
)
|
| 495 |
+
return torch.zeros(placeholder_shape, device=device, dtype=dtype)
|
| 496 |
+
|
| 497 |
+
mask_condition = self.mask_processor.preprocess(mask_image, height=height, width=width).to(device=device, dtype=dtype)
|
| 498 |
+
|
| 499 |
+
mask_for_inpainting = 1.0 - mask_condition
|
| 500 |
+
|
| 501 |
+
mask_latents = F.interpolate(mask_for_inpainting, size=reference_latents_shape[-2:], mode="nearest")
|
| 502 |
+
|
| 503 |
+
return mask_latents.unsqueeze(2)
|
| 504 |
+
|
| 505 |
+
def prepare_control_latents(
|
| 506 |
+
self, image: PipelineImageInput, width: int, height: int, batch_size: int, num_images_per_prompt: int, device: torch.device, dtype: torch.dtype
|
| 507 |
+
) -> torch.Tensor:
|
| 508 |
+
"""
|
| 509 |
+
Preprocesses a control image, ENCODES it with the VAE to latent space,
|
| 510 |
+
and returns a 5D tensor ready for the transformer model.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
image (`PipelineImageInput`): The control image. Can be None to return zeros.
|
| 514 |
+
width (`int`): The target width.
|
| 515 |
+
height (`int`): The target height.
|
| 516 |
+
batch_size (`int`): The prompt batch size.
|
| 517 |
+
num_images_per_prompt (`int`): The number of images per prompt.
|
| 518 |
+
device (`torch.device`): The target device.
|
| 519 |
+
dtype (`torch.dtype`): The target data type.
|
| 520 |
+
|
| 521 |
+
Returns:
|
| 522 |
+
`torch.Tensor`: A 5D tensor of the control image latents.
|
| 523 |
+
"""
|
| 524 |
+
if image is None:
|
| 525 |
+
latent_h = 2 * (int(height) // (self.vae_scale_factor * 2))
|
| 526 |
+
latent_w = 2 * (int(width) // (self.vae_scale_factor * 2))
|
| 527 |
+
return torch.zeros(
|
| 528 |
+
(batch_size * num_images_per_prompt, self.transformer.in_channels, 1, latent_h, latent_w),
|
| 529 |
+
device=device,
|
| 530 |
+
dtype=dtype,
|
| 531 |
+
)
|
| 532 |
+
|
| 533 |
+
image_tensor = self.image_processor.preprocess(image, height=height, width=width).to(device=device, dtype=self.vae.dtype)
|
| 534 |
+
with torch.no_grad():
|
| 535 |
+
latents = retrieve_latents(self.vae.encode(image_tensor), sample_mode="argmax")
|
| 536 |
+
latents = (latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
| 537 |
+
|
| 538 |
+
effective_batch_size = batch_size * num_images_per_prompt
|
| 539 |
+
if latents.shape[0] < effective_batch_size:
|
| 540 |
+
latents = latents.repeat_interleave(effective_batch_size // latents.shape[0], dim=0)
|
| 541 |
+
|
| 542 |
+
return latents.to(dtype=dtype).unsqueeze(2)
|
| 543 |
+
|
| 544 |
+
def _apply_mask_blur(self, mask_image, mask_blur_radius, is_inpaint_mode):
|
| 545 |
+
"""
|
| 546 |
+
Apply Gaussian blur to a mask image for inpainting operations.
|
| 547 |
+
Args:
|
| 548 |
+
mask_image (Image.Image | np.ndarray | torch.Tensor): The mask image to be blurred.
|
| 549 |
+
Can be provided as a PIL Image, NumPy array, or PyTorch tensor.
|
| 550 |
+
mask_blur_radius (float): The radius of the Gaussian blur filter in pixels.
|
| 551 |
+
Only applied if is_inpaint_mode is True and mask_blur_radius > 0.
|
| 552 |
+
is_inpaint_mode (bool): Flag indicating whether the pipeline is in inpainting mode.
|
| 553 |
+
Blur is only applied when this is True.
|
| 554 |
+
Returns:
|
| 555 |
+
Image.Image | np.ndarray | torch.Tensor: The mask image with Gaussian blur applied
|
| 556 |
+
if is_inpaint_mode is True and mask_blur_radius > 0. Otherwise, returns the
|
| 557 |
+
original mask_image unchanged. The return type matches the input type.
|
| 558 |
+
"""
|
| 559 |
+
mask_to_use = mask_image
|
| 560 |
+
if is_inpaint_mode and mask_blur_radius > 0:
|
| 561 |
+
if isinstance(mask_image, Image.Image):
|
| 562 |
+
mask_pil = mask_image
|
| 563 |
+
elif isinstance(mask_image, np.ndarray):
|
| 564 |
+
mask_pil = Image.fromarray(mask_image)
|
| 565 |
+
elif isinstance(mask_image, torch.Tensor):
|
| 566 |
+
mask_pil = Image.fromarray(mask_image.cpu().numpy().astype(np.uint8))
|
| 567 |
+
else:
|
| 568 |
+
mask_pil = mask_image
|
| 569 |
+
|
| 570 |
+
mask_to_use = mask_pil.filter(ImageFilter.GaussianBlur(radius=mask_blur_radius))
|
| 571 |
+
return mask_to_use
|
| 572 |
+
|
| 573 |
+
@property
|
| 574 |
+
def guidance_scale(self):
|
| 575 |
+
return self._guidance_scale
|
| 576 |
+
|
| 577 |
+
@property
|
| 578 |
+
def do_classifier_free_guidance(self):
|
| 579 |
+
return self._guidance_scale > 1
|
| 580 |
+
|
| 581 |
+
@property
|
| 582 |
+
def joint_attention_kwargs(self):
|
| 583 |
+
return self._joint_attention_kwargs
|
| 584 |
+
|
| 585 |
+
@property
|
| 586 |
+
def num_timesteps(self):
|
| 587 |
+
return self._num_timesteps
|
| 588 |
+
|
| 589 |
+
@property
|
| 590 |
+
def interrupt(self):
|
| 591 |
+
return self._interrupt
|
| 592 |
+
|
| 593 |
+
def __call__(
|
| 594 |
+
self,
|
| 595 |
+
prompt: Union[str, List[str]],
|
| 596 |
+
image: Optional[PipelineImageInput] = None,
|
| 597 |
+
mask_image: Optional[PipelineImageInput] = None,
|
| 598 |
+
mask_blur_radius: float = 4.0,
|
| 599 |
+
control_image: Optional[PipelineImageInput] = None,
|
| 600 |
+
height: Optional[int] = None,
|
| 601 |
+
width: Optional[int] = None,
|
| 602 |
+
num_inference_steps: int = 20,
|
| 603 |
+
sigmas: Optional[List[float]] = None,
|
| 604 |
+
strength: float = 1.0,
|
| 605 |
+
guidance_scale: float = 4.0,
|
| 606 |
+
cfg_normalization: bool = False,
|
| 607 |
+
cfg_truncation: float = 1.0,
|
| 608 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 609 |
+
num_images_per_prompt: int = 1,
|
| 610 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 611 |
+
latents: Optional[torch.Tensor] = None,
|
| 612 |
+
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
| 613 |
+
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
| 614 |
+
controlnet_conditioning_scale: float = 1.0,
|
| 615 |
+
controlnet_refiner_conditioning_scale: float = 1.0,
|
| 616 |
+
output_type: str = "pil",
|
| 617 |
+
return_dict: bool = True,
|
| 618 |
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 619 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 620 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 621 |
+
max_sequence_length: int = 512,
|
| 622 |
+
):
|
| 623 |
+
r"""
|
| 624 |
+
The main entry point for the Z-Image unified pipeline for generation.
|
| 625 |
+
|
| 626 |
+
Args:
|
| 627 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 628 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 629 |
+
image (`PipelineImageInput`, *optional*):
|
| 630 |
+
The initial image for image-to-image or inpainting modes.
|
| 631 |
+
mask_image (`PipelineImageInput`, *optional*):
|
| 632 |
+
The mask image for inpainting. White areas are preserved, black areas are inpainted.
|
| 633 |
+
mask_blur_radius (`float`, *optional*, defaults to 4.0):
|
| 634 |
+
The radius for blurring the edges of the inpainting mask to create a smoother transition.
|
| 635 |
+
control_image (`PipelineImageInput`, *optional*):
|
| 636 |
+
The conditioning image for control modes (e.g., Canny, depth).
|
| 637 |
+
height (`int`, *optional*, defaults to 1024):
|
| 638 |
+
The height in pixels of the generated image.
|
| 639 |
+
width (`int`, *optional*, defaults to 1024):
|
| 640 |
+
The width in pixels of the generated image.
|
| 641 |
+
num_inference_steps (`int`, *optional*, defaults to 20):
|
| 642 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 643 |
+
expense of slower inference.
|
| 644 |
+
sigmas (`List[float]`, *optional*):
|
| 645 |
+
Custom sigmas to use for the denoising process. If not defined, the scheduler's default behavior
|
| 646 |
+
will be used.
|
| 647 |
+
strength (`float`, *optional*, defaults to 1.0):
|
| 648 |
+
Denoising strength for image-to-image. A value of 1.0 means the initial image is fully replaced,
|
| 649 |
+
while a lower value preserves more of the original image structure. Only used in img2img mode.
|
| 650 |
+
guidance_scale (`float`, *optional*, defaults to 4.0):
|
| 651 |
+
The scale for classifier-free guidance. A value > 1 enables it. Higher values encourage images
|
| 652 |
+
closer to the prompt, potentially at the cost of quality.
|
| 653 |
+
cfg_normalization (`bool`, *optional*, defaults to False):
|
| 654 |
+
Whether to apply normalization to the guidance, which can prevent oversaturation.
|
| 655 |
+
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
| 656 |
+
A value between 0.0 and 1.0 that disables CFG for the final portion of the denoising steps,
|
| 657 |
+
specified as a fraction of total steps. For example, 0.8 disables CFG for the last 20% of steps.
|
| 658 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 659 |
+
The prompt or prompts not to guide the image generation.
|
| 660 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 661 |
+
The number of images to generate per prompt.
|
| 662 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 663 |
+
A torch generator to make generation deterministic.
|
| 664 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 665 |
+
Pre-generated noisy latents to be used as inputs for image generation.
|
| 666 |
+
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
| 667 |
+
Pre-generated positive text embeddings.
|
| 668 |
+
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
| 669 |
+
Pre-generated negative text embeddings.
|
| 670 |
+
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
| 671 |
+
The scale of the control conditioning influence.
|
| 672 |
+
controlnet_refiner_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
| 673 |
+
The scale of the control refiner conditioning influence.
|
| 674 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 675 |
+
The output format of the generated image. Choose between "pil" (`PIL.Image.Image`), "np.array", or "latent".
|
| 676 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 677 |
+
Whether to return a `ZImagePipelineOutput` instead of a plain tuple.
|
| 678 |
+
joint_attention_kwargs (`dict`, *optional*):
|
| 679 |
+
A kwargs dictionary for the `AttentionProcessor`.
|
| 680 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 681 |
+
A function that is called at the end of each denoising step.
|
| 682 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 683 |
+
The list of tensor inputs for the `callback_on_step_end` function.
|
| 684 |
+
max_sequence_length (`int`, *optional*, defaults to 512):
|
| 685 |
+
Maximum sequence length to use with the `prompt`.
|
| 686 |
+
|
| 687 |
+
Examples:
|
| 688 |
+
|
| 689 |
+
Returns:
|
| 690 |
+
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`:
|
| 691 |
+
If `return_dict` is True, a `ZImagePipelineOutput` is returned, otherwise a `tuple` with the generated images.
|
| 692 |
+
"""
|
| 693 |
+
self._guidance_scale = guidance_scale
|
| 694 |
+
self._joint_attention_kwargs = joint_attention_kwargs
|
| 695 |
+
self._interrupt = False
|
| 696 |
+
self._cfg_normalization = cfg_normalization
|
| 697 |
+
self._cfg_truncation = cfg_truncation
|
| 698 |
+
is_two_stage_control_model = self.transformer.control_in_dim > self.transformer.in_channels if hasattr(self.transformer, "control_in_dim") else False
|
| 699 |
+
device = self._execution_device
|
| 700 |
+
dtype = self.transformer.dtype
|
| 701 |
+
vae_scale = self.vae_scale_factor * 2
|
| 702 |
+
|
| 703 |
+
ref_image = control_image or image
|
| 704 |
+
image_height = None
|
| 705 |
+
image_width = None
|
| 706 |
+
if ref_image is not None:
|
| 707 |
+
if isinstance(ref_image, Image.Image):
|
| 708 |
+
image_height, image_width = ref_image.height, ref_image.width
|
| 709 |
+
else:
|
| 710 |
+
image_height, image_width = ref_image.shape[-2], ref_image.shape[-1]
|
| 711 |
+
|
| 712 |
+
height = height or image_height or 1024
|
| 713 |
+
width = width or image_width or 1024
|
| 714 |
+
|
| 715 |
+
if height % vae_scale != 0 or width % vae_scale != 0:
|
| 716 |
+
raise ValueError(f"Height/width must be divisible by {vae_scale}.")
|
| 717 |
+
|
| 718 |
+
batch_size = len(prompt) if isinstance(prompt, list) else 1 if prompt else len(prompt_embeds)
|
| 719 |
+
effective_batch_size = batch_size * num_images_per_prompt
|
| 720 |
+
|
| 721 |
+
if prompt_embeds is not None and prompt is None:
|
| 722 |
+
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 723 |
+
raise ValueError(
|
| 724 |
+
"When `prompt_embeds` is provided without `prompt`, `negative_prompt_embeds` must also be provided for classifier-free guidance."
|
| 725 |
+
)
|
| 726 |
+
else:
|
| 727 |
+
(
|
| 728 |
+
prompt_embeds,
|
| 729 |
+
negative_prompt_embeds,
|
| 730 |
+
) = self.encode_prompt(
|
| 731 |
+
prompt=prompt,
|
| 732 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 733 |
+
negative_prompt=negative_prompt,
|
| 734 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 735 |
+
prompt_embeds=prompt_embeds,
|
| 736 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 737 |
+
device=device,
|
| 738 |
+
max_sequence_length=max_sequence_length,
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
if self.do_classifier_free_guidance:
|
| 742 |
+
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
| 743 |
+
else:
|
| 744 |
+
prompt_embeds_model_input = prompt_embeds
|
| 745 |
+
|
| 746 |
+
is_inpaint_mode = image is not None and mask_image is not None
|
| 747 |
+
is_img2img_mode = image is not None and not is_inpaint_mode
|
| 748 |
+
|
| 749 |
+
if control_image is not None or is_inpaint_mode:
|
| 750 |
+
control_latents = self.prepare_control_latents(control_image, width, height, batch_size, num_images_per_prompt, device, dtype)
|
| 751 |
+
|
| 752 |
+
if is_two_stage_control_model:
|
| 753 |
+
mask_to_use = self._apply_mask_blur(mask_image, mask_blur_radius, is_inpaint_mode)
|
| 754 |
+
|
| 755 |
+
inpaint_latents = self._prepare_image_latents(
|
| 756 |
+
image, mask_to_use, width, height, batch_size, num_images_per_prompt, device, dtype, do_preprocess=True
|
| 757 |
+
)
|
| 758 |
+
|
| 759 |
+
mask_latents = self._prepare_mask_latents(
|
| 760 |
+
mask_to_use,
|
| 761 |
+
width,
|
| 762 |
+
height,
|
| 763 |
+
batch_size,
|
| 764 |
+
num_images_per_prompt,
|
| 765 |
+
inpaint_latents.shape,
|
| 766 |
+
device,
|
| 767 |
+
dtype,
|
| 768 |
+
)
|
| 769 |
+
control_context = torch.cat([control_latents, mask_latents, inpaint_latents], dim=1)
|
| 770 |
+
else:
|
| 771 |
+
control_context = control_latents
|
| 772 |
+
else:
|
| 773 |
+
control_context = None
|
| 774 |
+
|
| 775 |
+
if self.do_classifier_free_guidance:
|
| 776 |
+
control_context_model_input = control_context.repeat(2, 1, 1, 1, 1)
|
| 777 |
+
else:
|
| 778 |
+
control_context_model_input = control_context
|
| 779 |
+
|
| 780 |
+
image_seq_len = (height // (self.vae_scale_factor * 2)) * (width // (self.vae_scale_factor * 2))
|
| 781 |
+
mu = calculate_shift(image_seq_len)
|
| 782 |
+
self.scheduler.sigma_min = 0.0
|
| 783 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas, mu=mu)
|
| 784 |
+
self._num_timesteps = len(timesteps)
|
| 785 |
+
|
| 786 |
+
if is_img2img_mode and not is_inpaint_mode:
|
| 787 |
+
strength = min(strength, 1.0)
|
| 788 |
+
else:
|
| 789 |
+
strength = 1.0
|
| 790 |
+
|
| 791 |
+
if strength < 1.0:
|
| 792 |
+
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
| 793 |
+
t_start = max(num_inference_steps - init_timestep, 0)
|
| 794 |
+
timesteps = timesteps[t_start * self.scheduler.order :]
|
| 795 |
+
num_steps_to_run = len(timesteps) // self.scheduler.order
|
| 796 |
+
else:
|
| 797 |
+
num_steps_to_run = num_inference_steps
|
| 798 |
+
|
| 799 |
+
latent_timestep = timesteps[:1].repeat(effective_batch_size) if strength < 1.0 else None
|
| 800 |
+
|
| 801 |
+
use_image_for_latents = is_img2img_mode and not is_inpaint_mode
|
| 802 |
+
latents = self.prepare_latents(
|
| 803 |
+
effective_batch_size,
|
| 804 |
+
self.transformer.in_channels,
|
| 805 |
+
height,
|
| 806 |
+
width,
|
| 807 |
+
torch.float32,
|
| 808 |
+
device,
|
| 809 |
+
generator,
|
| 810 |
+
image=image if use_image_for_latents else None,
|
| 811 |
+
timestep=latent_timestep if use_image_for_latents else None,
|
| 812 |
+
latents=latents,
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
num_warmup_steps = len(timesteps) - num_steps_to_run * self.scheduler.order
|
| 816 |
+
with torch.inference_mode():
|
| 817 |
+
with self.progress_bar(total=num_steps_to_run) as progress_bar:
|
| 818 |
+
for i, t in enumerate(timesteps):
|
| 819 |
+
if self.interrupt:
|
| 820 |
+
continue
|
| 821 |
+
|
| 822 |
+
timestep = t.expand(latents.shape[0])
|
| 823 |
+
timestep = (1000 - timestep) / 1000
|
| 824 |
+
t_norm = timestep[0].item()
|
| 825 |
+
|
| 826 |
+
current_guidance_scale = self.guidance_scale
|
| 827 |
+
if self.do_classifier_free_guidance and self._cfg_truncation is not None and float(self._cfg_truncation) <= 1:
|
| 828 |
+
if t_norm > self._cfg_truncation:
|
| 829 |
+
current_guidance_scale = 0.0
|
| 830 |
+
|
| 831 |
+
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
| 832 |
+
|
| 833 |
+
if apply_cfg:
|
| 834 |
+
latents_typed = latents.to(self.transformer.dtype)
|
| 835 |
+
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
| 836 |
+
timestep_model_input = timestep.repeat(2)
|
| 837 |
+
else:
|
| 838 |
+
latent_model_input = latents.to(self.transformer.dtype)
|
| 839 |
+
timestep_model_input = timestep
|
| 840 |
+
|
| 841 |
+
latent_model_input = latent_model_input.unsqueeze(2)
|
| 842 |
+
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
| 843 |
+
|
| 844 |
+
model_out_list = self.transformer(
|
| 845 |
+
x=latent_model_input_list,
|
| 846 |
+
t=timestep_model_input,
|
| 847 |
+
cap_feats=prompt_embeds_model_input,
|
| 848 |
+
control_context=control_context_model_input,
|
| 849 |
+
conditioning_scale=controlnet_conditioning_scale,
|
| 850 |
+
refiner_conditioning_scale=controlnet_refiner_conditioning_scale,
|
| 851 |
+
)[0]
|
| 852 |
+
|
| 853 |
+
if apply_cfg:
|
| 854 |
+
pos_out = model_out_list[:effective_batch_size]
|
| 855 |
+
neg_out = model_out_list[effective_batch_size:]
|
| 856 |
+
|
| 857 |
+
noise_pred = []
|
| 858 |
+
for j in range(effective_batch_size):
|
| 859 |
+
pos = pos_out[j].float()
|
| 860 |
+
neg = neg_out[j].float()
|
| 861 |
+
|
| 862 |
+
pred = pos + current_guidance_scale * (pos - neg)
|
| 863 |
+
|
| 864 |
+
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
| 865 |
+
ori_pos_norm = torch.linalg.vector_norm(pos)
|
| 866 |
+
new_pos_norm = torch.linalg.vector_norm(pred)
|
| 867 |
+
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
| 868 |
+
if new_pos_norm > max_new_norm:
|
| 869 |
+
pred = pred * (max_new_norm / new_pos_norm)
|
| 870 |
+
|
| 871 |
+
noise_pred.append(pred)
|
| 872 |
+
|
| 873 |
+
noise_pred = torch.stack(noise_pred, dim=0)
|
| 874 |
+
else:
|
| 875 |
+
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
| 876 |
+
|
| 877 |
+
noise_pred = noise_pred.squeeze(2)
|
| 878 |
+
noise_pred = -noise_pred
|
| 879 |
+
|
| 880 |
+
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents).prev_sample
|
| 881 |
+
|
| 882 |
+
if callback_on_step_end is not None:
|
| 883 |
+
callback_kwargs = {}
|
| 884 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 885 |
+
callback_kwargs[k] = locals()[k]
|
| 886 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 887 |
+
|
| 888 |
+
if isinstance(callback_outputs, dict):
|
| 889 |
+
latents = callback_outputs.pop("latents", latents)
|
| 890 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 891 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
| 892 |
+
|
| 893 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 894 |
+
progress_bar.update()
|
| 895 |
+
|
| 896 |
+
if output_type != "latent":
|
| 897 |
+
latents = latents.to(self.vae.dtype)
|
| 898 |
+
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
| 899 |
+
with torch.no_grad():
|
| 900 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 901 |
+
image = self.image_processor.postprocess(image, output_type=output_type)
|
| 902 |
+
else:
|
| 903 |
+
image = latents
|
| 904 |
+
|
| 905 |
+
self.maybe_free_model_hooks()
|
| 906 |
+
|
| 907 |
+
if not return_dict:
|
| 908 |
+
return (image,)
|
| 909 |
+
|
| 910 |
+
return ZImagePipelineOutput(images=image)
|
diffusers_local/z_image_control_transformer_2d.py
ADDED
|
@@ -0,0 +1,1443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
| 2 |
+
# Refactored and optimized by DEVAIEXP Team
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from typing import Dict, List, Optional, Tuple, Union
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 24 |
+
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
|
| 25 |
+
from diffusers.models.attention_dispatch import dispatch_attention_fn
|
| 26 |
+
from diffusers.models.attention_processor import Attention, AttentionProcessor
|
| 27 |
+
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
| 28 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 29 |
+
from diffusers.models.normalization import RMSNorm
|
| 30 |
+
from diffusers.utils import (
|
| 31 |
+
is_torch_version,
|
| 32 |
+
)
|
| 33 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 34 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
ADALN_EMBED_DIM = 256
|
| 38 |
+
SEQ_MULTI_OF = 32
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def zero_module(module):
|
| 42 |
+
"""
|
| 43 |
+
Initializes the parameters of a given module with zeros.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
module (nn.Module): The module to be zero-initialized.
|
| 47 |
+
|
| 48 |
+
Returns:
|
| 49 |
+
nn.Module: The same module with its parameters initialized to zero.
|
| 50 |
+
"""
|
| 51 |
+
for p in module.parameters():
|
| 52 |
+
nn.init.zeros_(p)
|
| 53 |
+
return module
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class TimestepEmbedder(nn.Module):
|
| 57 |
+
"""
|
| 58 |
+
A module to embed timesteps into a higher-dimensional space using sinusoidal embeddings
|
| 59 |
+
followed by a multilayer perceptron (MLP).
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
|
| 63 |
+
"""
|
| 64 |
+
Initializes the TimestepEmbedder module.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
out_size (int): The output dimension of the embedding.
|
| 68 |
+
mid_size (int, optional): The intermediate dimension of the MLP. Defaults to `out_size`.
|
| 69 |
+
frequency_embedding_size (int, optional): The dimension of the sinusoidal frequency embedding. Defaults to 256.
|
| 70 |
+
"""
|
| 71 |
+
super().__init__()
|
| 72 |
+
if mid_size is None:
|
| 73 |
+
mid_size = out_size
|
| 74 |
+
self.mlp = nn.Sequential(
|
| 75 |
+
nn.Linear(
|
| 76 |
+
frequency_embedding_size,
|
| 77 |
+
mid_size,
|
| 78 |
+
bias=True,
|
| 79 |
+
),
|
| 80 |
+
nn.SiLU(),
|
| 81 |
+
nn.Linear(
|
| 82 |
+
mid_size,
|
| 83 |
+
out_size,
|
| 84 |
+
bias=True,
|
| 85 |
+
),
|
| 86 |
+
)
|
| 87 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 88 |
+
|
| 89 |
+
@staticmethod
|
| 90 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 91 |
+
"""
|
| 92 |
+
Creates sinusoidal timestep embeddings.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
t (torch.Tensor): A 1-D Tensor of N timesteps.
|
| 96 |
+
dim (int): The dimension of the embedding.
|
| 97 |
+
max_period (int, optional): The maximum period for the sinusoidal frequencies. Defaults to 10000.
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
torch.Tensor: The timestep embeddings with shape (N, dim).
|
| 101 |
+
"""
|
| 102 |
+
with torch.amp.autocast("cuda", enabled=False):
|
| 103 |
+
half = dim // 2
|
| 104 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
| 105 |
+
args = t[:, None] * freqs[None]
|
| 106 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 107 |
+
if dim % 2:
|
| 108 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 109 |
+
return embedding
|
| 110 |
+
|
| 111 |
+
def forward(self, t):
|
| 112 |
+
"""
|
| 113 |
+
Processes the input timesteps to generate embeddings.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
t (torch.Tensor): The input timesteps.
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
torch.Tensor: The final timestep embeddings after passing through the MLP.
|
| 120 |
+
"""
|
| 121 |
+
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
| 122 |
+
weight_dtype = self.mlp[0].weight.dtype
|
| 123 |
+
if weight_dtype.is_floating_point:
|
| 124 |
+
t_freq = t_freq.to(weight_dtype)
|
| 125 |
+
t_emb = self.mlp(t_freq)
|
| 126 |
+
return t_emb
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class FeedForward(nn.Module):
|
| 130 |
+
"""
|
| 131 |
+
A Feed-Forward Network module using SwiGLU activation.
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
def __init__(self, dim: int, hidden_dim: int):
|
| 135 |
+
"""
|
| 136 |
+
Initializes the FeedForward module.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
dim (int): Input and output dimension.
|
| 140 |
+
hidden_dim (int): The hidden dimension of the network.
|
| 141 |
+
"""
|
| 142 |
+
super().__init__()
|
| 143 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
| 144 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
| 145 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
| 146 |
+
|
| 147 |
+
def _forward_silu_gating(self, x1, x3):
|
| 148 |
+
"""
|
| 149 |
+
Applies the SiLU gating mechanism.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
x1 (torch.Tensor): The first intermediate tensor.
|
| 153 |
+
x3 (torch.Tensor): The second intermediate tensor (gate).
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
torch.Tensor: The result of the gating operation.
|
| 157 |
+
"""
|
| 158 |
+
return F.silu(x1) * x3
|
| 159 |
+
|
| 160 |
+
def forward(self, x):
|
| 161 |
+
"""
|
| 162 |
+
Defines the forward pass of the FeedForward network.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
x (torch.Tensor): The input tensor.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
torch.Tensor: The output tensor.
|
| 169 |
+
"""
|
| 170 |
+
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class FinalLayer(nn.Module):
|
| 174 |
+
"""
|
| 175 |
+
The final layer of the transformer, which applies AdaLN modulation and a linear projection.
|
| 176 |
+
"""
|
| 177 |
+
|
| 178 |
+
def __init__(self, hidden_size, out_channels):
|
| 179 |
+
"""
|
| 180 |
+
Initializes the FinalLayer module.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
hidden_size (int): The input hidden size.
|
| 184 |
+
out_channels (int): The output dimension (number of channels).
|
| 185 |
+
"""
|
| 186 |
+
super().__init__()
|
| 187 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 188 |
+
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
|
| 189 |
+
self.adaLN_modulation = nn.Sequential(
|
| 190 |
+
nn.SiLU(),
|
| 191 |
+
nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
def forward(self, x, c):
|
| 195 |
+
"""
|
| 196 |
+
Defines the forward pass for the final layer.
|
| 197 |
+
|
| 198 |
+
Args:
|
| 199 |
+
x (torch.Tensor): The main input tensor from the transformer blocks.
|
| 200 |
+
c (torch.Tensor): The conditioning tensor (usually from timestep embedding) for AdaLN modulation.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
torch.Tensor: The final output tensor projected to the patch dimension.
|
| 204 |
+
"""
|
| 205 |
+
scale = 1.0 + self.adaLN_modulation(c)
|
| 206 |
+
x = self.norm_final(x) * scale.unsqueeze(1)
|
| 207 |
+
x = self.linear(x)
|
| 208 |
+
return x
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class RopeEmbedder:
|
| 212 |
+
"""
|
| 213 |
+
Computes Rotary Positional Embeddings (RoPE) for 3D coordinates.
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
def __init__(self, theta: float = 256.0, axes_dims: List[int] = (32, 48, 48), axes_lens: List[int] = (1024, 512, 512)):
|
| 217 |
+
"""
|
| 218 |
+
Initializes the RopeEmbedder.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
theta (float, optional): The base for the rotary frequencies. Defaults to 256.0.
|
| 222 |
+
axes_dims (List[int], optional): The dimensions for each axis (F, H, W). Defaults to (32, 48, 48).
|
| 223 |
+
axes_lens (List[int], optional): The maximum length for each axis. Defaults to (1024, 512, 512).
|
| 224 |
+
"""
|
| 225 |
+
self.theta = theta
|
| 226 |
+
self.axes_dims = axes_dims
|
| 227 |
+
self.axes_lens = axes_lens
|
| 228 |
+
self.freqs_cis_cache = {}
|
| 229 |
+
|
| 230 |
+
def _precompute_freqs_cis(self, device):
|
| 231 |
+
"""
|
| 232 |
+
Precomputes and caches the rotary frequency tensors (cos and sin values).
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
device (torch.device): The device to store the cached tensors on.
|
| 236 |
+
|
| 237 |
+
Returns:
|
| 238 |
+
List[torch.Tensor]: A list of precomputed frequency tensors for each axis.
|
| 239 |
+
"""
|
| 240 |
+
if device in self.freqs_cis_cache:
|
| 241 |
+
return self.freqs_cis_cache[device]
|
| 242 |
+
freqs_cis_list = []
|
| 243 |
+
for dim, max_len in zip(self.axes_dims, self.axes_lens):
|
| 244 |
+
half = dim // 2
|
| 245 |
+
freqs = 1.0 / (self.theta ** (torch.arange(0, half, device=device, dtype=torch.float32) / half))
|
| 246 |
+
t = torch.arange(max_len, device=device, dtype=torch.float32)
|
| 247 |
+
freqs = torch.outer(t, freqs)
|
| 248 |
+
emb = torch.stack([freqs.cos(), freqs.sin()], dim=-1)
|
| 249 |
+
freqs_cis_list.append(emb)
|
| 250 |
+
self.freqs_cis_cache[device] = freqs_cis_list
|
| 251 |
+
return freqs_cis_list
|
| 252 |
+
|
| 253 |
+
def __call__(self, ids: torch.Tensor):
|
| 254 |
+
"""
|
| 255 |
+
Generates RoPE embeddings for a batch of 3D coordinates.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
ids (torch.Tensor): A tensor of coordinates with shape (N, 3).
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
torch.Tensor: The concatenated RoPE embeddings for the input coordinates.
|
| 262 |
+
"""
|
| 263 |
+
assert ids.ndim == 2 and ids.shape[1] == len(self.axes_dims)
|
| 264 |
+
device = ids.device
|
| 265 |
+
freqs_cis_list = self._precompute_freqs_cis(device)
|
| 266 |
+
result = []
|
| 267 |
+
for i in range(len(self.axes_dims)):
|
| 268 |
+
result.append(freqs_cis_list[i][ids[:, i]])
|
| 269 |
+
return torch.cat(result, dim=-2)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class ZSingleStreamAttnProcessor:
|
| 273 |
+
"""
|
| 274 |
+
An attention processor that applies Rotary Positional Embeddings (RoPE) to query and key tensors
|
| 275 |
+
before computing scaled dot-product attention.
|
| 276 |
+
"""
|
| 277 |
+
|
| 278 |
+
_attention_backend = None
|
| 279 |
+
_parallel_config = None
|
| 280 |
+
|
| 281 |
+
def __init__(self):
|
| 282 |
+
"""
|
| 283 |
+
Initializes the ZSingleStreamAttnProcessor.
|
| 284 |
+
"""
|
| 285 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 286 |
+
raise ImportError("ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher.")
|
| 287 |
+
|
| 288 |
+
def __call__(
|
| 289 |
+
self,
|
| 290 |
+
attn: Attention,
|
| 291 |
+
hidden_states: torch.Tensor,
|
| 292 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 293 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 294 |
+
freqs_cis: Optional[torch.Tensor] = None,
|
| 295 |
+
) -> torch.Tensor:
|
| 296 |
+
"""
|
| 297 |
+
The forward call for the attention processor.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
attn (Attention): The attention layer that this processor is attached to.
|
| 301 |
+
hidden_states (torch.Tensor): The input hidden states.
|
| 302 |
+
encoder_hidden_states (Optional[torch.Tensor], optional): Not used in self-attention. Defaults to None.
|
| 303 |
+
attention_mask (Optional[torch.Tensor], optional): The attention mask. Defaults to None.
|
| 304 |
+
freqs_cis (Optional[torch.Tensor], optional): The precomputed RoPE frequencies. Defaults to None.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
torch.Tensor: The output of the attention mechanism.
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
def apply_rotary_emb(q_or_k: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
| 311 |
+
"""
|
| 312 |
+
Applies RoPE to a query or key tensor.
|
| 313 |
+
"""
|
| 314 |
+
x = q_or_k.transpose(1, 2)
|
| 315 |
+
x_reshaped = x.float().reshape(*x.shape[:-1], -1, 2)
|
| 316 |
+
x0 = x_reshaped[..., 0]
|
| 317 |
+
x1 = x_reshaped[..., 1]
|
| 318 |
+
freqs_cos = freqs_cis[..., 0].unsqueeze(1)
|
| 319 |
+
freqs_sin = freqs_cis[..., 1].unsqueeze(1)
|
| 320 |
+
x_rotated_0 = x0 * freqs_cos - x1 * freqs_sin
|
| 321 |
+
x_rotated_1 = x0 * freqs_sin + x1 * freqs_cos
|
| 322 |
+
x_rotated = torch.stack((x_rotated_0, x_rotated_1), dim=-1)
|
| 323 |
+
x_out = x_rotated.flatten(-2).transpose(1, 2)
|
| 324 |
+
return x_out.to(q_or_k.dtype)
|
| 325 |
+
|
| 326 |
+
query = attn.to_q(hidden_states)
|
| 327 |
+
key = attn.to_k(hidden_states)
|
| 328 |
+
value = attn.to_v(hidden_states)
|
| 329 |
+
|
| 330 |
+
query = query.unflatten(-1, (attn.heads, -1))
|
| 331 |
+
key = key.unflatten(-1, (attn.heads, -1))
|
| 332 |
+
value = value.unflatten(-1, (attn.heads, -1))
|
| 333 |
+
|
| 334 |
+
if attn.norm_q is not None:
|
| 335 |
+
query = attn.norm_q(query)
|
| 336 |
+
if attn.norm_k is not None:
|
| 337 |
+
key = attn.norm_k(key)
|
| 338 |
+
|
| 339 |
+
if freqs_cis is not None:
|
| 340 |
+
query = apply_rotary_emb(query, freqs_cis)
|
| 341 |
+
key = apply_rotary_emb(key, freqs_cis)
|
| 342 |
+
|
| 343 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
| 344 |
+
attention_mask = attention_mask[:, None, None, :]
|
| 345 |
+
|
| 346 |
+
hidden_states = dispatch_attention_fn(
|
| 347 |
+
query,
|
| 348 |
+
key,
|
| 349 |
+
value,
|
| 350 |
+
attn_mask=attention_mask,
|
| 351 |
+
dropout_p=0.0,
|
| 352 |
+
is_causal=False,
|
| 353 |
+
backend=self._attention_backend,
|
| 354 |
+
parallel_config=self._parallel_config,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
hidden_states = hidden_states.flatten(2, 3)
|
| 358 |
+
|
| 359 |
+
output = attn.to_out[0](hidden_states.to(hidden_states.dtype))
|
| 360 |
+
if len(attn.to_out) > 1:
|
| 361 |
+
output = attn.to_out[1](output)
|
| 362 |
+
|
| 363 |
+
return output
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
@maybe_allow_in_graph
|
| 367 |
+
class ZImageTransformerBlock(nn.Module):
|
| 368 |
+
"""
|
| 369 |
+
A standard transformer block consisting of a self-attention layer and a feed-forward network.
|
| 370 |
+
Includes support for AdaLN modulation.
|
| 371 |
+
"""
|
| 372 |
+
|
| 373 |
+
def __init__(
|
| 374 |
+
self,
|
| 375 |
+
layer_id: int,
|
| 376 |
+
dim: int,
|
| 377 |
+
n_heads: int,
|
| 378 |
+
n_kv_heads: int,
|
| 379 |
+
norm_eps: float,
|
| 380 |
+
qk_norm: bool,
|
| 381 |
+
modulation=True,
|
| 382 |
+
):
|
| 383 |
+
"""
|
| 384 |
+
Initializes the ZImageTransformerBlock.
|
| 385 |
+
|
| 386 |
+
Args:
|
| 387 |
+
layer_id (int): The index of the layer.
|
| 388 |
+
dim (int): The dimension of the input and output features.
|
| 389 |
+
n_heads (int): The number of attention heads.
|
| 390 |
+
n_kv_heads (int): The number of key/value heads (not directly used in this simplified attention).
|
| 391 |
+
norm_eps (float): Epsilon for RMSNorm.
|
| 392 |
+
qk_norm (bool): Whether to apply normalization to query and key tensors.
|
| 393 |
+
modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True.
|
| 394 |
+
"""
|
| 395 |
+
super().__init__()
|
| 396 |
+
self.dim = dim
|
| 397 |
+
self.head_dim = dim // n_heads
|
| 398 |
+
self.attention = Attention(
|
| 399 |
+
query_dim=dim,
|
| 400 |
+
cross_attention_dim=None,
|
| 401 |
+
dim_head=dim // n_heads,
|
| 402 |
+
heads=n_heads,
|
| 403 |
+
qk_norm="rms_norm" if qk_norm else None,
|
| 404 |
+
eps=1e-5,
|
| 405 |
+
bias=False,
|
| 406 |
+
out_bias=False,
|
| 407 |
+
processor=ZSingleStreamAttnProcessor(),
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
| 411 |
+
self.layer_id = layer_id
|
| 412 |
+
|
| 413 |
+
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
| 414 |
+
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
| 415 |
+
|
| 416 |
+
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
| 417 |
+
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
| 418 |
+
|
| 419 |
+
self.modulation = modulation
|
| 420 |
+
if modulation:
|
| 421 |
+
self.adaLN_modulation = nn.Sequential(
|
| 422 |
+
nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
@property
|
| 426 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 427 |
+
"""
|
| 428 |
+
Returns a dictionary of all attention processors used in the module.
|
| 429 |
+
"""
|
| 430 |
+
processors = {}
|
| 431 |
+
|
| 432 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 433 |
+
if hasattr(module, "get_processor"):
|
| 434 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 435 |
+
for sub_name, child in module.named_children():
|
| 436 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 437 |
+
return processors
|
| 438 |
+
|
| 439 |
+
for name, module in self.named_children():
|
| 440 |
+
fn_recursive_add_processors(name, module, processors)
|
| 441 |
+
|
| 442 |
+
return processors
|
| 443 |
+
|
| 444 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 445 |
+
"""
|
| 446 |
+
Sets the attention processor for the attention layer in this block.
|
| 447 |
+
"""
|
| 448 |
+
count = len(self.attn_processors.keys())
|
| 449 |
+
|
| 450 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 451 |
+
raise ValueError(
|
| 452 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 453 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 457 |
+
if hasattr(module, "set_processor"):
|
| 458 |
+
if not isinstance(processor, dict):
|
| 459 |
+
module.set_processor(processor)
|
| 460 |
+
else:
|
| 461 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 462 |
+
for sub_name, child in module.named_children():
|
| 463 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 464 |
+
|
| 465 |
+
for name, module in self.named_children():
|
| 466 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 467 |
+
|
| 468 |
+
def forward(self, x, attn_mask, freqs_cis, adaln_input=None):
|
| 469 |
+
"""
|
| 470 |
+
Defines the forward pass for the transformer block.
|
| 471 |
+
|
| 472 |
+
Args:
|
| 473 |
+
x (torch.Tensor): The input tensor.
|
| 474 |
+
attn_mask (torch.Tensor): The attention mask.
|
| 475 |
+
freqs_cis (torch.Tensor): The RoPE frequencies.
|
| 476 |
+
adaln_input (torch.Tensor, optional): The conditioning tensor for AdaLN. Defaults to None.
|
| 477 |
+
|
| 478 |
+
Returns:
|
| 479 |
+
torch.Tensor: The output tensor of the block.
|
| 480 |
+
"""
|
| 481 |
+
if self.modulation:
|
| 482 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
| 483 |
+
scale_msa = scale_msa + 1.0
|
| 484 |
+
gate_msa = gate_msa.tanh()
|
| 485 |
+
scale_mlp = scale_mlp + 1.0
|
| 486 |
+
gate_mlp = gate_mlp.tanh()
|
| 487 |
+
|
| 488 |
+
normed = self.attention_norm1(x)
|
| 489 |
+
normed = normed * scale_msa
|
| 490 |
+
attn_out = self.attention(normed, attention_mask=attn_mask, freqs_cis=freqs_cis)
|
| 491 |
+
attn_out = self.attention_norm2(attn_out) * gate_msa
|
| 492 |
+
x = x + attn_out
|
| 493 |
+
|
| 494 |
+
normed = self.ffn_norm1(x)
|
| 495 |
+
normed = normed * scale_mlp
|
| 496 |
+
ffn_out = self.feed_forward(normed)
|
| 497 |
+
ffn_out = self.ffn_norm2(ffn_out) * gate_mlp
|
| 498 |
+
x = x + ffn_out
|
| 499 |
+
else:
|
| 500 |
+
normed = self.attention_norm1(x)
|
| 501 |
+
attn_out = self.attention(normed, attention_mask=attn_mask, freqs_cis=freqs_cis)
|
| 502 |
+
x = x + self.attention_norm2(attn_out)
|
| 503 |
+
normed = self.ffn_norm1(x)
|
| 504 |
+
ffn_out = self.feed_forward(normed)
|
| 505 |
+
x = x + self.ffn_norm2(ffn_out)
|
| 506 |
+
return x
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
class ZImageControlTransformerBlock(ZImageTransformerBlock):
|
| 510 |
+
"""
|
| 511 |
+
A specialized transformer block for the control pathway. It inherits from ZImageTransformerBlock
|
| 512 |
+
and adds projection layers to generate and combine control signals.
|
| 513 |
+
"""
|
| 514 |
+
|
| 515 |
+
def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0):
|
| 516 |
+
"""
|
| 517 |
+
Initializes the ZImageControlTransformerBlock.
|
| 518 |
+
|
| 519 |
+
Args:
|
| 520 |
+
layer_id (int): The index of the layer.
|
| 521 |
+
dim (int): The dimension of the features.
|
| 522 |
+
n_heads (int): The number of attention heads.
|
| 523 |
+
n_kv_heads (int): The number of key/value heads.
|
| 524 |
+
norm_eps (float): Epsilon for RMSNorm.
|
| 525 |
+
qk_norm (bool): Whether to apply normalization to query and key.
|
| 526 |
+
modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True.
|
| 527 |
+
block_id (int, optional): The index of this control block. Defaults to 0.
|
| 528 |
+
"""
|
| 529 |
+
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
|
| 530 |
+
self.block_id = block_id
|
| 531 |
+
if block_id == 0:
|
| 532 |
+
self.before_proj = zero_module(nn.Linear(self.dim, self.dim))
|
| 533 |
+
self.after_proj = zero_module(nn.Linear(self.dim, self.dim))
|
| 534 |
+
|
| 535 |
+
def forward(self, c, x, **kwargs):
|
| 536 |
+
"""
|
| 537 |
+
Defines the forward pass for the control block.
|
| 538 |
+
|
| 539 |
+
Args:
|
| 540 |
+
c (torch.Tensor): The control signal tensor.
|
| 541 |
+
x (torch.Tensor): The reference tensor from the main pathway.
|
| 542 |
+
**kwargs: Additional arguments for the parent's forward method.
|
| 543 |
+
|
| 544 |
+
Returns:
|
| 545 |
+
torch.Tensor: A stacked tensor containing the skip connection and the final output.
|
| 546 |
+
"""
|
| 547 |
+
if self.block_id == 0:
|
| 548 |
+
c = self.before_proj(c) + x
|
| 549 |
+
all_c = []
|
| 550 |
+
else:
|
| 551 |
+
all_c = list(torch.unbind(c))
|
| 552 |
+
c = all_c.pop(-1)
|
| 553 |
+
|
| 554 |
+
c = super().forward(c, **kwargs)
|
| 555 |
+
c_skip = self.after_proj(c)
|
| 556 |
+
all_c += [c_skip, c]
|
| 557 |
+
c = torch.stack(all_c)
|
| 558 |
+
return c
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
class BaseZImageTransformerBlock(ZImageTransformerBlock):
|
| 562 |
+
"""
|
| 563 |
+
The main transformer block used in the primary pathway. It inherits from ZImageTransformerBlock
|
| 564 |
+
and adds the logic to inject control "hints" from the control pathway.
|
| 565 |
+
"""
|
| 566 |
+
|
| 567 |
+
def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, norm_eps: float, qk_norm: bool, modulation=True, block_id=0):
|
| 568 |
+
"""
|
| 569 |
+
Initializes the BaseZImageTransformerBlock.
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
layer_id (int): The index of the layer.
|
| 573 |
+
dim (int): The dimension of the features.
|
| 574 |
+
n_heads (int): The number of attention heads.
|
| 575 |
+
n_kv_heads (int): The number of key/value heads.
|
| 576 |
+
norm_eps (float): Epsilon for RMSNorm.
|
| 577 |
+
qk_norm (bool): Whether to apply normalization to query and key.
|
| 578 |
+
modulation (bool, optional): Whether to enable AdaLN modulation. Defaults to True.
|
| 579 |
+
block_id (int, optional): The index used to retrieve the corresponding control hint. Defaults to 0.
|
| 580 |
+
"""
|
| 581 |
+
super().__init__(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation)
|
| 582 |
+
self.block_id = block_id
|
| 583 |
+
|
| 584 |
+
def forward(self, hidden_states, hints=None, context_scale=1.0, **kwargs):
|
| 585 |
+
"""
|
| 586 |
+
Defines the forward pass, including the injection of control hints.
|
| 587 |
+
|
| 588 |
+
Args:
|
| 589 |
+
hidden_states (torch.Tensor): The input tensor.
|
| 590 |
+
hints (List[torch.Tensor], optional): A list of control hints from the control pathway. Defaults to None.
|
| 591 |
+
context_scale (float, optional): A scale factor for the control hints. Defaults to 1.0.
|
| 592 |
+
**kwargs: Additional arguments for the parent's forward method.
|
| 593 |
+
|
| 594 |
+
Returns:
|
| 595 |
+
torch.Tensor: The output tensor of the block.
|
| 596 |
+
"""
|
| 597 |
+
hidden_states = super().forward(hidden_states, **kwargs)
|
| 598 |
+
if self.block_id is not None and hints is not None:
|
| 599 |
+
hidden_states = hidden_states + hints[self.block_id] * context_scale
|
| 600 |
+
return hidden_states
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
class ZImageControlTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 604 |
+
_supports_gradient_checkpointing = True
|
| 605 |
+
_keys_to_ignore_on_load_unexpected = [
|
| 606 |
+
r"control_layers\..*",
|
| 607 |
+
r"control_noise_refiner\..*",
|
| 608 |
+
r"control_all_x_embedder\..*",
|
| 609 |
+
]
|
| 610 |
+
_no_split_modules = ["ZImageTransformerBlock", "BaseZImageTransformerBlock", "ZImageControlTransformerBlock"]
|
| 611 |
+
_skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"]
|
| 612 |
+
_group_offload_block_modules = ["t_embedder", "cap_embedder"]
|
| 613 |
+
|
| 614 |
+
@register_to_config
|
| 615 |
+
def __init__(
|
| 616 |
+
self,
|
| 617 |
+
control_layers_places=None,
|
| 618 |
+
control_refiner_layers_places=None,
|
| 619 |
+
control_in_dim=None,
|
| 620 |
+
add_control_noise_refiner=False,
|
| 621 |
+
all_patch_size=(2,),
|
| 622 |
+
all_f_patch_size=(1,),
|
| 623 |
+
in_channels=16,
|
| 624 |
+
dim=3840,
|
| 625 |
+
n_layers=30,
|
| 626 |
+
n_refiner_layers=2,
|
| 627 |
+
n_heads=30,
|
| 628 |
+
n_kv_heads=30,
|
| 629 |
+
norm_eps=1e-5,
|
| 630 |
+
qk_norm=True,
|
| 631 |
+
cap_feat_dim=2560,
|
| 632 |
+
rope_theta=256.0,
|
| 633 |
+
t_scale=1000.0,
|
| 634 |
+
axes_dims=[32, 48, 48],
|
| 635 |
+
axes_lens=[1024, 512, 512],
|
| 636 |
+
use_controlnet=True,
|
| 637 |
+
checkpoint_ratio=0.5,
|
| 638 |
+
):
|
| 639 |
+
"""
|
| 640 |
+
Initializes the ZImageControlTransformer2DModel.
|
| 641 |
+
|
| 642 |
+
Args:
|
| 643 |
+
control_layers_places (List[int], optional): Indices of main layers where control hints are injected.
|
| 644 |
+
control_refiner_layers_places (List[int], optional): Indices of noise refiner layers for two-stage control.
|
| 645 |
+
control_in_dim (int, optional): Input channel dimension for the control context.
|
| 646 |
+
add_control_noise_refiner (bool, optional): Whether to add a dedicated refiner for the control signal.
|
| 647 |
+
all_patch_size (Tuple[int], optional): Tuple of patch sizes for spatial dimensions.
|
| 648 |
+
all_f_patch_size (Tuple[int], optional): Tuple of patch sizes for the frame dimension.
|
| 649 |
+
in_channels (int, optional): Number of input channels for the latent image.
|
| 650 |
+
dim (int, optional): The main dimension of the transformer model.
|
| 651 |
+
n_layers (int, optional): The number of main transformer layers.
|
| 652 |
+
n_refiner_layers (int, optional): The number of layers in the refiner blocks.
|
| 653 |
+
n_heads (int, optional): The number of attention heads.
|
| 654 |
+
n_kv_heads (int, optional): The number of key/value heads.
|
| 655 |
+
norm_eps (float, optional): Epsilon for RMSNorm.
|
| 656 |
+
qk_norm (bool, optional): Whether to apply normalization to query and key.
|
| 657 |
+
cap_feat_dim (int, optional): The dimension of the input caption features.
|
| 658 |
+
rope_theta (float, optional): The base for RoPE.
|
| 659 |
+
t_scale (float, optional): A scaling factor for the timestep.
|
| 660 |
+
axes_dims (List[int], optional): Dimensions for each axis in RoPE.
|
| 661 |
+
axes_lens (List[int], optional): Maximum lengths for each axis in RoPE.
|
| 662 |
+
use_controlnet (bool, optional): If False, control-related layers will not be created to save memory.
|
| 663 |
+
checkpoint_ratio (float, optional): The ratio of layers to apply gradient checkpointing to.
|
| 664 |
+
"""
|
| 665 |
+
super().__init__()
|
| 666 |
+
self.use_controlnet = use_controlnet
|
| 667 |
+
self.in_channels = in_channels
|
| 668 |
+
self.out_channels = in_channels
|
| 669 |
+
self.all_patch_size = all_patch_size
|
| 670 |
+
self.all_f_patch_size = all_f_patch_size
|
| 671 |
+
self.dim = dim
|
| 672 |
+
self.control_in_dim = self.dim if control_in_dim is None else control_in_dim
|
| 673 |
+
self.is_two_stage_control = self.control_in_dim > 16
|
| 674 |
+
self.n_heads = n_heads
|
| 675 |
+
self.rope_theta = rope_theta
|
| 676 |
+
self.t_scale = t_scale
|
| 677 |
+
self.gradient_checkpointing = False
|
| 678 |
+
self.checkpoint_ratio = checkpoint_ratio
|
| 679 |
+
assert len(all_patch_size) == len(all_f_patch_size)
|
| 680 |
+
|
| 681 |
+
self.control_layers_places = list(range(0, n_layers, 2)) if control_layers_places is None else control_layers_places
|
| 682 |
+
self.control_refiner_layers_places = list(range(0, n_refiner_layers)) if control_refiner_layers_places is None else control_refiner_layers_places
|
| 683 |
+
self.add_control_noise_refiner = add_control_noise_refiner
|
| 684 |
+
assert 0 in self.control_layers_places
|
| 685 |
+
self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers_places)}
|
| 686 |
+
self.control_refiner_layers_mapping = {i: n for n, i in enumerate(self.control_refiner_layers_places)}
|
| 687 |
+
|
| 688 |
+
self.all_x_embedder = nn.ModuleDict(
|
| 689 |
+
{
|
| 690 |
+
f"{patch_size}-{f_patch_size}": nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
|
| 691 |
+
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size)
|
| 692 |
+
}
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
self.all_final_layer = nn.ModuleDict(
|
| 696 |
+
{
|
| 697 |
+
f"{patch_size}-{f_patch_size}": FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
|
| 698 |
+
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size)
|
| 699 |
+
}
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
self.context_refiner = nn.ModuleList(
|
| 703 |
+
[ZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=False) for i in range(n_refiner_layers)]
|
| 704 |
+
)
|
| 705 |
+
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
|
| 706 |
+
self.cap_embedder = nn.Sequential(RMSNorm(cap_feat_dim, eps=norm_eps), nn.Linear(cap_feat_dim, dim, bias=True))
|
| 707 |
+
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
|
| 708 |
+
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
|
| 709 |
+
|
| 710 |
+
head_dim = dim // n_heads
|
| 711 |
+
assert head_dim == sum(axes_dims)
|
| 712 |
+
self.axes_dims = axes_dims
|
| 713 |
+
self.axes_lens = axes_lens
|
| 714 |
+
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
|
| 715 |
+
|
| 716 |
+
self.layers = nn.ModuleList(
|
| 717 |
+
[BaseZImageTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=self.control_layers_mapping.get(i)) for i in range(n_layers)]
|
| 718 |
+
)
|
| 719 |
+
|
| 720 |
+
self.noise_refiner = nn.ModuleList(
|
| 721 |
+
[
|
| 722 |
+
BaseZImageTransformerBlock(
|
| 723 |
+
1000 + i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, block_id=self.control_refiner_layers_mapping.get(i)
|
| 724 |
+
)
|
| 725 |
+
for i in range(n_refiner_layers)
|
| 726 |
+
]
|
| 727 |
+
)
|
| 728 |
+
|
| 729 |
+
if self.use_controlnet:
|
| 730 |
+
self.control_layers = nn.ModuleList(
|
| 731 |
+
[ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i) for i in self.control_layers_places]
|
| 732 |
+
)
|
| 733 |
+
self.control_all_x_embedder = nn.ModuleDict(
|
| 734 |
+
{
|
| 735 |
+
f"{patch_size}-{f_patch_size}": nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True)
|
| 736 |
+
for patch_size, f_patch_size in zip(all_patch_size, all_f_patch_size)
|
| 737 |
+
}
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
if self.is_two_stage_control:
|
| 741 |
+
if self.add_control_noise_refiner:
|
| 742 |
+
self.control_noise_refiner = nn.ModuleList(
|
| 743 |
+
[
|
| 744 |
+
ZImageControlTransformerBlock(1000 + layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True, block_id=layer_id)
|
| 745 |
+
for layer_id in range(n_refiner_layers)
|
| 746 |
+
]
|
| 747 |
+
)
|
| 748 |
+
else:
|
| 749 |
+
self.control_noise_refiner = None
|
| 750 |
+
else: # V1
|
| 751 |
+
self.control_noise_refiner = nn.ModuleList(
|
| 752 |
+
[ZImageTransformerBlock(1000 + i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, modulation=True) for i in range(n_refiner_layers)]
|
| 753 |
+
)
|
| 754 |
+
else:
|
| 755 |
+
self.control_layers = None
|
| 756 |
+
self.control_all_x_embedder = None
|
| 757 |
+
self.control_noise_refiner = None
|
| 758 |
+
|
| 759 |
+
def _unpatchify(self, x_image_tokens: torch.Tensor, all_sizes: List[Tuple], patch_size: int, f_patch_size: int) -> torch.Tensor:
|
| 760 |
+
"""
|
| 761 |
+
Converts a sequence of image tokens back into a batched image tensor. This version is robust
|
| 762 |
+
to batches containing images of different original sizes.
|
| 763 |
+
|
| 764 |
+
Args:
|
| 765 |
+
x_image_tokens (torch.Tensor): A tensor of image tokens with shape [B, SeqLen, Dim].
|
| 766 |
+
all_sizes (List[Tuple]): A list of tuples with the original (F, H, W) size for each image in the batch.
|
| 767 |
+
patch_size (int): The spatial patch size (height and width).
|
| 768 |
+
f_patch_size (int): The frame/temporal patch size.
|
| 769 |
+
|
| 770 |
+
Returns:
|
| 771 |
+
torch.Tensor: The reconstructed latent tensor with shape [B, C, F, H, W].
|
| 772 |
+
"""
|
| 773 |
+
pH = pW = patch_size
|
| 774 |
+
pF = f_patch_size
|
| 775 |
+
batch_size = x_image_tokens.shape[0]
|
| 776 |
+
unpatched_images = []
|
| 777 |
+
|
| 778 |
+
for i in range(batch_size):
|
| 779 |
+
F, H, W = all_sizes[i]
|
| 780 |
+
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
| 781 |
+
original_seq_len = F_tokens * H_tokens * W_tokens
|
| 782 |
+
current_image_tokens = x_image_tokens[i, :original_seq_len, :]
|
| 783 |
+
unpatched_image = current_image_tokens.view(F_tokens, H_tokens, W_tokens, pF, pH, pW, self.out_channels)
|
| 784 |
+
unpatched_image = unpatched_image.permute(6, 0, 3, 1, 4, 2, 5).reshape(self.out_channels, F, H, W)
|
| 785 |
+
unpatched_images.append(unpatched_image)
|
| 786 |
+
|
| 787 |
+
try:
|
| 788 |
+
final_tensor = torch.stack(unpatched_images, dim=0)
|
| 789 |
+
except RuntimeError:
|
| 790 |
+
raise ValueError(
|
| 791 |
+
"Could not stack unpatched images into a single batch tensor. "
|
| 792 |
+
"This typically occurs if you are trying to generate images of different sizes in the same batch."
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
return final_tensor
|
| 796 |
+
|
| 797 |
+
def _patchify(
|
| 798 |
+
self,
|
| 799 |
+
all_image: List[torch.Tensor],
|
| 800 |
+
patch_size: int,
|
| 801 |
+
f_patch_size: int,
|
| 802 |
+
cap_padding_len: int,
|
| 803 |
+
):
|
| 804 |
+
"""
|
| 805 |
+
Converts a list of image tensors into patch sequences and computes their positional IDs.
|
| 806 |
+
|
| 807 |
+
Args:
|
| 808 |
+
all_image (List[torch.Tensor]): A list of image tensors to process.
|
| 809 |
+
patch_size (int): The spatial patch size.
|
| 810 |
+
f_patch_size (int): The frame/temporal patch size.
|
| 811 |
+
cap_padding_len (int): The length of the padded caption sequence, used as an offset for image position IDs.
|
| 812 |
+
|
| 813 |
+
Returns:
|
| 814 |
+
Tuple: A tuple containing lists of processed patches, sizes, position IDs, and padding masks.
|
| 815 |
+
"""
|
| 816 |
+
pH = pW = patch_size
|
| 817 |
+
pF = f_patch_size
|
| 818 |
+
device = all_image[0].device
|
| 819 |
+
|
| 820 |
+
all_image_out = []
|
| 821 |
+
all_image_size = []
|
| 822 |
+
all_image_pos_ids = []
|
| 823 |
+
all_image_pad_mask = []
|
| 824 |
+
|
| 825 |
+
for i, image in enumerate(all_image):
|
| 826 |
+
C, F, H, W = image.size()
|
| 827 |
+
all_image_size.append((F, H, W))
|
| 828 |
+
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
| 829 |
+
|
| 830 |
+
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
| 831 |
+
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
| 832 |
+
|
| 833 |
+
image_ori_len = len(image)
|
| 834 |
+
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
| 835 |
+
|
| 836 |
+
image_ori_pos_ids = self._create_coordinate_grid(
|
| 837 |
+
size=(F_tokens, H_tokens, W_tokens),
|
| 838 |
+
start=(cap_padding_len + 1, 0, 0),
|
| 839 |
+
device=device,
|
| 840 |
+
).flatten(0, 2)
|
| 841 |
+
image_padding_pos_ids = (
|
| 842 |
+
self._create_coordinate_grid(
|
| 843 |
+
size=(1, 1, 1),
|
| 844 |
+
start=(0, 0, 0),
|
| 845 |
+
device=device,
|
| 846 |
+
)
|
| 847 |
+
.flatten(0, 2)
|
| 848 |
+
.repeat(image_padding_len, 1)
|
| 849 |
+
)
|
| 850 |
+
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
|
| 851 |
+
all_image_pos_ids.append(image_padded_pos_ids)
|
| 852 |
+
all_image_pad_mask.append(
|
| 853 |
+
torch.cat(
|
| 854 |
+
[
|
| 855 |
+
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
|
| 856 |
+
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
|
| 857 |
+
],
|
| 858 |
+
dim=0,
|
| 859 |
+
)
|
| 860 |
+
)
|
| 861 |
+
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
| 862 |
+
all_image_out.append(image_padded_feat)
|
| 863 |
+
|
| 864 |
+
return (
|
| 865 |
+
all_image_out,
|
| 866 |
+
all_image_size,
|
| 867 |
+
all_image_pos_ids,
|
| 868 |
+
all_image_pad_mask,
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
def _patchify_and_embed(
|
| 872 |
+
self,
|
| 873 |
+
all_image: List[torch.Tensor],
|
| 874 |
+
all_cap_feats: List[torch.Tensor],
|
| 875 |
+
patch_size: int,
|
| 876 |
+
f_patch_size: int,
|
| 877 |
+
):
|
| 878 |
+
"""
|
| 879 |
+
Processes a batch of images and caption features by converting them into padded patch sequences
|
| 880 |
+
and generating their corresponding positional IDs and padding masks. This is the general-purpose,
|
| 881 |
+
robust version that iterates through the batch.
|
| 882 |
+
|
| 883 |
+
Args:
|
| 884 |
+
all_image (List[torch.Tensor]): A list of image tensors.
|
| 885 |
+
all_cap_feats (List[torch.Tensor]): A list of caption feature tensors.
|
| 886 |
+
patch_size (int): The spatial patch size.
|
| 887 |
+
f_patch_size (int): The frame/temporal patch size.
|
| 888 |
+
|
| 889 |
+
Returns:
|
| 890 |
+
Tuple: A tuple containing all processed data structures (image patches, caption features, sizes,
|
| 891 |
+
position IDs, and padding masks) as lists.
|
| 892 |
+
"""
|
| 893 |
+
pH = pW = patch_size
|
| 894 |
+
pF = f_patch_size
|
| 895 |
+
device = all_image[0].device
|
| 896 |
+
|
| 897 |
+
all_image_out, all_image_size, all_image_pos_ids, all_image_pad_mask = [], [], [], []
|
| 898 |
+
all_cap_pos_ids, all_cap_pad_mask, all_cap_feats_out = [], [], []
|
| 899 |
+
|
| 900 |
+
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
|
| 901 |
+
cap_ori_len = len(cap_feat)
|
| 902 |
+
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
| 903 |
+
cap_total_len = cap_ori_len + cap_padding_len
|
| 904 |
+
|
| 905 |
+
cap_padded_pos_ids = self._create_coordinate_grid(size=(cap_total_len, 1, 1), start=(1, 0, 0), device=device).flatten(0, 2)
|
| 906 |
+
all_cap_pos_ids.append(cap_padded_pos_ids)
|
| 907 |
+
|
| 908 |
+
cap_mask = torch.ones(cap_total_len, dtype=torch.bool, device=device)
|
| 909 |
+
cap_mask[:cap_ori_len] = False
|
| 910 |
+
all_cap_pad_mask.append(cap_mask)
|
| 911 |
+
|
| 912 |
+
if cap_padding_len > 0:
|
| 913 |
+
padding_tensor = cap_feat[-1:].repeat(cap_padding_len, 1)
|
| 914 |
+
cap_padded_feat = torch.cat([cap_feat, padding_tensor], dim=0)
|
| 915 |
+
else:
|
| 916 |
+
cap_padded_feat = cap_feat
|
| 917 |
+
all_cap_feats_out.append(cap_padded_feat)
|
| 918 |
+
|
| 919 |
+
C, Fr, H, W = image.size()
|
| 920 |
+
all_image_size.append((Fr, H, W))
|
| 921 |
+
F_tokens, H_tokens, W_tokens = Fr // pF, H // pH, W // pW
|
| 922 |
+
|
| 923 |
+
image_reshaped = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW).permute(1, 3, 5, 2, 4, 6, 0).reshape(-1, pF * pH * pW * C)
|
| 924 |
+
|
| 925 |
+
image_ori_len = image_reshaped.shape[0]
|
| 926 |
+
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
| 927 |
+
image_total_len = image_ori_len + image_padding_len
|
| 928 |
+
|
| 929 |
+
image_ori_pos_ids = self._create_coordinate_grid(size=(F_tokens, H_tokens, W_tokens), start=(cap_total_len + 1, 0, 0), device=device).flatten(0, 2)
|
| 930 |
+
if image_padding_len > 0:
|
| 931 |
+
image_padding_pos_ids = torch.zeros((image_padding_len, 3), dtype=torch.int32, device=device)
|
| 932 |
+
image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
|
| 933 |
+
else:
|
| 934 |
+
image_padded_pos_ids = image_ori_pos_ids
|
| 935 |
+
all_image_pos_ids.append(image_padded_pos_ids)
|
| 936 |
+
|
| 937 |
+
image_mask = torch.ones(image_total_len, dtype=torch.bool, device=device)
|
| 938 |
+
image_mask[:image_ori_len] = False
|
| 939 |
+
all_image_pad_mask.append(image_mask)
|
| 940 |
+
|
| 941 |
+
if image_padding_len > 0:
|
| 942 |
+
padding_tensor = image_reshaped[-1:].repeat(image_padding_len, 1)
|
| 943 |
+
image_padded_feat = torch.cat([image_reshaped, padding_tensor], dim=0)
|
| 944 |
+
else:
|
| 945 |
+
image_padded_feat = image_reshaped
|
| 946 |
+
all_image_out.append(image_padded_feat)
|
| 947 |
+
|
| 948 |
+
return (
|
| 949 |
+
all_image_out,
|
| 950 |
+
all_cap_feats_out,
|
| 951 |
+
all_image_size,
|
| 952 |
+
all_image_pos_ids,
|
| 953 |
+
all_cap_pos_ids,
|
| 954 |
+
all_image_pad_mask,
|
| 955 |
+
all_cap_pad_mask,
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
def _process_cap_feats_with_cfg_cache(self, cap_feats_list, cap_pos_ids, cap_inner_pad_mask):
|
| 959 |
+
"""
|
| 960 |
+
Processes caption features with intelligent duplicate detection to avoid redundant computation,
|
| 961 |
+
especially for Classifier-Free Guidance (CFG) where prompts are repeated.
|
| 962 |
+
|
| 963 |
+
Args:
|
| 964 |
+
cap_feats_list (List[torch.Tensor]): List of padded caption feature tensors.
|
| 965 |
+
cap_pos_ids (List[torch.Tensor]): List of corresponding position ID tensors.
|
| 966 |
+
cap_inner_pad_mask (List[torch.Tensor]): List of corresponding padding masks.
|
| 967 |
+
|
| 968 |
+
Returns:
|
| 969 |
+
Tuple: A tuple of batched tensors for padded features, RoPE frequencies, attention mask, and sequence lengths.
|
| 970 |
+
"""
|
| 971 |
+
device = cap_feats_list[0].device
|
| 972 |
+
bsz = len(cap_feats_list)
|
| 973 |
+
|
| 974 |
+
shapes_equal = all(c.shape == cap_feats_list[0].shape for c in cap_feats_list)
|
| 975 |
+
|
| 976 |
+
if shapes_equal and bsz >= 2:
|
| 977 |
+
unique_indices = [0]
|
| 978 |
+
unique_tensors = [cap_feats_list[0]]
|
| 979 |
+
tensor_mapping = [0]
|
| 980 |
+
|
| 981 |
+
for i in range(1, bsz):
|
| 982 |
+
found_match = False
|
| 983 |
+
for j, unique_tensor in enumerate(unique_tensors):
|
| 984 |
+
if torch.equal(cap_feats_list[i], unique_tensor):
|
| 985 |
+
tensor_mapping.append(j)
|
| 986 |
+
found_match = True
|
| 987 |
+
break
|
| 988 |
+
|
| 989 |
+
if not found_match:
|
| 990 |
+
unique_indices.append(i)
|
| 991 |
+
unique_tensors.append(cap_feats_list[i])
|
| 992 |
+
tensor_mapping.append(len(unique_tensors) - 1)
|
| 993 |
+
|
| 994 |
+
if len(unique_tensors) < bsz:
|
| 995 |
+
unique_cap_feats_list = [cap_feats_list[i] for i in unique_indices]
|
| 996 |
+
unique_cap_pos_ids = [cap_pos_ids[i] for i in unique_indices]
|
| 997 |
+
unique_cap_inner_pad_mask = [cap_inner_pad_mask[i] for i in unique_indices]
|
| 998 |
+
|
| 999 |
+
cap_item_seqlens_unique = [len(i) for i in unique_cap_feats_list]
|
| 1000 |
+
cap_max_item_seqlen = max(cap_item_seqlens_unique)
|
| 1001 |
+
|
| 1002 |
+
cap_feats_cat = torch.cat(unique_cap_feats_list, dim=0)
|
| 1003 |
+
cap_feats_embedded = self.cap_embedder(cap_feats_cat)
|
| 1004 |
+
cap_feats_embedded[torch.cat(unique_cap_inner_pad_mask)] = self.cap_pad_token
|
| 1005 |
+
cap_feats_padded_unique = pad_sequence(list(cap_feats_embedded.split(cap_item_seqlens_unique, dim=0)), batch_first=True, padding_value=0.0)
|
| 1006 |
+
|
| 1007 |
+
cap_freqs_cis_cat = self.rope_embedder(torch.cat(unique_cap_pos_ids, dim=0))
|
| 1008 |
+
cap_freqs_cis_unique = pad_sequence(list(cap_freqs_cis_cat.split(cap_item_seqlens_unique, dim=0)), batch_first=True, padding_value=0.0)
|
| 1009 |
+
|
| 1010 |
+
cap_feats_padded = cap_feats_padded_unique[tensor_mapping]
|
| 1011 |
+
cap_freqs_cis = cap_freqs_cis_unique[tensor_mapping]
|
| 1012 |
+
|
| 1013 |
+
seq_lens_tensor = torch.tensor([cap_max_item_seqlen] * bsz, device=device, dtype=torch.int32)
|
| 1014 |
+
arange = torch.arange(cap_max_item_seqlen, device=device, dtype=torch.int32)
|
| 1015 |
+
cap_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
|
| 1016 |
+
|
| 1017 |
+
cap_item_seqlens = [cap_max_item_seqlen] * bsz
|
| 1018 |
+
|
| 1019 |
+
return cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens
|
| 1020 |
+
|
| 1021 |
+
cap_item_seqlens = [len(i) for i in cap_feats_list]
|
| 1022 |
+
cap_max_item_seqlen = max(cap_item_seqlens)
|
| 1023 |
+
cap_feats_cat = torch.cat(cap_feats_list, dim=0)
|
| 1024 |
+
cap_feats_embedded = self.cap_embedder(cap_feats_cat)
|
| 1025 |
+
cap_feats_embedded[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
| 1026 |
+
cap_feats_padded = pad_sequence(list(cap_feats_embedded.split(cap_item_seqlens, dim=0)), batch_first=True, padding_value=0.0)
|
| 1027 |
+
|
| 1028 |
+
cap_freqs_cis_cat = self.rope_embedder(torch.cat(cap_pos_ids, dim=0))
|
| 1029 |
+
cap_freqs_cis = pad_sequence(list(cap_freqs_cis_cat.split(cap_item_seqlens, dim=0)), batch_first=True, padding_value=0.0)
|
| 1030 |
+
|
| 1031 |
+
seq_lens_tensor = torch.tensor(cap_item_seqlens, device=device, dtype=torch.int32)
|
| 1032 |
+
arange = torch.arange(cap_max_item_seqlen, device=device, dtype=torch.int32)
|
| 1033 |
+
cap_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
|
| 1034 |
+
|
| 1035 |
+
return cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens
|
| 1036 |
+
|
| 1037 |
+
@staticmethod
|
| 1038 |
+
def _create_coordinate_grid(size, start=None, device=None):
|
| 1039 |
+
"""
|
| 1040 |
+
Creates a 3D coordinate grid.
|
| 1041 |
+
|
| 1042 |
+
Args:
|
| 1043 |
+
size (Tuple[int]): The dimensions of the grid (F, H, W).
|
| 1044 |
+
start (Tuple[int], optional): The starting coordinates for each axis. Defaults to (0, 0, 0).
|
| 1045 |
+
device (torch.device, optional): The device to create the tensor on. Defaults to None.
|
| 1046 |
+
|
| 1047 |
+
Returns:
|
| 1048 |
+
torch.Tensor: The coordinate grid tensor.
|
| 1049 |
+
"""
|
| 1050 |
+
if start is None:
|
| 1051 |
+
start = (0 for _ in size)
|
| 1052 |
+
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
| 1053 |
+
grids = torch.meshgrid(axes, indexing="ij")
|
| 1054 |
+
return torch.stack(grids, dim=-1)
|
| 1055 |
+
|
| 1056 |
+
def _apply_transformer_blocks(self, hidden_states, layers, checkpoint_ratio=0.5, **kwargs):
|
| 1057 |
+
"""
|
| 1058 |
+
Applies a list of transformer layers to the hidden states, with optional selective gradient checkpointing.
|
| 1059 |
+
|
| 1060 |
+
Args:
|
| 1061 |
+
hidden_states (torch.Tensor): The input tensor.
|
| 1062 |
+
layers (nn.ModuleList): The list of transformer layers to apply.
|
| 1063 |
+
checkpoint_ratio (float, optional): The ratio of layers to apply gradient checkpointing to. Defaults to 0.5.
|
| 1064 |
+
**kwargs: Additional keyword arguments to pass to each layer's forward method.
|
| 1065 |
+
|
| 1066 |
+
Returns:
|
| 1067 |
+
torch.Tensor: The output tensor after applying all layers.
|
| 1068 |
+
"""
|
| 1069 |
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
| 1070 |
+
|
| 1071 |
+
def create_custom_forward(module, **static_kwargs):
|
| 1072 |
+
def custom_forward(*inputs):
|
| 1073 |
+
return module(*inputs, **static_kwargs)
|
| 1074 |
+
|
| 1075 |
+
return custom_forward
|
| 1076 |
+
|
| 1077 |
+
ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 1078 |
+
|
| 1079 |
+
checkpoint_every_n = max(1, int(1.0 / checkpoint_ratio)) if checkpoint_ratio > 0 else len(layers) + 1
|
| 1080 |
+
|
| 1081 |
+
for i, layer in enumerate(layers):
|
| 1082 |
+
if i % checkpoint_every_n == 0:
|
| 1083 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 1084 |
+
create_custom_forward(layer, **kwargs),
|
| 1085 |
+
hidden_states,
|
| 1086 |
+
**ckpt_kwargs,
|
| 1087 |
+
)
|
| 1088 |
+
else:
|
| 1089 |
+
hidden_states = layer(hidden_states, **kwargs)
|
| 1090 |
+
else:
|
| 1091 |
+
for layer in layers:
|
| 1092 |
+
hidden_states = layer(hidden_states, **kwargs)
|
| 1093 |
+
|
| 1094 |
+
return hidden_states
|
| 1095 |
+
|
| 1096 |
+
def _prepare_control_inputs(self, control_context, cap_feats_ref, t, patch_size, f_patch_size, device):
|
| 1097 |
+
"""
|
| 1098 |
+
Prepares the control context for the transformer, including patchifying, embedding, and generating
|
| 1099 |
+
positional information. Includes a fast path for batches with uniform shapes.
|
| 1100 |
+
|
| 1101 |
+
Args:
|
| 1102 |
+
control_context (torch.Tensor or List[torch.Tensor]): The control context input.
|
| 1103 |
+
cap_feats_ref (List[torch.Tensor]): A reference to caption features for padding calculation.
|
| 1104 |
+
t (torch.Tensor): The timestep tensor.
|
| 1105 |
+
patch_size (int): The spatial patch size.
|
| 1106 |
+
f_patch_size (int): The frame/temporal patch size.
|
| 1107 |
+
device (torch.device): The target device.
|
| 1108 |
+
|
| 1109 |
+
Returns:
|
| 1110 |
+
Dict: A dictionary containing the processed control tensors ('c', 'c_item_seqlens', 'attn_mask', etc.).
|
| 1111 |
+
"""
|
| 1112 |
+
bsz = control_context.shape[0]
|
| 1113 |
+
|
| 1114 |
+
if isinstance(control_context, torch.Tensor) and control_context.ndim == 5:
|
| 1115 |
+
control_list = list(torch.unbind(control_context, dim=0))
|
| 1116 |
+
else:
|
| 1117 |
+
control_list = control_context
|
| 1118 |
+
|
| 1119 |
+
pH = pW = patch_size
|
| 1120 |
+
pF = f_patch_size
|
| 1121 |
+
cap_padding_len = cap_feats_ref[0].size(0) if isinstance(cap_feats_ref, list) else cap_feats_ref.shape[1]
|
| 1122 |
+
|
| 1123 |
+
shapes = [c.shape for c in control_list]
|
| 1124 |
+
same_shape = all(s == shapes[0] for s in shapes)
|
| 1125 |
+
|
| 1126 |
+
if same_shape and bsz >= 2:
|
| 1127 |
+
control_batch = torch.stack(control_list, dim=0)
|
| 1128 |
+
B, C, F, H, W = control_batch.shape
|
| 1129 |
+
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
| 1130 |
+
|
| 1131 |
+
control_batch = control_batch.view(B, C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
| 1132 |
+
control_batch = control_batch.permute(0, 2, 4, 6, 3, 5, 7, 1).reshape(B, F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
| 1133 |
+
|
| 1134 |
+
ori_len = control_batch.shape[1]
|
| 1135 |
+
padding_len = (-ori_len) % SEQ_MULTI_OF
|
| 1136 |
+
|
| 1137 |
+
if padding_len > 0:
|
| 1138 |
+
pad_tensor = control_batch[:, -1:, :].repeat(1, padding_len, 1)
|
| 1139 |
+
control_batch = torch.cat([control_batch, pad_tensor], dim=1)
|
| 1140 |
+
|
| 1141 |
+
c = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_batch)
|
| 1142 |
+
|
| 1143 |
+
final_seq_len = control_batch.shape[1]
|
| 1144 |
+
pos_ids_ori = self._create_coordinate_grid(
|
| 1145 |
+
size=(F_tokens, H_tokens, W_tokens),
|
| 1146 |
+
start=(cap_padding_len + 1, 0, 0),
|
| 1147 |
+
device=device,
|
| 1148 |
+
).flatten(0, 2) # [ori_len, 3]
|
| 1149 |
+
|
| 1150 |
+
pos_ids_pad = torch.zeros((padding_len, 3), dtype=torch.int32, device=device)
|
| 1151 |
+
pos_ids_padded = torch.cat([pos_ids_ori, pos_ids_pad], dim=0)
|
| 1152 |
+
|
| 1153 |
+
c_freqs_cis_single = self.rope_embedder(pos_ids_padded)
|
| 1154 |
+
c_freqs_cis = c_freqs_cis_single.unsqueeze(0).repeat(B, 1, 1, 1)
|
| 1155 |
+
c_attn_mask = torch.ones((B, final_seq_len), dtype=torch.bool, device=device)
|
| 1156 |
+
|
| 1157 |
+
return {"c": c, "c_item_seqlens": [final_seq_len] * B, "attn_mask": c_attn_mask, "freqs_cis": c_freqs_cis, "adaln_input": t.type_as(c)}
|
| 1158 |
+
|
| 1159 |
+
(c_patches, _, c_pos_ids, c_inner_pad_mask) = self._patchify(control_list, patch_size, f_patch_size, cap_padding_len)
|
| 1160 |
+
|
| 1161 |
+
c_item_seqlens = [len(p) for p in c_patches]
|
| 1162 |
+
c_max_item_seqlen = max(c_item_seqlens)
|
| 1163 |
+
|
| 1164 |
+
c = torch.cat(c_patches, dim=0)
|
| 1165 |
+
c = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](c)
|
| 1166 |
+
c[torch.cat(c_inner_pad_mask)] = self.x_pad_token
|
| 1167 |
+
c = list(c.split(c_item_seqlens, dim=0))
|
| 1168 |
+
|
| 1169 |
+
c_freqs_cis_list = []
|
| 1170 |
+
for pos_ids in c_pos_ids:
|
| 1171 |
+
c_freqs_cis_list.append(self.rope_embedder(pos_ids))
|
| 1172 |
+
|
| 1173 |
+
c_padded = pad_sequence(c, batch_first=True, padding_value=0.0)
|
| 1174 |
+
c_freqs_cis_padded = pad_sequence(c_freqs_cis_list, batch_first=True, padding_value=0.0)
|
| 1175 |
+
|
| 1176 |
+
seq_lens_tensor = torch.tensor(c_item_seqlens, device=device, dtype=torch.int32)
|
| 1177 |
+
arange = torch.arange(c_max_item_seqlen, device=device, dtype=torch.int32)
|
| 1178 |
+
c_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
|
| 1179 |
+
|
| 1180 |
+
return {"c": c_padded, "c_item_seqlens": c_item_seqlens, "attn_mask": c_attn_mask, "freqs_cis": c_freqs_cis_padded, "adaln_input": t.type_as(c_padded)}
|
| 1181 |
+
|
| 1182 |
+
def _patchify_and_embed_batch_optimized(self, all_image, all_cap_feats, patch_size, f_patch_size):
|
| 1183 |
+
"""
|
| 1184 |
+
An optimized version of _patchify_and_embed for batches where all images and captions have
|
| 1185 |
+
uniform shapes. It processes the entire batch using vectorized operations instead of a loop.
|
| 1186 |
+
|
| 1187 |
+
Args:
|
| 1188 |
+
all_image (List[torch.Tensor]): List of image tensors, all of the same shape.
|
| 1189 |
+
all_cap_feats (List[torch.Tensor]): List of caption features, all of the same shape.
|
| 1190 |
+
patch_size (int): The spatial patch size.
|
| 1191 |
+
f_patch_size (int): The frame/temporal patch size.
|
| 1192 |
+
|
| 1193 |
+
Returns:
|
| 1194 |
+
Tuple: A tuple containing all processed data structures, matching the output of the standard method.
|
| 1195 |
+
"""
|
| 1196 |
+
pH = pW = patch_size
|
| 1197 |
+
pF = f_patch_size
|
| 1198 |
+
device = all_image[0].device
|
| 1199 |
+
|
| 1200 |
+
image_shapes = [img.shape for img in all_image]
|
| 1201 |
+
cap_shapes = [cap.shape for cap in all_cap_feats]
|
| 1202 |
+
|
| 1203 |
+
same_image_shape = all(s == image_shapes[0] for s in image_shapes)
|
| 1204 |
+
same_cap_shape = all(s == cap_shapes[0] for s in cap_shapes)
|
| 1205 |
+
|
| 1206 |
+
if not (same_image_shape and same_cap_shape):
|
| 1207 |
+
return self._patchify_and_embed(all_image, all_cap_feats, patch_size, f_patch_size)
|
| 1208 |
+
|
| 1209 |
+
images_batch = torch.stack(all_image, dim=0)
|
| 1210 |
+
caps_batch = torch.stack(all_cap_feats, dim=0)
|
| 1211 |
+
|
| 1212 |
+
B, C, Fr, H, W = images_batch.shape
|
| 1213 |
+
cap_ori_len = caps_batch.shape[1]
|
| 1214 |
+
|
| 1215 |
+
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
|
| 1216 |
+
cap_total_len = cap_ori_len + cap_padding_len
|
| 1217 |
+
|
| 1218 |
+
if cap_padding_len > 0:
|
| 1219 |
+
cap_pad = caps_batch[:, -1:, :].repeat(1, cap_padding_len, 1)
|
| 1220 |
+
caps_batch = torch.cat([caps_batch, cap_pad], dim=1)
|
| 1221 |
+
|
| 1222 |
+
cap_pos_ids = self._create_coordinate_grid(size=(cap_total_len, 1, 1), start=(1, 0, 0), device=device).flatten(0, 2).unsqueeze(0).repeat(B, 1, 1)
|
| 1223 |
+
|
| 1224 |
+
cap_mask = torch.zeros((B, cap_total_len), dtype=torch.bool, device=device)
|
| 1225 |
+
if cap_padding_len > 0:
|
| 1226 |
+
cap_mask[:, cap_ori_len:] = True
|
| 1227 |
+
|
| 1228 |
+
F_tokens, H_tokens, W_tokens = Fr // pF, H // pH, W // pW
|
| 1229 |
+
images_reshaped = (
|
| 1230 |
+
images_batch.view(B, C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
| 1231 |
+
.permute(0, 2, 4, 6, 3, 5, 7, 1)
|
| 1232 |
+
.reshape(B, F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
| 1233 |
+
)
|
| 1234 |
+
|
| 1235 |
+
image_ori_len = images_reshaped.shape[1]
|
| 1236 |
+
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
| 1237 |
+
image_total_len = image_ori_len + image_padding_len
|
| 1238 |
+
|
| 1239 |
+
if image_padding_len > 0:
|
| 1240 |
+
img_pad = images_reshaped[:, -1:, :].repeat(1, image_padding_len, 1)
|
| 1241 |
+
images_reshaped = torch.cat([images_reshaped, img_pad], dim=1)
|
| 1242 |
+
|
| 1243 |
+
image_pos_ids = (
|
| 1244 |
+
self._create_coordinate_grid(size=(F_tokens, H_tokens, W_tokens), start=(cap_total_len + 1, 0, 0), device=device)
|
| 1245 |
+
.flatten(0, 2)
|
| 1246 |
+
.unsqueeze(0)
|
| 1247 |
+
.repeat(B, 1, 1)
|
| 1248 |
+
)
|
| 1249 |
+
|
| 1250 |
+
if image_padding_len > 0:
|
| 1251 |
+
img_pos_pad = torch.zeros((B, image_padding_len, 3), dtype=torch.int32, device=device)
|
| 1252 |
+
image_pos_ids = torch.cat([image_pos_ids, img_pos_pad], dim=1)
|
| 1253 |
+
|
| 1254 |
+
image_mask = torch.zeros((B, image_total_len), dtype=torch.bool, device=device)
|
| 1255 |
+
if image_padding_len > 0:
|
| 1256 |
+
image_mask[:, image_ori_len:] = True
|
| 1257 |
+
|
| 1258 |
+
all_image_size = [(Fr, H, W)] * B
|
| 1259 |
+
|
| 1260 |
+
return (
|
| 1261 |
+
list(torch.unbind(images_reshaped, dim=0)),
|
| 1262 |
+
list(torch.unbind(caps_batch, dim=0)),
|
| 1263 |
+
all_image_size,
|
| 1264 |
+
list(torch.unbind(image_pos_ids, dim=0)),
|
| 1265 |
+
list(torch.unbind(cap_pos_ids, dim=0)),
|
| 1266 |
+
list(torch.unbind(image_mask, dim=0)),
|
| 1267 |
+
list(torch.unbind(cap_mask, dim=0)),
|
| 1268 |
+
)
|
| 1269 |
+
|
| 1270 |
+
def forward(
|
| 1271 |
+
self,
|
| 1272 |
+
x: List[torch.Tensor],
|
| 1273 |
+
t,
|
| 1274 |
+
cap_feats: List[torch.Tensor],
|
| 1275 |
+
patch_size=2,
|
| 1276 |
+
f_patch_size=1,
|
| 1277 |
+
control_context=None,
|
| 1278 |
+
conditioning_scale=1.0,
|
| 1279 |
+
refiner_conditioning_scale=1.0,
|
| 1280 |
+
):
|
| 1281 |
+
"""
|
| 1282 |
+
The main forward pass of the transformer model.
|
| 1283 |
+
|
| 1284 |
+
Args:
|
| 1285 |
+
x (List[torch.Tensor]):
|
| 1286 |
+
A list of latent image tensors.
|
| 1287 |
+
t (torch.Tensor):
|
| 1288 |
+
A batch of timesteps.
|
| 1289 |
+
cap_feats (List[torch.Tensor]):
|
| 1290 |
+
A list of caption feature tensors.
|
| 1291 |
+
patch_size (int, optional):
|
| 1292 |
+
The spatial patch size to use. Defaults to 2.
|
| 1293 |
+
f_patch_size (int, optional):
|
| 1294 |
+
The frame/temporal patch size to use. Defaults to 1.
|
| 1295 |
+
control_context (torch.Tensor, optional):
|
| 1296 |
+
The control context tensor. Defaults to None.
|
| 1297 |
+
conditioning_scale (float, optional):
|
| 1298 |
+
The scale for applying control hints. Defaults to 1.0.
|
| 1299 |
+
refiner_conditioning_scale (float, optional):
|
| 1300 |
+
The scale for applying refiner control hints. Defaults to 1.0.
|
| 1301 |
+
|
| 1302 |
+
Returns:
|
| 1303 |
+
Transformer2DModelOutput: An object containing the final denoised sample.
|
| 1304 |
+
"""
|
| 1305 |
+
|
| 1306 |
+
is_control_mode = self.use_controlnet and control_context is not None and conditioning_scale > 0
|
| 1307 |
+
if refiner_conditioning_scale is None:
|
| 1308 |
+
refiner_conditioning_scale = conditioning_scale or 1.0
|
| 1309 |
+
|
| 1310 |
+
assert patch_size in self.all_patch_size
|
| 1311 |
+
assert f_patch_size in self.all_f_patch_size
|
| 1312 |
+
|
| 1313 |
+
bsz = len(x)
|
| 1314 |
+
device = x[0].device
|
| 1315 |
+
|
| 1316 |
+
t = t * self.t_scale
|
| 1317 |
+
t = self.t_embedder(t)
|
| 1318 |
+
|
| 1319 |
+
can_optimize_patchify = (
|
| 1320 |
+
bsz == len(cap_feats) and bsz >= 2 and all(img.shape == x[0].shape for img in x) and all(cap.shape == cap_feats[0].shape for cap in cap_feats)
|
| 1321 |
+
)
|
| 1322 |
+
|
| 1323 |
+
if can_optimize_patchify:
|
| 1324 |
+
(x_list, cap_feats_list, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask) = self._patchify_and_embed_batch_optimized(
|
| 1325 |
+
x, cap_feats, patch_size, f_patch_size
|
| 1326 |
+
)
|
| 1327 |
+
else:
|
| 1328 |
+
(x_list, cap_feats_list, x_size, x_pos_ids, cap_pos_ids, x_inner_pad_mask, cap_inner_pad_mask) = self._patchify_and_embed(
|
| 1329 |
+
x, cap_feats, patch_size, f_patch_size
|
| 1330 |
+
)
|
| 1331 |
+
|
| 1332 |
+
x_item_seqlens = [len(i) for i in x_list]
|
| 1333 |
+
x_max_item_seqlen = max(x_item_seqlens) if x_item_seqlens else 0
|
| 1334 |
+
x_cat = torch.cat(x_list, dim=0) if x_list else torch.empty(0, x_list[0].shape[1] if x_list else 0, device=device)
|
| 1335 |
+
x_embedded = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_cat)
|
| 1336 |
+
if x_inner_pad_mask and torch.cat(x_inner_pad_mask).any():
|
| 1337 |
+
x_embedded[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
| 1338 |
+
x = pad_sequence(list(x_embedded.split(x_item_seqlens, dim=0)), batch_first=True, padding_value=0.0)
|
| 1339 |
+
adaln_input = t.to(device).type_as(x)
|
| 1340 |
+
|
| 1341 |
+
cap_feats_padded, cap_freqs_cis, cap_attn_mask, cap_item_seqlens = self._process_cap_feats_with_cfg_cache(
|
| 1342 |
+
cap_feats_list, cap_pos_ids, cap_inner_pad_mask
|
| 1343 |
+
)
|
| 1344 |
+
|
| 1345 |
+
x_freqs_cis_cat = self.rope_embedder(torch.cat(x_pos_ids, dim=0)) if x_pos_ids else torch.empty(0, device=device)
|
| 1346 |
+
x_freqs_cis = pad_sequence(list(x_freqs_cis_cat.split(x_item_seqlens, dim=0)), batch_first=True, padding_value=0.0)
|
| 1347 |
+
|
| 1348 |
+
seq_lens_tensor = torch.tensor(x_item_seqlens, device=device, dtype=torch.int32)
|
| 1349 |
+
arange = torch.arange(x_max_item_seqlen, device=device, dtype=torch.int32)
|
| 1350 |
+
x_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
|
| 1351 |
+
|
| 1352 |
+
refiner_hints = None
|
| 1353 |
+
if is_control_mode and self.is_two_stage_control:
|
| 1354 |
+
prepared_control = self._prepare_control_inputs(control_context, cap_feats_padded, t, patch_size, f_patch_size, device)
|
| 1355 |
+
c = prepared_control["c"]
|
| 1356 |
+
kwargs_for_control_refiner = {
|
| 1357 |
+
"x": x,
|
| 1358 |
+
"attn_mask": prepared_control["attn_mask"],
|
| 1359 |
+
"freqs_cis": prepared_control["freqs_cis"],
|
| 1360 |
+
"adaln_input": prepared_control["adaln_input"],
|
| 1361 |
+
}
|
| 1362 |
+
c_processed = self._apply_transformer_blocks(
|
| 1363 |
+
c,
|
| 1364 |
+
self.control_noise_refiner if self.add_control_noise_refiner else self.control_layers,
|
| 1365 |
+
checkpoint_ratio=self.checkpoint_ratio,
|
| 1366 |
+
**kwargs_for_control_refiner,
|
| 1367 |
+
)
|
| 1368 |
+
refiner_hints = torch.unbind(c_processed)[:-1]
|
| 1369 |
+
control_context_processed = torch.unbind(c_processed)[-1]
|
| 1370 |
+
control_context_item_seqlens = prepared_control["c_item_seqlens"]
|
| 1371 |
+
|
| 1372 |
+
kwargs_for_refiner = {
|
| 1373 |
+
"attn_mask": x_attn_mask,
|
| 1374 |
+
"freqs_cis": x_freqs_cis,
|
| 1375 |
+
"adaln_input": adaln_input,
|
| 1376 |
+
"context_scale": refiner_conditioning_scale,
|
| 1377 |
+
}
|
| 1378 |
+
if refiner_hints is not None:
|
| 1379 |
+
kwargs_for_refiner["hints"] = refiner_hints
|
| 1380 |
+
x = self._apply_transformer_blocks(x, self.noise_refiner, checkpoint_ratio=1.0, **kwargs_for_refiner)
|
| 1381 |
+
|
| 1382 |
+
kwargs_for_context = {"attn_mask": cap_attn_mask, "freqs_cis": cap_freqs_cis}
|
| 1383 |
+
cap_feats = self._apply_transformer_blocks(cap_feats_padded, self.context_refiner, checkpoint_ratio=1.0, **kwargs_for_context)
|
| 1384 |
+
|
| 1385 |
+
unified_item_seqlens = [a + b for a, b in zip(x_item_seqlens, cap_item_seqlens)]
|
| 1386 |
+
unified_max_item_seqlen = max(unified_item_seqlens) if unified_item_seqlens else 0
|
| 1387 |
+
unified = torch.zeros((bsz, unified_max_item_seqlen, x.shape[-1]), dtype=x.dtype, device=device)
|
| 1388 |
+
unified_freqs_cis = torch.zeros((bsz, unified_max_item_seqlen, x_freqs_cis.shape[-2], x_freqs_cis.shape[-1]), dtype=x_freqs_cis.dtype, device=device)
|
| 1389 |
+
|
| 1390 |
+
for i in range(bsz):
|
| 1391 |
+
x_len = x_item_seqlens[i]
|
| 1392 |
+
cap_len = cap_item_seqlens[i]
|
| 1393 |
+
unified[i, :x_len] = x[i, :x_len]
|
| 1394 |
+
unified[i, x_len : x_len + cap_len] = cap_feats[i, :cap_len]
|
| 1395 |
+
unified_freqs_cis[i, :x_len] = x_freqs_cis[i, :x_len]
|
| 1396 |
+
unified_freqs_cis[i, x_len : x_len + cap_len] = cap_freqs_cis[i, :cap_len]
|
| 1397 |
+
|
| 1398 |
+
seq_lens_tensor = torch.tensor(unified_item_seqlens, device=device, dtype=torch.int32)
|
| 1399 |
+
arange = torch.arange(unified_max_item_seqlen, device=device, dtype=torch.int32)
|
| 1400 |
+
unified_attn_mask = arange[None, :] < seq_lens_tensor[:, None]
|
| 1401 |
+
|
| 1402 |
+
hints = None
|
| 1403 |
+
if is_control_mode:
|
| 1404 |
+
kwargs_for_hints = {
|
| 1405 |
+
"attn_mask": unified_attn_mask,
|
| 1406 |
+
"freqs_cis": unified_freqs_cis,
|
| 1407 |
+
"adaln_input": adaln_input,
|
| 1408 |
+
}
|
| 1409 |
+
if self.is_two_stage_control:
|
| 1410 |
+
control_context_unified_list = [
|
| 1411 |
+
torch.cat([control_context_processed[i][: control_context_item_seqlens[i]], cap_feats[i, : cap_item_seqlens[i]]], dim=0) for i in range(bsz)
|
| 1412 |
+
]
|
| 1413 |
+
c = pad_sequence(control_context_unified_list, batch_first=True, padding_value=0.0)
|
| 1414 |
+
new_kwargs = dict(x=unified, **kwargs_for_hints)
|
| 1415 |
+
c_processed = self._apply_transformer_blocks(c, self.control_layers, checkpoint_ratio=self.checkpoint_ratio, **new_kwargs)
|
| 1416 |
+
hints = torch.unbind(c_processed)[:-1]
|
| 1417 |
+
else:
|
| 1418 |
+
prepared_control = self._prepare_control_inputs(control_context, cap_feats_padded, t, patch_size, f_patch_size, device)
|
| 1419 |
+
c = prepared_control["c"]
|
| 1420 |
+
kwargs_for_v1_refiner = {
|
| 1421 |
+
"attn_mask": prepared_control["attn_mask"],
|
| 1422 |
+
"freqs_cis": prepared_control["freqs_cis"],
|
| 1423 |
+
"adaln_input": prepared_control["adaln_input"],
|
| 1424 |
+
}
|
| 1425 |
+
c = self._apply_transformer_blocks(c, self.control_noise_refiner, checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_v1_refiner)
|
| 1426 |
+
c_item_seqlens = prepared_control["c_item_seqlens"]
|
| 1427 |
+
control_context_unified_list = [torch.cat([c[i, : c_item_seqlens[i]], cap_feats[i, : cap_item_seqlens[i]]], dim=0) for i in range(bsz)]
|
| 1428 |
+
c_unified = pad_sequence(control_context_unified_list, batch_first=True, padding_value=0.0)
|
| 1429 |
+
new_kwargs = dict(x=unified, **kwargs_for_hints)
|
| 1430 |
+
c_processed = self._apply_transformer_blocks(c_unified, self.control_layers, checkpoint_ratio=self.checkpoint_ratio, **new_kwargs)
|
| 1431 |
+
hints = torch.unbind(c_processed)[:-1]
|
| 1432 |
+
|
| 1433 |
+
kwargs_for_layers = {"attn_mask": unified_attn_mask, "freqs_cis": unified_freqs_cis, "adaln_input": adaln_input}
|
| 1434 |
+
if hints is not None:
|
| 1435 |
+
kwargs_for_layers["hints"] = hints
|
| 1436 |
+
kwargs_for_layers["context_scale"] = conditioning_scale
|
| 1437 |
+
unified = self._apply_transformer_blocks(unified, self.layers, checkpoint_ratio=self.checkpoint_ratio, **kwargs_for_layers)
|
| 1438 |
+
|
| 1439 |
+
unified_out = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
|
| 1440 |
+
x_image_tokens = unified_out[:, :x_max_item_seqlen]
|
| 1441 |
+
x_final_tensor = self._unpatchify(x_image_tokens, x_size, patch_size, f_patch_size)
|
| 1442 |
+
|
| 1443 |
+
return Transformer2DModelOutput(sample=x_final_tensor)
|
infer_controlnet.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from diffusers import FlowMatchEulerDiscreteScheduler, GGUFQuantizationConfig
|
| 7 |
+
from diffusers.utils import load_image
|
| 8 |
+
from diffusers_local import patch # Apply necessary patches for local diffusers components
|
| 9 |
+
|
| 10 |
+
# 1. Import all necessary components
|
| 11 |
+
from diffusers_local.pipeline_z_image_control_unified import ZImageControlUnifiedPipeline
|
| 12 |
+
from diffusers_local.z_image_control_transformer_2d import ZImageControlTransformer2DModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,garbage_collection_threshold:0.7,max_split_size_mb:1024"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
# 1. Set params
|
| 20 |
+
BASE_MODEL_ID = "."
|
| 21 |
+
GGUF_MODEL_FILE = "./transformer/z_image_turbo_control_unified_v2.1_q4_k_m.gguf"
|
| 22 |
+
|
| 23 |
+
use_gguf = True
|
| 24 |
+
|
| 25 |
+
control_mode = "depth" # (pose, canny, depth, hed, mlsd)
|
| 26 |
+
negative_prompt = "low quality, blurry, ugly, deformed fingers, extra fingers, bad hand, bad anatomy, noise, overexposed, underexposed"
|
| 27 |
+
guidance_scale = 0
|
| 28 |
+
seed = 43
|
| 29 |
+
shift = 3.0
|
| 30 |
+
controlnet_conditioning_refiner_scale = None
|
| 31 |
+
|
| 32 |
+
if control_mode == "pose":
|
| 33 |
+
#prompt="一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动,裙摆轻盈飞扬。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,肤色白皙细腻,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,手指清晰可见、五指完整、指节自然、姿势优雅放松,仿佛沉浸在思绪之中。背景是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕,海浪轻轻拍打沙滩,天空湛蓝云朵稀薄。整体画面高清锐利、细节丰富、色彩鲜艳、焦点清晰、8K分辨率、杰作、最佳质量、无模糊、无噪点、无畸变、自然光照、电影级渲染。"
|
| 34 |
+
prompt = "Photorealistic portrait of a beautiful young East Asian woman with long, vibrant purple hair and a black bow. She is wearing a flowing white summer dress, standing on a sunny beach with a sparkling ocean and clear blue sky in the background. Bright natural sunlight, sharp focus, ultra-detailed."
|
| 35 |
+
control_image = load_image("assets/pose.jpg")
|
| 36 |
+
target_height = 1728
|
| 37 |
+
target_width = 992
|
| 38 |
+
num_inference_steps = 25
|
| 39 |
+
controlnet_conditioning_scale = 0.75
|
| 40 |
+
controlnet_conditioning_refiner_scale = 1.0
|
| 41 |
+
#guidance_scale = 2.5
|
| 42 |
+
elif control_mode == "canny":
|
| 43 |
+
prompt = "A jaguar in the forest, soft cinematic lighting, balanced exposure, 8K UHD, DSLR camera, sharp focus, realistic texture."
|
| 44 |
+
prompt = "A masterpiece photograph, the face of a stunning leopard in a moment of calm intensity. It is peeking from a hideaway of dark green ivy leaves and tiny jasmine flowers. Its amber-colored eyes are the focal point, locked onto the viewer with a piercing intelligence. The light is cinematic, soft and directional, sculpting the animal's features and highlighting the velvety texture of its fur and the wet gleam of its nose. Intimate and silent atmosphere. 4K quality, ultra-realistic."
|
| 45 |
+
control_image = load_image("assets/canny.jpg")
|
| 46 |
+
target_height = 1328
|
| 47 |
+
target_width = 880
|
| 48 |
+
num_inference_steps = 25
|
| 49 |
+
controlnet_conditioning_scale = 1.0
|
| 50 |
+
elif control_mode == "depth":
|
| 51 |
+
prompt = "Photorealistic portrait of a fluffy long-haired cat, sitting in a forest at night. The cat is in the foreground, close-up, and in sharp focus. The background with trees is heavily blurred, creating a strong bokeh effect. Soft lighting from the front illuminates the cat."
|
| 52 |
+
control_image = load_image("assets/depth_cat.png")
|
| 53 |
+
target_height = 1024
|
| 54 |
+
target_width = 1024
|
| 55 |
+
num_inference_steps = 15
|
| 56 |
+
controlnet_conditioning_scale = 0.7
|
| 57 |
+
guidance_scale = 1.5
|
| 58 |
+
elif control_mode == "hed":
|
| 59 |
+
# prompt="raw photo, portrait of a handsome Asian man sitting at a wooden table, holding a green glass bottle, wearing a black sweater, wristwatch, highly detailed skin texture, realistic pores, serious gaze, soft cinematic lighting, rim lighting, balanced exposure, 8k uhd, dslr, sharp focus, wood grain texture."
|
| 60 |
+
prompt = "Cinematic film still, an ultra-realistic portrait of a melancholic Korean man in a dimly lit room. He is sitting at a dark wooden table, his hands wrapped around a green soju bottle. Rembrandt-style lighting, with a soft key light from the side sculpting his features and casting the other side in deep shadow (chiaroscuro). Sharp focus on his weary, expressive eyes. Shallow depth of field, with the dark background blurred out. Subtle film grain, art-house cinema aesthetic."
|
| 61 |
+
negative_prompt = "underexposed, crushed blacks, too dark, heavy shadows, makeup, smooth skin, plastic, wax, cartoon, illustration, distorted hands, bad anatomy, blur, haze, flat lighting"
|
| 62 |
+
control_image = load_image("assets/man_hed.png")
|
| 63 |
+
target_height = 1024
|
| 64 |
+
target_width = 768
|
| 65 |
+
num_inference_steps = 25
|
| 66 |
+
controlnet_conditioning_scale = 0.7
|
| 67 |
+
guidance_scale = 2.5
|
| 68 |
+
elif control_mode == "mlsd":
|
| 69 |
+
prompt = "RAW photo, professional interior design photography of a bright and clean contemporary home office. A sleek white desk with distressed wood grain drawers and chrome handles sits before a large window. A modern white ergonomic chair. To the left, a tall built-in white bookshelf with warm, backlit shelves. A dark wood accent wall. Cozy beige chaise lounge with a decorative red pillow."
|
| 70 |
+
control_image = load_image("assets/room_mlsd.png")
|
| 71 |
+
target_height = 1024
|
| 72 |
+
target_width = 1024
|
| 73 |
+
num_inference_steps = 25
|
| 74 |
+
controlnet_conditioning_scale = 0.85
|
| 75 |
+
controlnet_conditioning_refiner_scale = 1.0
|
| 76 |
+
|
| 77 |
+
generator = torch.Generator("cuda").manual_seed(seed)
|
| 78 |
+
|
| 79 |
+
print("Loading Pipeline...")
|
| 80 |
+
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
|
| 81 |
+
|
| 82 |
+
if use_gguf:
|
| 83 |
+
transformer = ZImageControlTransformer2DModel.from_single_file(
|
| 84 |
+
GGUF_MODEL_FILE,
|
| 85 |
+
torch_dtype=torch.bfloat16,
|
| 86 |
+
config=str(Path(GGUF_MODEL_FILE).parent),
|
| 87 |
+
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
| 88 |
+
add_control_noise_refiner=True, # <== If you don't want to use the control noise refiner disable here.
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
transformer = ZImageControlTransformer2DModel.from_pretrained(
|
| 92 |
+
BASE_MODEL_ID,
|
| 93 |
+
subfolder="transformer",
|
| 94 |
+
torch_dtype=torch.bfloat16,
|
| 95 |
+
add_control_noise_refiner=True, # <== If you don't want to use the control noise refiner disable here.
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
pipe = ZImageControlUnifiedPipeline.from_pretrained(
|
| 99 |
+
BASE_MODEL_ID,
|
| 100 |
+
torch_dtype=torch.bfloat16,
|
| 101 |
+
transformer=transformer, # You don't need to load the transformer here if you don't intend to disable add_control_noise_refiner (only for diffusers model, gguf is required).
|
| 102 |
+
)
|
| 103 |
+
pipe.scheduler = scheduler
|
| 104 |
+
|
| 105 |
+
# Apply optimization (Optional)
|
| 106 |
+
pipe.enable_group_offload(
|
| 107 |
+
onload_device="cuda",
|
| 108 |
+
offload_device="cpu",
|
| 109 |
+
offload_type="block_level",
|
| 110 |
+
num_blocks_per_group=1,
|
| 111 |
+
low_cpu_mem_usage=True,
|
| 112 |
+
use_stream=True,
|
| 113 |
+
)
|
| 114 |
+
pipe.vae.use_tiling = True
|
| 115 |
+
# ---
|
| 116 |
+
|
| 117 |
+
print("\nRunning Inference...")
|
| 118 |
+
start_inference_time = time.time()
|
| 119 |
+
|
| 120 |
+
generated_image = pipe(
|
| 121 |
+
prompt=prompt,
|
| 122 |
+
negative_prompt=negative_prompt,
|
| 123 |
+
control_image=control_image,
|
| 124 |
+
height=target_height,
|
| 125 |
+
width=target_width,
|
| 126 |
+
num_inference_steps=num_inference_steps,
|
| 127 |
+
guidance_scale=guidance_scale,
|
| 128 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
| 129 |
+
controlnet_refiner_conditioning_scale=controlnet_conditioning_refiner_scale,
|
| 130 |
+
generator=generator,
|
| 131 |
+
).images[0]
|
| 132 |
+
|
| 133 |
+
end_inference_time = time.time()
|
| 134 |
+
print(f"\nGeneration finished in {end_inference_time - start_inference_time:.2f} seconds.")
|
| 135 |
+
|
| 136 |
+
# Save Output
|
| 137 |
+
if not os.path.exists("outputs"):
|
| 138 |
+
os.makedirs("outputs")
|
| 139 |
+
output_filename = f"outputs/z_image_controlnet_result_control_{control_mode}.png"
|
| 140 |
+
generated_image.save(output_filename)
|
| 141 |
+
print(f"Image successfully saved as '{output_filename}'")
|
| 142 |
+
generated_image.show()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|
infer_i2i.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from diffusers import FlowMatchEulerDiscreteScheduler, GGUFQuantizationConfig
|
| 7 |
+
from diffusers.utils import load_image
|
| 8 |
+
from diffusers_local import patch # Apply necessary patches for local diffusers components
|
| 9 |
+
|
| 10 |
+
# 1. Import all necessary components
|
| 11 |
+
from diffusers_local.pipeline_z_image_control_unified import ZImageControlUnifiedPipeline
|
| 12 |
+
from diffusers_local.z_image_control_transformer_2d import ZImageControlTransformer2DModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,garbage_collection_threshold:0.7,max_split_size_mb:1024"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
# 1. Set params
|
| 20 |
+
BASE_MODEL_ID = "."
|
| 21 |
+
GGUF_MODEL_FILE = "./transformer/z_image_turbo_control_unified_v2.1_q4_k_m.gguf"
|
| 22 |
+
|
| 23 |
+
use_gguf = True
|
| 24 |
+
|
| 25 |
+
prompt = "a asian man with a bottle"
|
| 26 |
+
negative_prompt = "Low quality, blurry, ugly, deformed fingers, extra fingers, bad hand, bad anatomy, noise, overexposed, underexposed"
|
| 27 |
+
|
| 28 |
+
target_height = 1024
|
| 29 |
+
target_width = 768
|
| 30 |
+
num_inference_steps = 9
|
| 31 |
+
guidance_scale = 0
|
| 32 |
+
strength = 0.75
|
| 33 |
+
seed = 43
|
| 34 |
+
shift = 3.0
|
| 35 |
+
input_image = load_image("assets/bottle.jpg")
|
| 36 |
+
generator = torch.Generator("cuda").manual_seed(seed)
|
| 37 |
+
|
| 38 |
+
print("Loading Pipeline...")
|
| 39 |
+
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
|
| 40 |
+
|
| 41 |
+
if use_gguf:
|
| 42 |
+
transformer = ZImageControlTransformer2DModel.from_single_file(
|
| 43 |
+
GGUF_MODEL_FILE,
|
| 44 |
+
torch_dtype=torch.bfloat16,
|
| 45 |
+
config=str(Path(GGUF_MODEL_FILE).parent),
|
| 46 |
+
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
| 47 |
+
use_controlnet=False, # <== Disable control layers to inference speedy
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
transformer = ZImageControlTransformer2DModel.from_pretrained(
|
| 51 |
+
BASE_MODEL_ID,
|
| 52 |
+
subfolder="transformer",
|
| 53 |
+
torch_dtype=torch.bfloat16,
|
| 54 |
+
use_controlnet=False, # <== Disable control layers to inference speedy
|
| 55 |
+
)
|
| 56 |
+
pipe = ZImageControlUnifiedPipeline.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.bfloat16, transformer=transformer)
|
| 57 |
+
pipe.scheduler = scheduler
|
| 58 |
+
|
| 59 |
+
# Apply optimization (Optional)
|
| 60 |
+
pipe.enable_group_offload(
|
| 61 |
+
onload_device="cuda", offload_device="cpu", offload_type="block_level", num_blocks_per_group=1, low_cpu_mem_usage=True, use_stream=True
|
| 62 |
+
)
|
| 63 |
+
pipe.vae.use_tiling = True
|
| 64 |
+
# ---
|
| 65 |
+
|
| 66 |
+
print("\nRunning Inference...")
|
| 67 |
+
start_inference_time = time.time()
|
| 68 |
+
|
| 69 |
+
generated_image = pipe(
|
| 70 |
+
prompt=prompt,
|
| 71 |
+
image=input_image,
|
| 72 |
+
strength=strength,
|
| 73 |
+
negative_prompt=negative_prompt,
|
| 74 |
+
height=target_height,
|
| 75 |
+
width=target_width,
|
| 76 |
+
num_inference_steps=num_inference_steps,
|
| 77 |
+
guidance_scale=guidance_scale,
|
| 78 |
+
generator=generator,
|
| 79 |
+
).images[0]
|
| 80 |
+
|
| 81 |
+
end_inference_time = time.time()
|
| 82 |
+
print(f"\nGeneration finished in {end_inference_time - start_inference_time:.2f} seconds.")
|
| 83 |
+
|
| 84 |
+
# Save Output
|
| 85 |
+
if not os.path.exists("outputs"):
|
| 86 |
+
os.makedirs("outputs")
|
| 87 |
+
output_filename = "outputs/z_image_controlnet_result_i2i.png"
|
| 88 |
+
generated_image.save(output_filename)
|
| 89 |
+
print(f"Image successfully saved as '{output_filename}'")
|
| 90 |
+
generated_image.show()
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
if __name__ == "__main__":
|
| 94 |
+
main()
|
infer_inpaint.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from diffusers import FlowMatchEulerDiscreteScheduler, GGUFQuantizationConfig
|
| 7 |
+
from diffusers.utils import load_image
|
| 8 |
+
from diffusers_local import patch # Apply necessary patches for local diffusers components
|
| 9 |
+
|
| 10 |
+
# 1. Import all necessary components
|
| 11 |
+
from diffusers_local.pipeline_z_image_control_unified import ZImageControlUnifiedPipeline
|
| 12 |
+
from diffusers_local.z_image_control_transformer_2d import ZImageControlTransformer2DModel
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,garbage_collection_threshold:0.7,max_split_size_mb:1024"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
# 1. Set params
|
| 20 |
+
BASE_MODEL_ID = "."
|
| 21 |
+
GGUF_MODEL_FILE = "./transformer/z_image_turbo_control_unified_v2.1_q4_k_m.gguf"
|
| 22 |
+
|
| 23 |
+
use_gguf = True
|
| 24 |
+
|
| 25 |
+
# prompt="一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动,裙摆轻盈飞扬。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,肤色白皙细腻,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,手指清晰可见、五指完整、指节自然、姿势优雅放松,仿佛沉浸在思绪之中。背景是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕,海浪轻轻拍打沙滩,天空湛蓝云朵稀薄。整体画面高清锐利、细节丰富、色彩鲜艳、焦点清晰、8K分辨率、杰作、最佳质量、无模糊、无噪点、无畸变、自然光照、电影级渲染。"
|
| 26 |
+
prompt = "Photorealistic portrait of a beautiful young East Asian woman with long, vibrant purple hair and a black bow. She is wearing a flowing white summer dress, standing on a sunny beach with a sparkling ocean and clear blue sky in the background. Bright natural sunlight, sharp focus, ultra-detailed."
|
| 27 |
+
negative_prompt = "Low quality, blurry, ugly, deformed fingers, extra fingers, bad hand, bad anatomy, noise, overexposed, underexposed"
|
| 28 |
+
|
| 29 |
+
target_height = 1728
|
| 30 |
+
target_width = 992
|
| 31 |
+
num_inference_steps = 20
|
| 32 |
+
guidance_scale = 0 # 2.5
|
| 33 |
+
controlnet_conditioning_scale = 0.7
|
| 34 |
+
controlnet_conditioning_refiner_scale = 0.75
|
| 35 |
+
mask_blur_radius = 8.0
|
| 36 |
+
seed = 42
|
| 37 |
+
shift = 3.0
|
| 38 |
+
generator = torch.Generator("cuda").manual_seed(seed)
|
| 39 |
+
|
| 40 |
+
print("Loading Pipeline...")
|
| 41 |
+
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
|
| 42 |
+
|
| 43 |
+
if use_gguf:
|
| 44 |
+
transformer = ZImageControlTransformer2DModel.from_single_file(
|
| 45 |
+
GGUF_MODEL_FILE,
|
| 46 |
+
torch_dtype=torch.bfloat16,
|
| 47 |
+
config=str(Path(GGUF_MODEL_FILE).parent),
|
| 48 |
+
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
| 49 |
+
add_control_noise_refiner=True, # <== If you don't want to use the control noise refiner disable here.
|
| 50 |
+
)
|
| 51 |
+
else:
|
| 52 |
+
transformer = ZImageControlTransformer2DModel.from_pretrained(
|
| 53 |
+
BASE_MODEL_ID,
|
| 54 |
+
subfolder="transformer",
|
| 55 |
+
torch_dtype=torch.bfloat16,
|
| 56 |
+
add_control_noise_refiner=True, # <== If you don't want to use the control noise refiner disable here.
|
| 57 |
+
)
|
| 58 |
+
pipe = ZImageControlUnifiedPipeline.from_pretrained(
|
| 59 |
+
BASE_MODEL_ID,
|
| 60 |
+
torch_dtype=torch.bfloat16,
|
| 61 |
+
transformer=transformer, # You don't need to load the transformer here if you don't intend to disable add_control_noise_refiner (only for diffusers model, gguf is required).
|
| 62 |
+
)
|
| 63 |
+
pipe.scheduler = scheduler
|
| 64 |
+
|
| 65 |
+
# Apply optimization (Optional)
|
| 66 |
+
pipe.enable_group_offload(
|
| 67 |
+
onload_device="cuda", offload_device="cpu", offload_type="block_level", num_blocks_per_group=1, low_cpu_mem_usage=True, use_stream=True
|
| 68 |
+
)
|
| 69 |
+
pipe.vae.use_tiling = True
|
| 70 |
+
# ---
|
| 71 |
+
|
| 72 |
+
print("\nRunning Inference...")
|
| 73 |
+
|
| 74 |
+
pose_image = load_image("assets/pose.jpg")
|
| 75 |
+
inpaint_image = load_image("assets/inpaint.jpg")
|
| 76 |
+
mask_image = load_image("assets/mask_inpaint.jpg")
|
| 77 |
+
|
| 78 |
+
start_inference_time = time.time()
|
| 79 |
+
|
| 80 |
+
generated_image = pipe(
|
| 81 |
+
prompt=prompt,
|
| 82 |
+
negative_prompt=negative_prompt,
|
| 83 |
+
image=inpaint_image,
|
| 84 |
+
control_image=pose_image,
|
| 85 |
+
mask_image=mask_image,
|
| 86 |
+
mask_blur_radius=mask_blur_radius,
|
| 87 |
+
height=target_height,
|
| 88 |
+
width=target_width,
|
| 89 |
+
num_inference_steps=num_inference_steps,
|
| 90 |
+
guidance_scale=guidance_scale,
|
| 91 |
+
controlnet_conditioning_scale=controlnet_conditioning_scale,
|
| 92 |
+
controlnet_refiner_conditioning_scale=controlnet_conditioning_refiner_scale,
|
| 93 |
+
generator=generator,
|
| 94 |
+
).images[0]
|
| 95 |
+
|
| 96 |
+
end_inference_time = time.time()
|
| 97 |
+
print(f"\nGeneration finished in {end_inference_time - start_inference_time:.2f} seconds.")
|
| 98 |
+
|
| 99 |
+
# Save Output
|
| 100 |
+
if not os.path.exists("outputs"):
|
| 101 |
+
os.makedirs("outputs")
|
| 102 |
+
output_filename = "outputs/z_image_controlnet_result_inpaint.png"
|
| 103 |
+
generated_image.save(output_filename)
|
| 104 |
+
print(f"Image successfully saved as '{output_filename}'")
|
| 105 |
+
generated_image.show()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
main()
|
infer_t2i.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from diffusers import FlowMatchEulerDiscreteScheduler, GGUFQuantizationConfig
|
| 7 |
+
from diffusers_local import patch # Apply necessary patches for local diffusers components
|
| 8 |
+
|
| 9 |
+
# 1. Import all necessary components
|
| 10 |
+
from diffusers_local.pipeline_z_image_control_unified import ZImageControlUnifiedPipeline
|
| 11 |
+
from diffusers_local.z_image_control_transformer_2d import ZImageControlTransformer2DModel
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,garbage_collection_threshold:0.7,max_split_size_mb:1024"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def main():
|
| 18 |
+
# 1. Set params
|
| 19 |
+
BASE_MODEL_ID = "."
|
| 20 |
+
GGUF_MODEL_FILE = "./z_image_turbo_control_unified_v2.1_q4_k_m.gguf"
|
| 21 |
+
|
| 22 |
+
use_gguf = True
|
| 23 |
+
|
| 24 |
+
prompt = "一位年轻女子站在阳光明媚的海岸线上,白裙在轻拂的海风中微微飘动,裙摆轻盈飞扬。她拥有一头鲜艳的紫色长发,在风中轻盈舞动,发间系着一个精致的黑色蝴蝶结,与身后柔和的蔚蓝天空形成鲜明对比。她面容清秀,眉目精致,肤色白皙细腻,透着一股甜美的青春气息;神情柔和,略带羞涩,目光静静地凝望着远方的地平线,双手自然交叠于身前,手指清晰可见、五指完整、指节自然、姿势优雅放松,仿佛沉浸在思绪之中。背景是辽阔无垠、波光粼粼的大海,阳光洒在海面上,映出温暖的金色光晕,海浪轻轻拍打沙滩,天空湛蓝云朵稀薄。整体画面高清锐利、细节丰富、色彩鲜艳、焦点清晰、8K分辨率、杰作、最佳质量、无模糊、无噪点、无畸变、自然光照、电影级渲染。"
|
| 25 |
+
negative_prompt = "Low quality, blurry, ugly, deformed fingers, extra fingers, bad hand, bad anatomy, noise, overexposed, underexposed"
|
| 26 |
+
|
| 27 |
+
target_height = 1728
|
| 28 |
+
target_width = 992
|
| 29 |
+
num_inference_steps = 9
|
| 30 |
+
guidance_scale = 0
|
| 31 |
+
seed = 43
|
| 32 |
+
shift = 3.0
|
| 33 |
+
generator = torch.Generator("cuda").manual_seed(seed)
|
| 34 |
+
|
| 35 |
+
print("Loading Pipeline...")
|
| 36 |
+
scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000, shift=shift)
|
| 37 |
+
|
| 38 |
+
if use_gguf:
|
| 39 |
+
transformer = ZImageControlTransformer2DModel.from_single_file(
|
| 40 |
+
GGUF_MODEL_FILE,
|
| 41 |
+
torch_dtype=torch.bfloat16,
|
| 42 |
+
config=str(Path(GGUF_MODEL_FILE).parent),
|
| 43 |
+
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
|
| 44 |
+
use_controlnet=False, # <== Disable control layers to inference speedy
|
| 45 |
+
)
|
| 46 |
+
else:
|
| 47 |
+
transformer = ZImageControlTransformer2DModel.from_pretrained(
|
| 48 |
+
BASE_MODEL_ID,
|
| 49 |
+
subfolder="transformer",
|
| 50 |
+
torch_dtype=torch.bfloat16,
|
| 51 |
+
use_controlnet=False, # <== Disable control layers to inference speedy
|
| 52 |
+
)
|
| 53 |
+
pipe = ZImageControlUnifiedPipeline.from_pretrained(BASE_MODEL_ID, torch_dtype=torch.bfloat16, transformer=transformer)
|
| 54 |
+
pipe.scheduler = scheduler
|
| 55 |
+
|
| 56 |
+
# Apply optimization (Optional)
|
| 57 |
+
pipe.enable_group_offload(
|
| 58 |
+
onload_device="cuda", offload_device="cpu", offload_type="block_level", num_blocks_per_group=1, low_cpu_mem_usage=True, use_stream=True
|
| 59 |
+
)
|
| 60 |
+
pipe.vae.use_tiling = True
|
| 61 |
+
# ---
|
| 62 |
+
|
| 63 |
+
print("\nRunning Inference...")
|
| 64 |
+
start_inference_time = time.time()
|
| 65 |
+
|
| 66 |
+
generated_image = pipe(
|
| 67 |
+
prompt=prompt,
|
| 68 |
+
negative_prompt=negative_prompt,
|
| 69 |
+
height=target_height,
|
| 70 |
+
width=target_width,
|
| 71 |
+
num_inference_steps=num_inference_steps,
|
| 72 |
+
guidance_scale=guidance_scale,
|
| 73 |
+
generator=generator,
|
| 74 |
+
).images[0]
|
| 75 |
+
|
| 76 |
+
end_inference_time = time.time()
|
| 77 |
+
print(f"\nGeneration finished in {end_inference_time - start_inference_time:.2f} seconds.")
|
| 78 |
+
|
| 79 |
+
# Save Output
|
| 80 |
+
if not os.path.exists("outputs"):
|
| 81 |
+
os.makedirs("outputs")
|
| 82 |
+
output_filename = "outputs/z_image_controlnet_result_t2i.png"
|
| 83 |
+
generated_image.save(output_filename)
|
| 84 |
+
print(f"Image successfully saved as '{output_filename}'")
|
| 85 |
+
generated_image.show()
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
main()
|
model_index.json
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "ZImagePipeline",
|
| 3 |
+
"_diffusers_version": "0.36.0.dev0",
|
| 4 |
+
"scheduler": [
|
| 5 |
+
"diffusers",
|
| 6 |
+
"FlowMatchEulerDiscreteScheduler"
|
| 7 |
+
],
|
| 8 |
+
"text_encoder": [
|
| 9 |
+
"transformers",
|
| 10 |
+
"Qwen3Model"
|
| 11 |
+
],
|
| 12 |
+
"tokenizer": [
|
| 13 |
+
"transformers",
|
| 14 |
+
"Qwen2Tokenizer"
|
| 15 |
+
],
|
| 16 |
+
"transformer": [
|
| 17 |
+
"diffusers",
|
| 18 |
+
"ZImageControlTransformer2DModel"
|
| 19 |
+
],
|
| 20 |
+
"vae": [
|
| 21 |
+
"diffusers",
|
| 22 |
+
"AutoencoderKL"
|
| 23 |
+
]
|
| 24 |
+
}
|
requirements.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--extra-index-url https://download.pytorch.org/whl/cu126
|
| 2 |
+
|
| 3 |
+
torch==2.8.0+cu126
|
| 4 |
+
torchvision==0.23.0+cu126
|
| 5 |
+
torchaudio==2.8.0+cu126
|
| 6 |
+
transformers==4.56.0
|
| 7 |
+
bitsandbytes==0.48.1
|
| 8 |
+
xformers==0.0.32.post2
|
| 9 |
+
git+https://github.com/huggingface/diffusers
|
| 10 |
+
hf_xet
|
| 11 |
+
gguf
|
| 12 |
+
accelerate
|
| 13 |
+
protobuf
|
| 14 |
+
einops
|
| 15 |
+
matplotlib
|
| 16 |
+
sacremoses
|
| 17 |
+
scikit-image
|
| 18 |
+
sentencepiece
|
| 19 |
+
scipy
|
| 20 |
+
opencv-python
|
| 21 |
+
triton-windows<3.5; sys_platform == 'win32'
|
| 22 |
+
triton==3.4.0; sys_platform != 'win32'
|
results/canny.png
ADDED
|
Git LFS Details
|
results/depth.png
ADDED
|
Git LFS Details
|
results/hed.png
ADDED
|
Git LFS Details
|
results/new_tests/controlnet_result_i2i.png
ADDED
|
Git LFS Details
|
results/new_tests/result_control_canny.png
ADDED
|
Git LFS Details
|
results/new_tests/result_control_depth.png
ADDED
|
Git LFS Details
|
results/new_tests/result_control_hed.png
ADDED
|
Git LFS Details
|
results/new_tests/result_control_inpaint_original_mask.png
ADDED
|
Git LFS Details
|
results/new_tests/result_control_mlsd.png
ADDED
|
Git LFS Details
|
results/new_tests/result_control_pose.png
ADDED
|
Git LFS Details
|
results/new_tests/result_inpaint.png
ADDED
|
Git LFS Details
|
results/new_tests/result_t2i.png
ADDED
|
Git LFS Details
|