Text-to-Image
Diffusers
Safetensors
recoilme commited on
Commit
7b12282
·
1 Parent(s): ec0cd71
girl.jpg CHANGED

Git LFS Details

  • SHA256: d98b27cf3ae2b91022856db57f7f5e07380cd59b0cb5f01077c556079c93adc7
  • Pointer size: 130 Bytes
  • Size of remote file: 24.9 kB

Git LFS Details

  • SHA256: 142051a7a89100e8614349a8b0fba903d6eb044af08092bb0d9a17478bb8900c
  • Pointer size: 130 Bytes
  • Size of remote file: 53.7 kB
media/result_grid.jpg CHANGED

Git LFS Details

  • SHA256: d5b008a5ca3136fc5b747c4e02d09fa6605e4c04a02b39cec9ec47b0795dd0f2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.73 MB

Git LFS Details

  • SHA256: e2468f4c8c3e6d55a01af9e6d8af568cf771f95637ba008402fcf3492697d58a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.68 MB
model_index.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b6d71e1f562601e1a6bfcd1f0f7f81e021003b5512637d1c14dd77aba88144c8
3
- size 412
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0d5bb629e0658077c2dc01e262ac002c2235c10c6bccdbbb13986db25d810a5
3
+ size 545
pipeline_sdxs.py CHANGED
@@ -7,18 +7,21 @@ from dataclasses import dataclass
7
  from diffusers import DiffusionPipeline
8
  from diffusers.utils import BaseOutput
9
  from tqdm import tqdm
 
10
 
11
  @dataclass
12
  class SdxsPipelineOutput(BaseOutput):
13
  images: Union[List[Image.Image], np.ndarray]
14
 
15
  class SdxsPipeline(DiffusionPipeline):
16
- def __init__(self, vae, text_encoder, tokenizer, unet, scheduler):
17
  super().__init__()
18
  self.register_modules(
19
  vae=vae,
20
- text_encoder=text_encoder,
21
- tokenizer=tokenizer,
 
 
22
  unet=unet,
23
  scheduler=scheduler
24
  )
@@ -98,7 +101,7 @@ class SdxsPipeline(DiffusionPipeline):
98
  return latents
99
 
100
  def encode_prompt(self, prompt, negative_prompt, device, dtype):
101
- def get_single_encode(texts, is_negative=False):
102
  if texts is None or texts == "":
103
  hidden_dim = self.text_encoder.config.hidden_size
104
  shape = (1, self.text_encoder.config.max_position_embeddings, hidden_dim)
@@ -130,19 +133,63 @@ class SdxsPipeline(DiffusionPipeline):
130
  prompt_embeds = final_layer_norm(prompt_embeds)
131
 
132
  return prompt_embeds, toks.attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  pos_embeds, pos_mask = get_single_encode(prompt)
135
- neg_embeds, neg_mask = get_single_encode(negative_prompt, is_negative=True)
 
 
136
 
137
  batch_size = pos_embeds.shape[0]
138
  if neg_embeds.shape[0] != batch_size:
139
  neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
140
  neg_mask = neg_mask.repeat(batch_size, 1)
 
 
 
 
141
 
142
  text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
143
  final_mask = torch.cat([neg_mask, pos_mask], dim=0)
 
144
 
145
- return text_embeddings.to(dtype=dtype), final_mask.to(dtype=torch.int64)
146
 
147
  @torch.no_grad()
148
  def __call__(
@@ -170,7 +217,7 @@ class SdxsPipeline(DiffusionPipeline):
170
  generator = torch.Generator(device=device).manual_seed(seed)
171
 
172
  # 1. Encode prompt (твой код оставляем без изменений)
173
- text_embeddings, attention_mask = self.encode_prompt(
174
  prompt, negative_prompt, device, dtype
175
  )
176
  batch_size = 1 if isinstance(prompt, str) else len(prompt)
@@ -226,12 +273,17 @@ class SdxsPipeline(DiffusionPipeline):
226
  # ==================== DENOISING LOOP (одинаковый для txt2img и img2img) ====================
227
  for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
228
  latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
 
 
 
 
229
 
230
  model_out = self.unet(
231
  latent_model_input,
232
  t,
233
  encoder_hidden_states=text_embeddings,
234
  encoder_attention_mask=attention_mask,
 
235
  return_dict=False,
236
  )[0]
237
 
 
7
  from diffusers import DiffusionPipeline
8
  from diffusers.utils import BaseOutput
9
  from tqdm import tqdm
10
+ from transformers import Qwen3ForCausalLM, Qwen2Tokenizer
11
 
12
  @dataclass
13
  class SdxsPipelineOutput(BaseOutput):
14
  images: Union[List[Image.Image], np.ndarray]
15
 
16
  class SdxsPipeline(DiffusionPipeline):
17
+ def __init__(self, vae, text_encoder, text_encoder2, tokenizer, tokenizer2, unet, scheduler):
18
  super().__init__()
19
  self.register_modules(
20
  vae=vae,
21
+ text_encoder=text_encoder,
22
+ text_encoder2=text_encoder2,
23
+ tokenizer=tokenizer,
24
+ tokenizer2=tokenizer2,
25
  unet=unet,
26
  scheduler=scheduler
27
  )
 
101
  return latents
102
 
103
  def encode_prompt(self, prompt, negative_prompt, device, dtype):
104
+ def get_single_encode(texts):
105
  if texts is None or texts == "":
106
  hidden_dim = self.text_encoder.config.hidden_size
107
  shape = (1, self.text_encoder.config.max_position_embeddings, hidden_dim)
 
133
  prompt_embeds = final_layer_norm(prompt_embeds)
134
 
135
  return prompt_embeds, toks.attention_mask
136
+
137
+ def get_pooled_encode(texts):
138
+ if texts is None or texts == "":
139
+ hidden_dim = self.text_encoder2.config.hidden_size
140
+ shape = (1, self.text_encoder.config.max_position_embeddings, hidden_dim)#248
141
+ emb = torch.zeros(shape, dtype=dtype, device=device)
142
+ return emb
143
+
144
+ if isinstance(texts, str):
145
+ texts = [texts]
146
+
147
+ with torch.no_grad():
148
+ messages = [{"role": "user", "content": texts}]
149
+ with open("tokenizer2/chat_template.jinja", "r", encoding="utf-8") as f:
150
+ custom_template = f.read().strip()
151
+ text = self.tokenizer2.apply_chat_template(messages, add_generation_prompt=False, tokenize=False, chat_template=custom_template)
152
+ #text = self.tokenizer2.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)
153
+ toks = self.tokenizer2(
154
+ text,
155
+ padding="max_length",
156
+ max_length=self.text_encoder.config.max_position_embeddings,
157
+ truncation=True,
158
+ return_tensors="pt"
159
+ ).to(device)
160
+
161
+ outputs = self.text_encoder2(
162
+ input_ids=toks.input_ids,
163
+ attention_mask=toks.attention_mask,
164
+ output_hidden_states=True
165
+ )
166
+
167
+ layer_index = -2
168
+ last_hidden = outputs.hidden_states[layer_index]
169
+ seq_len = toks.attention_mask.sum(dim=1) - 1
170
+ pooled = last_hidden[torch.arange(len(last_hidden)), seq_len.clamp(min=0)]
171
+
172
+ return pooled
173
 
174
  pos_embeds, pos_mask = get_single_encode(prompt)
175
+ neg_embeds, neg_mask = get_single_encode(negative_prompt)
176
+ pos_pooled = get_pooled_encode(prompt)
177
+ neg_pooled = get_pooled_encode(negative_prompt)
178
 
179
  batch_size = pos_embeds.shape[0]
180
  if neg_embeds.shape[0] != batch_size:
181
  neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
182
  neg_mask = neg_mask.repeat(batch_size, 1)
183
+ neg_pooled = neg_pooled.repeat(batch_size, 1)
184
+
185
+ if pos_pooled.shape[0] != batch_size:
186
+ pos_pooled = pos_pooled.repeat(batch_size, 1)
187
 
188
  text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
189
  final_mask = torch.cat([neg_mask, pos_mask], dim=0)
190
+ pooled_embeds = torch.cat([neg_pooled, pos_pooled], dim=0)
191
 
192
+ return text_embeddings.to(dtype=dtype), final_mask.to(dtype=torch.int64), pooled_embeds.to(dtype=dtype)
193
 
194
  @torch.no_grad()
195
  def __call__(
 
217
  generator = torch.Generator(device=device).manual_seed(seed)
218
 
219
  # 1. Encode prompt (твой код оставляем без изменений)
220
+ text_embeddings, attention_mask, pooled_embeds = self.encode_prompt(
221
  prompt, negative_prompt, device, dtype
222
  )
223
  batch_size = 1 if isinstance(prompt, str) else len(prompt)
 
273
  # ==================== DENOISING LOOP (одинаковый для txt2img и img2img) ====================
274
  for i, t in enumerate(tqdm(timesteps, desc="Sampling")):
275
  latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents
276
+
277
+ added_cond_kwargs = {
278
+ "text_embeds": pooled_embeds,
279
+ }
280
 
281
  model_out = self.unet(
282
  latent_model_input,
283
  t,
284
  encoder_hidden_states=text_embeddings,
285
  encoder_attention_mask=attention_mask,
286
+ added_cond_kwargs=added_cond_kwargs,
287
  return_dict=False,
288
  )[0]
289
 
samples/unet_320x640_0.jpg CHANGED

Git LFS Details

  • SHA256: d8d46684ae7e461e5b8cd03937a50d01ae89532e43ab3a47456a3e687af19549
  • Pointer size: 130 Bytes
  • Size of remote file: 35.5 kB

Git LFS Details

  • SHA256: fdc51c08531f2a84eacc51ab1c864f052bb271abb5cd2265a3e16c3f08e66eb7
  • Pointer size: 130 Bytes
  • Size of remote file: 46.1 kB
samples/unet_352x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 3efd68b90e3bfe131e8516e9c681b369f9e26d8d4e96159fb38e1a860756c0c5
  • Pointer size: 130 Bytes
  • Size of remote file: 75.6 kB

Git LFS Details

  • SHA256: d5c026a5184bfbfc57f85323ef15b41465a544254e551f4cb802f613e2fbf7cb
  • Pointer size: 131 Bytes
  • Size of remote file: 136 kB
samples/unet_384x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 0e73dd8cabd5bfdd49c677639c8557f10275f4bffd0483f25e4f82a8401d7c85
  • Pointer size: 131 Bytes
  • Size of remote file: 158 kB

Git LFS Details

  • SHA256: 9b008824a24f39ea1eb081bee307076f4a7f9cc49f72324322683a7f6682b919
  • Pointer size: 131 Bytes
  • Size of remote file: 101 kB
samples/unet_416x640_0.jpg CHANGED

Git LFS Details

  • SHA256: dbf30fdc26f79b0ff8dde4577ab91f96033f926ed0e9dd9df00067c154eb194e
  • Pointer size: 131 Bytes
  • Size of remote file: 132 kB

Git LFS Details

  • SHA256: 12fea08c2af57b7f7e3ba8b6aeefb7eec7b3bcd0d4e242d864adbae45a24cd6e
  • Pointer size: 130 Bytes
  • Size of remote file: 65.3 kB
samples/unet_448x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 344e2e5df5d11c05a64be72dc42baa50299160e152d854da70f8b29b31640f0b
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB

Git LFS Details

  • SHA256: 91f1cad64c89b2ae069d77b0b1eb6f48ec939ad1baa5f15c4dc88e58a48defd5
  • Pointer size: 131 Bytes
  • Size of remote file: 155 kB
samples/unet_480x640_0.jpg CHANGED

Git LFS Details

  • SHA256: cd7fd9eae4dc639ab8c96ce643ecffce2ad9295499ea490ede8c3c6800816c22
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB

Git LFS Details

  • SHA256: 5039b32dd4618e05abf17e2c82f5e170ec0a7fadcbc368901128c076f6bacef5
  • Pointer size: 131 Bytes
  • Size of remote file: 197 kB
samples/unet_512x640_0.jpg CHANGED

Git LFS Details

  • SHA256: b3a059eedb669e51bf50ea0892cb67b4177385b14a52fff88b74ceef222a2eb2
  • Pointer size: 131 Bytes
  • Size of remote file: 177 kB

Git LFS Details

  • SHA256: 19513d1153dac5ed907098d5fba226adfb7e3130e7ee155500e6c33578711fb7
  • Pointer size: 130 Bytes
  • Size of remote file: 76.8 kB
samples/unet_544x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 006fedadb8d1630ba4ecf8e9e73657fd9789f7957d2f9af9d97facf5feab2aec
  • Pointer size: 130 Bytes
  • Size of remote file: 92.2 kB

Git LFS Details

  • SHA256: 307a075166de57168469dbbc9c0c7316b5037c5bbb50fa4fe09f572dc97ea327
  • Pointer size: 131 Bytes
  • Size of remote file: 153 kB
samples/unet_576x640_0.jpg CHANGED

Git LFS Details

  • SHA256: c7210149d438abf21d40b0821b949c81ed922754b6bd5b7db1b69b87dcfe0bc1
  • Pointer size: 131 Bytes
  • Size of remote file: 230 kB

Git LFS Details

  • SHA256: 232c9214613fe3aa95b6bc39b67a63e4a96eb10f59354d317250f587c110a5c9
  • Pointer size: 131 Bytes
  • Size of remote file: 245 kB
samples/unet_608x640_0.jpg CHANGED

Git LFS Details

  • SHA256: f75948ed9916efe45dab024bea702d0a3714342ebec6f4d354102f05277091a9
  • Pointer size: 131 Bytes
  • Size of remote file: 152 kB

Git LFS Details

  • SHA256: 6f72b0ea6d9074617b892593b17fe81c64a51337dd2d4a5b283ed8ac1841b064
  • Pointer size: 131 Bytes
  • Size of remote file: 103 kB
samples/unet_640x320_0.jpg CHANGED

Git LFS Details

  • SHA256: eeb51da2f098f450eddbb24608683f888472a46eae32b808a9d117caffebe49b
  • Pointer size: 130 Bytes
  • Size of remote file: 98.5 kB

Git LFS Details

  • SHA256: da499eeda937acf345c49609f7a93f7caa2b76c4fd4d1f03cb40652711e316dd
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
samples/unet_640x352_0.jpg CHANGED

Git LFS Details

  • SHA256: 2ce8e3c3eb464eb0650b8ba139f2b1ed9c13cd842ee34f11ef5ed6346ccdb8d0
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB

Git LFS Details

  • SHA256: ddca2967951f48c8c15c23c82e1950d5f2f6099f86544d6038ea087e04794505
  • Pointer size: 131 Bytes
  • Size of remote file: 157 kB
samples/unet_640x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 25b4325358bbcc2c787882393313616d2e3c89b6960a57b77afcb94dab04f1e6
  • Pointer size: 131 Bytes
  • Size of remote file: 117 kB

Git LFS Details

  • SHA256: 7bb0cb8197d73293384b2cfc773e8fd3fdaa4b2248c77eb0622b7f4002317508
  • Pointer size: 131 Bytes
  • Size of remote file: 176 kB
samples/unet_640x416_0.jpg CHANGED

Git LFS Details

  • SHA256: ebe19c459af583d3a8c5bd53e73dc0a4b88c9ee564fe1036483d734b0f514788
  • Pointer size: 131 Bytes
  • Size of remote file: 181 kB

Git LFS Details

  • SHA256: c90e25f76319168e9bf34e55ff41f080d8c72ba46090e95a4a7d0c8c191c935a
  • Pointer size: 131 Bytes
  • Size of remote file: 114 kB
samples/unet_640x448_0.jpg CHANGED

Git LFS Details

  • SHA256: 13e903e484a909c74fe1a65378d471ac530c86a75a8b550d452ed56cc91e30c1
  • Pointer size: 131 Bytes
  • Size of remote file: 111 kB

Git LFS Details

  • SHA256: 2c721fc590cd2b69e48646ad40c05149d1ce65c0ed26f4d1dfe0f5545a9e13c6
  • Pointer size: 131 Bytes
  • Size of remote file: 244 kB
samples/unet_640x480_0.jpg CHANGED

Git LFS Details

  • SHA256: a7f615e27eaa6c42180e5e09a61924084050befa2e6638e7386a0b55fa060dbb
  • Pointer size: 131 Bytes
  • Size of remote file: 174 kB

Git LFS Details

  • SHA256: 8c98c99e90ea48882546809e92d5fa1ce38d63910a6e76e8e432c987632c9ef2
  • Pointer size: 130 Bytes
  • Size of remote file: 59.9 kB
samples/unet_640x512_0.jpg CHANGED

Git LFS Details

  • SHA256: d02cf8e05c137f816b98f7baabe4167498ae6318ce5fb3b2274ae77ad86a6c16
  • Pointer size: 131 Bytes
  • Size of remote file: 115 kB

Git LFS Details

  • SHA256: baf9e78bbb24d30dc34c36dd9d64fe96793facec7d74714ca20203dd2a6e611d
  • Pointer size: 131 Bytes
  • Size of remote file: 119 kB
samples/unet_640x544_0.jpg CHANGED

Git LFS Details

  • SHA256: aade5739843610c92f0f32009eb4c27a313b7f2e573affe372c357a28af01c89
  • Pointer size: 131 Bytes
  • Size of remote file: 316 kB

Git LFS Details

  • SHA256: 1eae49806b4ac286ddd0dae65949b6c752a34079ecdccd113f16cebaeb624ba2
  • Pointer size: 131 Bytes
  • Size of remote file: 167 kB
samples/unet_640x576_0.jpg CHANGED

Git LFS Details

  • SHA256: c1ebebfc600d8f9f9fb7e123941387dbc51e447e807b994da7d9e5cd3798ee77
  • Pointer size: 131 Bytes
  • Size of remote file: 156 kB

Git LFS Details

  • SHA256: 62222f1de64715706e86befdef44c895b10044c5aca35919fa71fb0390e54b31
  • Pointer size: 130 Bytes
  • Size of remote file: 87 kB
samples/unet_640x608_0.jpg CHANGED

Git LFS Details

  • SHA256: d91fc031fed4bebf7c11527b71534367c26439df1135ee00aa25ee4157f5daa7
  • Pointer size: 131 Bytes
  • Size of remote file: 163 kB

Git LFS Details

  • SHA256: 8154e648e834b93d713432707601f0c0fc6ddeed9eb05896cc83abdbf25c8406
  • Pointer size: 131 Bytes
  • Size of remote file: 166 kB
samples/unet_640x640_0.jpg CHANGED

Git LFS Details

  • SHA256: 0feca1acd71ff8c0e7502406dd1132171178351f0e3887be646ecc70112cc48b
  • Pointer size: 131 Bytes
  • Size of remote file: 264 kB

Git LFS Details

  • SHA256: dd869ffb1b6774c3ac82bc62040e6fdcda3a90f2403ca623bb61ab8a9b51252a
  • Pointer size: 131 Bytes
  • Size of remote file: 233 kB
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ac34e2fd2c559d6e716bb9603c9834bdc4034ba6466994a7d17baec6443aeddf
3
- size 2650007
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:57474740f5a2bece3336537eca97110854b675a1213a6311971f016a4b5f5e3d
3
+ size 787689
text_encoder2/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e932ef2a43c3da60f4ddb09e3cb4bc242ea648d4f6720e90a9a59bd2478f1ed1
3
+ size 1506
text_encoder2/generation_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ad07e05bf667caeda7d50a747f9f7b0bc099ff85cd75d3b29c304881be086dbb
3
+ size 204
text_encoder2/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f7db7f79952f3f268880b4ce26ac6547265a3671ec0c406ecd45fda35976f25
3
+ size 3441185296
tmp.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ebdbfa4158fc271c58538c4e5678089d5de991910e2cad8c475a86c1c49b707
3
+ size 823112
tokenizer2/chat_template.jinja ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {%- if tools %}
2
+ {{- '<|im_start|>system\n' }}
3
+ {%- if messages[0].role == 'system' %}
4
+ {%- if messages[0].content is string %}
5
+ {{- messages[0].content }}
6
+ {%- else %}
7
+ {%- for content in messages[0].content %}
8
+ {%- if 'text' in content %}
9
+ {{- content.text }}
10
+ {%- endif %}
11
+ {%- endfor %}
12
+ {%- endif %}
13
+ {{- '\n\n' }}
14
+ {%- endif %}
15
+ {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within <tools></tools> XML tags:\n<tools>" }}
16
+ {%- for tool in tools %}
17
+ {{- "\n" }}
18
+ {{- tool | tojson }}
19
+ {%- endfor %}
20
+ {{- "\n</tools>\n\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{\"name\": <function-name>, \"arguments\": <args-json-object>}\n</tool_call><|im_end|>\n" }}
21
+ {%- else %}
22
+ {%- if messages[0].role == 'system' %}
23
+ {{- '<|im_start|>system\n' }}
24
+ {%- if messages[0].content is string %}
25
+ {{- messages[0].content }}
26
+ {%- else %}
27
+ {%- for content in messages[0].content %}
28
+ {%- if 'text' in content %}
29
+ {{- content.text }}
30
+ {%- endif %}
31
+ {%- endfor %}
32
+ {%- endif %}
33
+ {{- '<|im_end|>\n' }}
34
+ {%- endif %}
35
+ {%- endif %}
36
+ {%- set image_count = namespace(value=0) %}
37
+ {%- set video_count = namespace(value=0) %}
38
+ {%- for message in messages %}
39
+ {%- if message.role == "user" %}
40
+ {{- '<|im_start|>' + message.role + '\n' }}
41
+ {%- if message.content is string %}
42
+ {{- message.content }}
43
+ {%- else %}
44
+ {%- for content in message.content %}
45
+ {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}
46
+ {%- set image_count.value = image_count.value + 1 %}
47
+ {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
48
+ <|vision_start|><|image_pad|><|vision_end|>
49
+ {%- elif content.type == 'video' or 'video' in content %}
50
+ {%- set video_count.value = video_count.value + 1 %}
51
+ {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
52
+ <|vision_start|><|video_pad|><|vision_end|>
53
+ {%- elif 'text' in content %}
54
+ {{- content.text }}
55
+ {%- endif %}
56
+ {%- endfor %}
57
+ {%- endif %}
58
+ {{- '<|im_end|>\n' }}
59
+ {%- elif message.role == "assistant" %}
60
+ {{- '<|im_start|>' + message.role + '\n' }}
61
+ {%- if message.content is string %}
62
+ {{- message.content }}
63
+ {%- else %}
64
+ {%- for content_item in message.content %}
65
+ {%- if 'text' in content_item %}
66
+ {{- content_item.text }}
67
+ {%- endif %}
68
+ {%- endfor %}
69
+ {%- endif %}
70
+ {%- if message.tool_calls %}
71
+ {%- for tool_call in message.tool_calls %}
72
+ {%- if (loop.first and message.content) or (not loop.first) %}
73
+ {{- '\n' }}
74
+ {%- endif %}
75
+ {%- if tool_call.function %}
76
+ {%- set tool_call = tool_call.function %}
77
+ {%- endif %}
78
+ {{- '<tool_call>\n{"name": "' }}
79
+ {{- tool_call.name }}
80
+ {{- '", "arguments": ' }}
81
+ {%- if tool_call.arguments is string %}
82
+ {{- tool_call.arguments }}
83
+ {%- else %}
84
+ {{- tool_call.arguments | tojson }}
85
+ {%- endif %}
86
+ {{- '}\n</tool_call>' }}
87
+ {%- endfor %}
88
+ {%- endif %}
89
+ {{- '<|im_end|>\n' }}
90
+ {%- elif message.role == "tool" %}
91
+ {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %}
92
+ {{- '<|im_start|>user' }}
93
+ {%- endif %}
94
+ {{- '\n<tool_response>\n' }}
95
+ {%- if message.content is string %}
96
+ {{- message.content }}
97
+ {%- else %}
98
+ {%- for content in message.content %}
99
+ {%- if content.type == 'image' or 'image' in content or 'image_url' in content %}
100
+ {%- set image_count.value = image_count.value + 1 %}
101
+ {%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
102
+ <|vision_start|><|image_pad|><|vision_end|>
103
+ {%- elif content.type == 'video' or 'video' in content %}
104
+ {%- set video_count.value = video_count.value + 1 %}
105
+ {%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
106
+ <|vision_start|><|video_pad|><|vision_end|>
107
+ {%- elif 'text' in content %}
108
+ {{- content.text }}
109
+ {%- endif %}
110
+ {%- endfor %}
111
+ {%- endif %}
112
+ {{- '\n</tool_response>' }}
113
+ {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
114
+ {{- '<|im_end|>\n' }}
115
+ {%- endif %}
116
+ {%- endif %}
117
+ {%- endfor %}
118
+ {%- if add_generation_prompt %}
119
+ {{- '<|im_start|>assistant\n' }}
120
+ {%- endif %}
tokenizer2/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:be75606093db2094d7cd20f3c2f385c212750648bd6ea4fb2bf507a6a4c55506
3
+ size 11422650
tokenizer2/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:013cf6e7def2d3260ae8bbb909a0dc024cd27732b8df318a3d00ee7c724b3215
3
+ size 390
train.py CHANGED
@@ -128,6 +128,8 @@ if accelerator.is_main_process:
128
  vae = AutoencoderKLFlux2.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
129
  tokenizer = AutoTokenizer.from_pretrained("tokenizer")
130
  text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
 
 
131
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
132
 
133
  def encode_texts(texts, max_length=max_length):
@@ -166,7 +168,30 @@ def encode_texts(texts, max_length=max_length):
166
  final_layer_norm = text_model.text_model.final_layer_norm
167
  prompt_embeds = final_layer_norm(prompt_embeds)
168
 
169
- return prompt_embeds, attention_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
172
  if shift_factor is None: shift_factor = 0.0
@@ -300,9 +325,9 @@ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
300
  texts = [item["text"] for item in samples_data]
301
 
302
  # Кодируем тексты на лету, чтобы получить маски и пулинг
303
- embeddings, masks = encode_texts(texts)
304
 
305
- fixed_samples[size] = (latents, embeddings, masks, texts)
306
 
307
  print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
308
  return fixed_samples
@@ -336,12 +361,12 @@ def collate_fn_simple(batch):
336
  ]
337
  # 3. Кодируем на лету
338
  # Возвращает: hidden (B, L, D), mask (B, L)
339
- embeddings, attention_mask = encode_texts(texts)
340
 
341
  # attention_mask от токенизатора уже имеет нужный формат, но на всякий случай приведем к long
342
  attention_mask = attention_mask.to(dtype=torch.int64)
343
 
344
- return latents, embeddings, attention_mask
345
 
346
  batch_sampler = DistributedResolutionBatchSampler(
347
  dataset=dataset,
@@ -447,22 +472,21 @@ def get_negative_embedding(neg_prompt="", batch_size=1):
447
  empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
448
  return empty_emb, empty_mask
449
 
450
- uncond_emb, uncond_mask = encode_texts([neg_prompt])
451
  uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
452
  uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
 
453
 
454
- return uncond_emb, uncond_mask
455
 
456
  # Получаем негативные (пустые) условия для валидации
457
- uncond_emb, uncond_mask = get_negative_embedding("low quality")
458
-
459
-
460
 
461
  # --- Функция генерации семплов ---
462
  @torch.compiler.disable()
463
  @torch.no_grad()
464
  def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
465
- uncond_emb, uncond_mask = uncond_data
466
 
467
  original_model = None
468
  try:
@@ -477,11 +501,12 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
477
  all_captions = []
478
 
479
  # Распаковываем 5 элементов (добавились mask)
480
- for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
481
  width, height = size
482
  sample_latents = sample_latents.to(dtype=dtype, device=device)
483
  sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
484
  sample_mask = sample_mask.to(device=device)
 
485
 
486
  latents = torch.randn(
487
  sample_latents.shape,
@@ -509,17 +534,25 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
509
  neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
510
  attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
511
 
 
 
 
512
  else:
513
  latent_model_input = latents
514
  text_embeddings_batch = sample_text_embeddings
515
  attention_mask_batch = sample_mask
 
516
 
 
 
 
517
  # Предсказание с передачей всех условий
518
  model_out = original_model(
519
  latent_model_input,
520
  t,
521
  encoder_hidden_states=text_embeddings_batch,
522
  encoder_attention_mask=attention_mask_batch,
 
523
  )
524
  flow = getattr(model_out, "sample", model_out)
525
 
@@ -582,7 +615,7 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
582
  del all_generated_images, all_captions
583
  del latents, current_latents, latent_model_input, flow
584
  del decoded, decoded_fp32
585
- del sample_latents, sample_text_embeddings, sample_mask # Копии на GPU
586
  del model_out
587
  except UnboundLocalError:
588
  pass
@@ -597,7 +630,7 @@ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
597
  if accelerator.is_main_process:
598
  if save_model:
599
  print("Генерация сэмплов до старта обучения...")
600
- generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
601
  accelerator.wait_for_everyone()
602
 
603
  def save_checkpoint(unet, variant=""):
@@ -638,7 +671,7 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
638
  accelerator.wait_for_everyone()
639
  unet.train()
640
 
641
- for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
642
  with accelerator.accumulate(unet):
643
  if save_model == False and epoch == 0 and step == 5 :
644
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
@@ -655,13 +688,17 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
655
  # делаем integer timesteps для UNet
656
  timesteps = t.to(torch.float32).mul(999.0)
657
  timesteps = timesteps.clamp(0, scheduler.config.num_train_timesteps - 1)
658
-
 
 
 
659
  # --- Вызов UNet с маской ---
660
  model_pred = unet(
661
  noisy_latents,
662
  timesteps,
663
  encoder_hidden_states=embeddings,
664
- encoder_attention_mask=attention_mask
 
665
  ).sample
666
 
667
  target = noise - latents
@@ -715,9 +752,9 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
715
  if global_step % sample_interval == 0 or global_step==50:
716
  # Передаем tuple (emb, mask) для негатива
717
  if save_model:
718
- generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
719
  elif epoch % 10 == 0:
720
- generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
721
  last_n = sample_interval
722
 
723
  if save_model:
 
128
  vae = AutoencoderKLFlux2.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
129
  tokenizer = AutoTokenizer.from_pretrained("tokenizer")
130
  text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
131
+ tokenizer2 = AutoTokenizer.from_pretrained("tokenizer")
132
+ text_model2 = AutoModel.from_pretrained("text_encoder").to(device).eval()
133
  scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
134
 
135
  def encode_texts(texts, max_length=max_length):
 
168
  final_layer_norm = text_model.text_model.final_layer_norm
169
  prompt_embeds = final_layer_norm(prompt_embeds)
170
 
171
+ messages = [{"role": "user", "content": texts}]
172
+ with open("tokenizer2/chat_template.jinja", "r", encoding="utf-8") as f:
173
+ custom_template = f.read().strip()
174
+ text = tokenizer2.apply_chat_template(messages, add_generation_prompt=False, tokenize=False, chat_template=custom_template)
175
+ toks = tokenizer2(
176
+ text,
177
+ padding="max_length",
178
+ max_length=max_length,
179
+ truncation=True,
180
+ return_tensors="pt"
181
+ ).to(device)
182
+
183
+ outputs = text_model2(
184
+ input_ids=toks.input_ids,
185
+ attention_mask=toks.attention_mask,
186
+ output_hidden_states=True
187
+ )
188
+
189
+ layer_index = -2
190
+ last_hidden = outputs.hidden_states[layer_index]
191
+ seq_len = toks.attention_mask.sum(dim=1) - 1
192
+ pooled = last_hidden[torch.arange(len(last_hidden)), seq_len.clamp(min=0)]
193
+
194
+ return prompt_embeds, attention_mask, pooled
195
 
196
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
197
  if shift_factor is None: shift_factor = 0.0
 
325
  texts = [item["text"] for item in samples_data]
326
 
327
  # Кодируем тексты на лету, чтобы получить маски и пулинг
328
+ embeddings, masks, pooled = encode_texts(texts)
329
 
330
+ fixed_samples[size] = (latents, embeddings, masks, texts, pooled)
331
 
332
  print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
333
  return fixed_samples
 
361
  ]
362
  # 3. Кодируем на лету
363
  # Возвращает: hidden (B, L, D), mask (B, L)
364
+ embeddings, attention_mask, pooled = encode_texts(texts)
365
 
366
  # attention_mask от токенизатора уже имеет нужный формат, но на всякий случай приведем к long
367
  attention_mask = attention_mask.to(dtype=torch.int64)
368
 
369
+ return latents, embeddings, attention_mask, pooled
370
 
371
  batch_sampler = DistributedResolutionBatchSampler(
372
  dataset=dataset,
 
472
  empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
473
  return empty_emb, empty_mask
474
 
475
+ uncond_emb, uncond_mask, uncond_pooled = encode_texts([neg_prompt])
476
  uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
477
  uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
478
+ uncond_pooled = uncond_pooled.to(device=device).repeat(batch_size, 1)
479
 
480
+ return uncond_emb, uncond_mask, uncond_pooled
481
 
482
  # Получаем негативные (пустые) условия для валидации
483
+ uncond_emb, uncond_mask, uncond_pooled = get_negative_embedding("low quality")
 
 
484
 
485
  # --- Функция генерации семплов ---
486
  @torch.compiler.disable()
487
  @torch.no_grad()
488
  def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
489
+ uncond_emb, uncond_mask, uncond_pooled = uncond_data
490
 
491
  original_model = None
492
  try:
 
501
  all_captions = []
502
 
503
  # Распаковываем 5 элементов (добавились mask)
504
+ for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text, sample_pooled) in fixed_samples_cpu.items():
505
  width, height = size
506
  sample_latents = sample_latents.to(dtype=dtype, device=device)
507
  sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
508
  sample_mask = sample_mask.to(device=device)
509
+ sample_pooled = sample_pooled.to(dtype=dtype, device=device)
510
 
511
  latents = torch.randn(
512
  sample_latents.shape,
 
534
  neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
535
  attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
536
 
537
+ neg_pooled_batch = uncond_pooled[0:1].expand(curr_batch_size, -1)
538
+ attention_pooled_batch = torch.cat([neg_pooled_batch, sample_pooled], dim=0)
539
+
540
  else:
541
  latent_model_input = latents
542
  text_embeddings_batch = sample_text_embeddings
543
  attention_mask_batch = sample_mask
544
+ attention_pooled_batch = sample_pooled
545
 
546
+ added_cond_kwargs = {
547
+ "text_embeds": attention_pooled_batch,
548
+ }
549
  # Предсказание с передачей всех условий
550
  model_out = original_model(
551
  latent_model_input,
552
  t,
553
  encoder_hidden_states=text_embeddings_batch,
554
  encoder_attention_mask=attention_mask_batch,
555
+ added_cond_kwargs=added_cond_kwargs,
556
  )
557
  flow = getattr(model_out, "sample", model_out)
558
 
 
615
  del all_generated_images, all_captions
616
  del latents, current_latents, latent_model_input, flow
617
  del decoded, decoded_fp32
618
+ del sample_latents, sample_text_embeddings, sample_mask, sample_pooled # Копии на GPU
619
  del model_out
620
  except UnboundLocalError:
621
  pass
 
630
  if accelerator.is_main_process:
631
  if save_model:
632
  print("Генерация сэмплов до старта обучения...")
633
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask,uncond_pooled), 0)
634
  accelerator.wait_for_everyone()
635
 
636
  def save_checkpoint(unet, variant=""):
 
671
  accelerator.wait_for_everyone()
672
  unet.train()
673
 
674
+ for step, (latents, embeddings, attention_mask, pooled) in enumerate(dataloader):
675
  with accelerator.accumulate(unet):
676
  if save_model == False and epoch == 0 and step == 5 :
677
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
 
688
  # делаем integer timesteps для UNet
689
  timesteps = t.to(torch.float32).mul(999.0)
690
  timesteps = timesteps.clamp(0, scheduler.config.num_train_timesteps - 1)
691
+
692
+ added_cond_kwargs = {
693
+ "text_embeds": pooled,
694
+ }
695
  # --- Вызов UNet с маской ---
696
  model_pred = unet(
697
  noisy_latents,
698
  timesteps,
699
  encoder_hidden_states=embeddings,
700
+ encoder_attention_mask=attention_mask,
701
+ added_cond_kwargs=added_cond_kwargs,
702
  ).sample
703
 
704
  target = noise - latents
 
752
  if global_step % sample_interval == 0 or global_step==50:
753
  # Передаем tuple (emb, mask) для негатива
754
  if save_model:
755
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask,uncond_pooled), global_step)
756
  elif epoch % 10 == 0:
757
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask,uncond_pooled), global_step)
758
  last_n = sample_interval
759
 
760
  if save_model:
train1te.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from comet_ml import Experiment
2
+ import os
3
+ os.environ["NCCL_P2P_DISABLE"] = "1"
4
+ os.environ["NCCL_IB_DISABLE"] = "1"
5
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
6
+ import math
7
+ import torch
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from torch.utils.data import DataLoader, Sampler
11
+ from torch.utils.data.distributed import DistributedSampler
12
+ from torch.optim.lr_scheduler import LambdaLR
13
+ from collections import defaultdict
14
+ from diffusers import UNet2DConditionModel, AutoencoderKL,AutoencoderKLFlux2,AsymmetricAutoencoderKL,FlowMatchEulerDiscreteScheduler
15
+ from accelerate import Accelerator, DeepSpeedPlugin
16
+ from datasets import load_from_disk
17
+ from tqdm import tqdm
18
+ from PIL import Image, ImageOps
19
+ import wandb
20
+ import random
21
+ import gc
22
+ from accelerate.state import DistributedType
23
+ from torch.distributed import broadcast_object_list
24
+ from torch.utils.checkpoint import checkpoint
25
+ from diffusers.models.attention_processor import AttnProcessor2_0
26
+ from datetime import datetime
27
+ import bitsandbytes as bnb
28
+ import torch.nn.functional as F
29
+ from collections import deque
30
+ from transformers import AutoTokenizer, AutoModel
31
+
32
+ # --------------------------- Параметры ---------------------------
33
+ ds_path = "/workspace/sdxs-1b/datasets/ds1234_flux32"
34
+ project = "unet"
35
+ ## total batch (split // num `GPU)
36
+ batch_size = 32
37
+ base_learning_rate = 3e-5
38
+ min_learning_rate = 1e-5
39
+ num_epochs = 8
40
+ sample_interval_share = 20
41
+ cfg_dropout = 0.10
42
+ max_length = 248
43
+ use_wandb = True
44
+ use_comet_ml = False
45
+ save_model = True
46
+ use_decay = True
47
+ fbp = False
48
+ optimizer_type = "adam8bit"
49
+ torch_compile = False
50
+ unet_gradient = True
51
+ loss_normalize = False
52
+ fixed_seed = False
53
+ shuffle = True
54
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
55
+ comet_ml_workspace = "recoilme"
56
+ torch.backends.cuda.matmul.allow_tf32 = True
57
+ torch.backends.cudnn.allow_tf32 = True
58
+ # Включение Flash Attention 2/SDPA #MAX_JOBS=4 pip install flash-attn --no-build-isolation
59
+ torch.backends.cuda.enable_flash_sdp(True)
60
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
61
+ torch.backends.cuda.enable_math_sdp(False) # Отключаем медленный вариант
62
+ save_barrier = 1.5
63
+ warmup_percent = 0.03
64
+ #percentile_clipping = 95
65
+ betta2 = 0.995
66
+ eps = 1e-7
67
+ clip_grad_norm = 1.0
68
+ limit = 0
69
+ checkpoints_folder = ""
70
+ gradient_accumulation_steps = 1
71
+ dtype = torch.float32
72
+ mixed_precision = "no"
73
+
74
+ # Параметры для диффузии
75
+ n_diffusion_steps = 40
76
+ samples_to_generate = 12
77
+ guidance_scale = 4
78
+
79
+ # Папки для сохранения результатов
80
+ generated_folder = "samples"
81
+ os.makedirs(generated_folder, exist_ok=True)
82
+
83
+ # Настройка seed
84
+ current_date = datetime.now()
85
+ seed = int(current_date.strftime("%Y%m%d")) + 1
86
+ if fixed_seed:
87
+ torch.manual_seed(seed)
88
+ np.random.seed(seed)
89
+ random.seed(seed)
90
+ if torch.cuda.is_available():
91
+ torch.cuda.manual_seed_all(seed)
92
+
93
+ accelerator = Accelerator(
94
+ mixed_precision=mixed_precision,
95
+ gradient_accumulation_steps=gradient_accumulation_steps
96
+ )
97
+ device = accelerator.device
98
+
99
+ print("init")
100
+
101
+ # --------------------------- Инициализация WandB ---------------------------
102
+ if accelerator.is_main_process:
103
+ if use_wandb:
104
+ wandb.init(project=project, config={
105
+ "batch_size": batch_size,
106
+ "base_learning_rate": base_learning_rate,
107
+ "num_epochs": num_epochs,
108
+ "optimizer_type": optimizer_type,
109
+ })
110
+ if use_comet_ml:
111
+ from comet_ml import Experiment
112
+ comet_experiment = Experiment(
113
+ api_key=comet_ml_api_key,
114
+ project_name=project,
115
+ workspace=comet_ml_workspace
116
+ )
117
+ hyper_params = {
118
+ "batch_size": batch_size,
119
+ "base_learning_rate": base_learning_rate,
120
+ "num_epochs": num_epochs,
121
+ }
122
+ comet_experiment.log_parameters(hyper_params)
123
+
124
+ # --------------------------- Загрузка моделей ---------------------------
125
+ #vae = AutoencoderKL.from_pretrained("vae", torch_dtype=dtype).to("cpu").eval()
126
+ #vae = AutoencoderKLFlux2.from_pretrained("black-forest-labs/FLUX.2-dev",subfolder="vae",torch_dtype=dtype).to(device).eval()
127
+ #vae = AsymmetricAutoencoderKL.from_pretrained("vae",torch_dtype=dtype).to(device).eval()
128
+ vae = AutoencoderKLFlux2.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
129
+ tokenizer = AutoTokenizer.from_pretrained("tokenizer")
130
+ text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
131
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
132
+
133
+ def encode_texts(texts, max_length=max_length):
134
+ if texts is None:
135
+ texts = [""]
136
+
137
+ if isinstance(texts, str):
138
+ texts = [texts]
139
+
140
+ with torch.no_grad():
141
+ # 1. Основная токенизация
142
+ toks = tokenizer(
143
+ texts,
144
+ padding="max_length",
145
+ max_length=max_length,
146
+ truncation=True,
147
+ return_tensors="pt"
148
+ ).to(device)
149
+
150
+ text_input_ids = toks.input_ids
151
+ attention_mask = toks.attention_mask
152
+
153
+ # 4. Прогон через модель
154
+ # Правильный вызов: передаем конкретные тензоры или распаковываем словарь **toks
155
+ outputs = text_model(
156
+ input_ids=text_input_ids,
157
+ attention_mask=attention_mask,
158
+ output_hidden_states=True # Часто нужно для SD 1.5 (слой -2)
159
+ )
160
+
161
+ layer_index = -2
162
+ prompt_embeds = outputs.hidden_states[layer_index]
163
+
164
+ # 2. ДОБАВЛЯЕМ ФИНАЛЬНУЮ НОРМАЛИЗАЦИЮ
165
+ # В CLIP после всех блоков стоит слой LayerNorm.
166
+ final_layer_norm = text_model.text_model.final_layer_norm
167
+ prompt_embeds = final_layer_norm(prompt_embeds)
168
+
169
+ return prompt_embeds, attention_mask
170
+
171
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
172
+ if shift_factor is None: shift_factor = 0.0
173
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
174
+ if scaling_factor is None: scaling_factor = 1.0
175
+
176
+ def _patchify_latents(latents):
177
+ batch_size, num_channels_latents, height, width = latents.shape
178
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
179
+ latents = latents.permute(0, 1, 3, 5, 2, 4)
180
+ latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
181
+ return latents
182
+
183
+ @staticmethod
184
+ def _unpatchify_latents(latents):
185
+ batch_size, num_channels_latents, height, width = latents.shape
186
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
187
+ latents = latents.permute(0, 1, 4, 2, 5, 3)
188
+ latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
189
+ return latents
190
+
191
+ def flux_encode(vae,latents):
192
+ # patch
193
+ image_latents = _patchify_latents(latents)
194
+ # norm
195
+ latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
196
+ latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps)
197
+ latents = (image_latents - latents_bn_mean) / latents_bn_std
198
+ # unpatch
199
+ latents = _unpatchify_latents(latents)
200
+ return latents
201
+
202
+ def flux_decode(vae,latents):
203
+ # patch
204
+ image_latents = _patchify_latents(latents)
205
+ # norm
206
+ latents_bn_mean = vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
207
+ latents_bn_std = torch.sqrt(vae.bn.running_var.view(1, -1, 1, 1) + vae.config.batch_norm_eps)
208
+ latents = image_latents * latents_bn_std + latents_bn_mean
209
+ # unpatch
210
+ latents = _unpatchify_latents(latents)
211
+ return latents
212
+
213
+ class DistributedResolutionBatchSampler(Sampler):
214
+ def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
215
+ self.dataset = dataset
216
+ self.batch_size = max(1, batch_size // num_replicas)
217
+ self.num_replicas = num_replicas
218
+ self.rank = rank
219
+ self.shuffle = shuffle
220
+ self.drop_last = drop_last
221
+ self.epoch = 0
222
+
223
+ try:
224
+ widths = np.array(dataset["width"])
225
+ heights = np.array(dataset["height"])
226
+ except KeyError:
227
+ widths = np.zeros(len(dataset))
228
+ heights = np.zeros(len(dataset))
229
+
230
+ self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
231
+ self.size_groups = {}
232
+ for w, h in self.size_keys:
233
+ mask = (widths == w) & (heights == h)
234
+ self.size_groups[(w, h)] = np.where(mask)[0]
235
+
236
+ self.group_num_batches = {}
237
+ total_batches = 0
238
+ for size, indices in self.size_groups.items():
239
+ num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
240
+ self.group_num_batches[size] = num_full_batches
241
+ total_batches += num_full_batches
242
+
243
+ self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
244
+
245
+ def __iter__(self):
246
+ if torch.cuda.is_available():
247
+ torch.cuda.empty_cache()
248
+ all_batches = []
249
+ rng = np.random.RandomState(self.epoch)
250
+
251
+ for size, indices in self.size_groups.items():
252
+ indices = indices.copy()
253
+ if self.shuffle:
254
+ rng.shuffle(indices)
255
+ num_full_batches = self.group_num_batches[size]
256
+ if num_full_batches == 0:
257
+ continue
258
+ valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
259
+ batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
260
+ start_idx = self.rank * self.batch_size
261
+ end_idx = start_idx + self.batch_size
262
+ gpu_batches = batches[:, start_idx:end_idx]
263
+ all_batches.extend(gpu_batches)
264
+
265
+ if self.shuffle:
266
+ rng.shuffle(all_batches)
267
+ accelerator.wait_for_everyone()
268
+ return iter(all_batches)
269
+
270
+ def __len__(self):
271
+ return self.num_batches
272
+
273
+ def set_epoch(self, epoch):
274
+ self.epoch = epoch
275
+
276
+ # --- [UPDATED] Функция для фиксированных семплов ---
277
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
278
+ size_groups = defaultdict(list)
279
+ try:
280
+ widths = dataset["width"]
281
+ heights = dataset["height"]
282
+ except KeyError:
283
+ widths = [0] * len(dataset)
284
+ heights = [0] * len(dataset)
285
+ for i, (w, h) in enumerate(zip(widths, heights)):
286
+ size = (w, h)
287
+ size_groups[size].append(i)
288
+
289
+ fixed_samples = {}
290
+ for size, indices in size_groups.items():
291
+ n_samples = min(samples_per_group, len(indices))
292
+ if len(size_groups)==1:
293
+ n_samples = samples_to_generate
294
+ if n_samples == 0:
295
+ continue
296
+ sample_indices = random.sample(indices, n_samples)
297
+ samples_data = [dataset[idx] for idx in sample_indices]
298
+
299
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
300
+ texts = [item["text"] for item in samples_data]
301
+
302
+ # Кодируем тексты на лету, чтобы получить маски и пулинг
303
+ embeddings, masks = encode_texts(texts)
304
+
305
+ fixed_samples[size] = (latents, embeddings, masks, texts)
306
+
307
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
308
+ return fixed_samples
309
+
310
+ if limit > 0:
311
+ dataset = load_from_disk(ds_path).select(range(limit))
312
+ else:
313
+ dataset = load_from_disk(ds_path)
314
+
315
+ dataset = dataset.filter(
316
+ lambda x: [not (path.startswith("/workspace/dataset/animesfw") or path.startswith("/workspace/dataset/d4/animesfw")) for path in x["image_path"]],
317
+ batched=True,
318
+ batch_size=10000, # обрабатываем по 10к строк за раз
319
+ num_proc=8
320
+ )
321
+ print(f"Осталось примеров после фильтрации: {len(dataset)}")
322
+
323
+ # --- Collate Function ---
324
+ def collate_fn_simple(batch):
325
+ # 1. Латенты (VAE)
326
+ latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device, dtype=dtype)
327
+
328
+ # 2. Текст берем сырой из датасета
329
+ raw_texts = [item["text"] for item in batch]
330
+ texts = [
331
+ "" if t.lower().startswith("zero")
332
+ else "" if random.random() < cfg_dropout
333
+ else t[1:].lstrip() if t.startswith(".")
334
+ else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
335
+ for t in raw_texts
336
+ ]
337
+ # 3. Кодируем на лету
338
+ # Возвращает: hidden (B, L, D), mask (B, L)
339
+ embeddings, attention_mask = encode_texts(texts)
340
+
341
+ # attention_mask от токенизатора уже имеет нужный формат, но на всякий случай приведем к long
342
+ attention_mask = attention_mask.to(dtype=torch.int64)
343
+
344
+ return latents, embeddings, attention_mask
345
+
346
+ batch_sampler = DistributedResolutionBatchSampler(
347
+ dataset=dataset,
348
+ batch_size=batch_size,
349
+ num_replicas=accelerator.num_processes,
350
+ rank=accelerator.process_index,
351
+ shuffle=shuffle
352
+ )
353
+
354
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
355
+ if accelerator.is_main_process:
356
+ print("Total samples", len(dataloader))
357
+ dataloader = accelerator.prepare(dataloader)
358
+
359
+ start_epoch = 0
360
+ global_step = 0
361
+ total_training_steps = (len(dataloader) * num_epochs)
362
+ world_size = accelerator.state.num_processes
363
+
364
+ # Загрузка UNet
365
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
366
+ if os.path.isdir(latest_checkpoint):
367
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
368
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
369
+ if unet_gradient:
370
+ unet.enable_gradient_checkpointing()
371
+ unet.set_use_memory_efficient_attention_xformers(False)
372
+ try:
373
+ unet.set_attn_processor(AttnProcessor2_0())
374
+ except Exception as e:
375
+ print(f"Ошибка при включении SDPA: {e}")
376
+ unet.set_use_memory_efficient_attention_xformers(True)
377
+ else:
378
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}")
379
+
380
+
381
+ def create_optimizer(name, params):
382
+ if name == "adam8bit":
383
+ return bnb.optim.AdamW8bit(
384
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
385
+ #percentile_clipping=percentile_clipping
386
+ )
387
+ elif name == "adam":
388
+ return torch.optim.AdamW(
389
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=1e-8, weight_decay=0.01
390
+ )
391
+ else:
392
+ raise ValueError(f"Unknown optimizer: {name}")
393
+
394
+ if fbp:
395
+ trainable_params = list(unet.parameters())
396
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
397
+ def optimizer_hook(param):
398
+ optimizer_dict[param].step()
399
+ optimizer_dict[param].zero_grad(set_to_none=True)
400
+ for param in trainable_params:
401
+ param.register_post_accumulate_grad_hook(optimizer_hook)
402
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
403
+ else:
404
+ # 1. Сначала замораживаем ВСЕ параметры UNet
405
+ #unet.requires_grad_(False)
406
+
407
+ # 2. Размораживаем только нужные
408
+ #trainable_params_names = ["conv_in.weight", "conv_in.bias", "conv_out.weight", "conv_out.bias"]
409
+ #train_params = []
410
+
411
+ #for name, param in unet.named_parameters():
412
+ # if any(target in name for target in trainable_params_names):
413
+ # param.requires_grad = True
414
+ # train_params.append(param)
415
+ # print(f"Обучаемый слой: {name}")
416
+
417
+ unet.requires_grad_(True)
418
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
419
+
420
+ def lr_schedule(step):
421
+ x = step / (total_training_steps * world_size)
422
+ warmup = warmup_percent
423
+ if not use_decay:
424
+ return base_learning_rate
425
+ if x < warmup:
426
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
427
+ decay_ratio = (x - warmup) / (1 - warmup)
428
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
429
+ (1 + math.cos(math.pi * decay_ratio))
430
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
431
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
432
+
433
+ if torch_compile:
434
+ print("compiling")
435
+ unet = torch.compile(unet)
436
+ print("compiling - ok")
437
+
438
+ # Фиксированные семплы
439
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
440
+
441
+ # --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) ---
442
+ def get_negative_embedding(neg_prompt="", batch_size=1):
443
+ if not neg_prompt:
444
+ hidden_dim = 2048
445
+ seq_len = max_length
446
+ empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
447
+ empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
448
+ return empty_emb, empty_mask
449
+
450
+ uncond_emb, uncond_mask = encode_texts([neg_prompt])
451
+ uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
452
+ uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
453
+
454
+ return uncond_emb, uncond_mask
455
+
456
+ # Получаем негативные (пустые) условия для валидации
457
+ uncond_emb, uncond_mask = get_negative_embedding("low quality")
458
+
459
+
460
+
461
+ # --- Функция генерации семплов ---
462
+ @torch.compiler.disable()
463
+ @torch.no_grad()
464
+ def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
465
+ uncond_emb, uncond_mask = uncond_data
466
+
467
+ original_model = None
468
+ try:
469
+ if not torch_compile:
470
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
471
+ else:
472
+ original_model = unet.eval()
473
+
474
+ vae.to(device=device).eval()
475
+
476
+ all_generated_images = []
477
+ all_captions = []
478
+
479
+ # Распаковываем 5 элементов (добавились mask)
480
+ for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
481
+ width, height = size
482
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
483
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
484
+ sample_mask = sample_mask.to(device=device)
485
+
486
+ latents = torch.randn(
487
+ sample_latents.shape,
488
+ device=device,
489
+ dtype=sample_latents.dtype,
490
+ generator=torch.Generator(device=device).manual_seed(seed)
491
+ )
492
+
493
+ scheduler.set_timesteps(n_diffusion_steps, device=device)
494
+
495
+ for t in scheduler.timesteps:
496
+ if guidance_scale != 1:
497
+ latent_model_input = torch.cat([latents, latents], dim=0)
498
+
499
+ # Подготовка батчей для CFG (Negative + Positive)
500
+ # 1. Embeddings
501
+ curr_batch_size = sample_text_embeddings.shape[0]
502
+ seq_len = sample_text_embeddings.shape[1]
503
+ hidden_dim = sample_text_embeddings.shape[2]
504
+
505
+ neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
506
+ text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0)
507
+
508
+ # 2. Masks
509
+ neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
510
+ attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
511
+
512
+ else:
513
+ latent_model_input = latents
514
+ text_embeddings_batch = sample_text_embeddings
515
+ attention_mask_batch = sample_mask
516
+
517
+ # Предсказание с передачей всех условий
518
+ model_out = original_model(
519
+ latent_model_input,
520
+ t,
521
+ encoder_hidden_states=text_embeddings_batch,
522
+ encoder_attention_mask=attention_mask_batch,
523
+ )
524
+ flow = getattr(model_out, "sample", model_out)
525
+
526
+ if guidance_scale != 1:
527
+ flow_uncond, flow_cond = flow.chunk(2)
528
+ flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
529
+
530
+ latents = scheduler.step(flow, t, latents).prev_sample
531
+
532
+ current_latents = latents
533
+ if step==0:
534
+ current_latents = sample_latents
535
+
536
+ latents = current_latents.detach() * scaling_factor + shift_factor
537
+ latents = flux_decode(vae,latents)
538
+ decoded = vae.decode(latents.to(torch.float32)).sample
539
+ decoded_fp32 = decoded.to(torch.float32)
540
+
541
+ for img_idx, img_tensor in enumerate(decoded_fp32):
542
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
543
+ img = img.transpose(1, 2, 0)
544
+
545
+ if np.isnan(img).any():
546
+ print("NaNs found, saving stopped! Step:", step)
547
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
548
+
549
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
550
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
551
+ max_w_overall = max(255, max_w_overall)
552
+ max_h_overall = max(255, max_h_overall)
553
+
554
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
555
+ all_generated_images.append(padded_img)
556
+
557
+ caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
558
+ all_captions.append(caption_text)
559
+
560
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
561
+ pil_img.save(sample_path, "JPEG", quality=96)
562
+
563
+ if use_wandb and accelerator.is_main_process:
564
+ wandb_images = [
565
+ wandb.Image(img, caption=f"{all_captions[i]}")
566
+ for i, img in enumerate(all_generated_images)
567
+ ]
568
+ wandb.log({"generated_images": wandb_images})
569
+ if use_comet_ml and accelerator.is_main_process:
570
+ for i, img in enumerate(all_generated_images):
571
+ comet_experiment.log_image(
572
+ image_data=img,
573
+ name=f"step_{step}_img_{i}",
574
+ step=step,
575
+ metadata={"caption": all_captions[i]}
576
+ )
577
+ finally:
578
+ vae.to("cpu")
579
+ try:
580
+ all_generated_images.clear()
581
+ all_captions.clear()
582
+ del all_generated_images, all_captions
583
+ del latents, current_latents, latent_model_input, flow
584
+ del decoded, decoded_fp32
585
+ del sample_latents, sample_text_embeddings, sample_mask # Копии на GPU
586
+ del model_out
587
+ except UnboundLocalError:
588
+ pass
589
+
590
+ # 3. Синхронизируем CUDA перед очисткой
591
+ torch.cuda.synchronize()
592
+ # 4. Теперь чистим кэш аллокатора и вызываем GC
593
+ torch.cuda.empty_cache()
594
+ gc.collect()
595
+
596
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
597
+ if accelerator.is_main_process:
598
+ if save_model:
599
+ print("Генерация сэмплов до старта обучения...")
600
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
601
+ accelerator.wait_for_everyone()
602
+
603
+ def save_checkpoint(unet, variant=""):
604
+ if accelerator.is_main_process:
605
+ model_to_save = None
606
+ if not torch_compile:
607
+ model_to_save = accelerator.unwrap_model(unet)
608
+ else:
609
+ model_to_save = unet
610
+
611
+ if variant != "":
612
+ model_to_save.to(dtype=torch.float16).save_pretrained(
613
+ os.path.join(checkpoints_folder, f"{project}"), variant=variant
614
+ )
615
+ else:
616
+ model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
617
+
618
+ torch.cuda.synchronize()
619
+ torch.cuda.empty_cache()
620
+ gc.collect()
621
+ #unet = unet.to(dtype=dtype) #TODO: wtf???
622
+
623
+ # --------------------------- Тренировочный цикл ---------------------------
624
+ if accelerator.is_main_process:
625
+ print(f"Total steps per GPU: {total_training_steps}")
626
+
627
+ epoch_loss_points = []
628
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
629
+
630
+ steps_per_epoch = len(dataloader)
631
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
632
+ min_loss = 4.
633
+
634
+ for epoch in range(start_epoch, start_epoch + num_epochs):
635
+ batch_losses = []
636
+ batch_grads = []
637
+ batch_sampler.set_epoch(epoch)
638
+ accelerator.wait_for_everyone()
639
+ unet.train()
640
+
641
+ for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
642
+ with accelerator.accumulate(unet):
643
+ if save_model == False and epoch == 0 and step == 5 :
644
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
645
+ print(f"Шаг {step}: {used_gb:.2f} GB")
646
+
647
+ # шум
648
+ noise = torch.randn_like(latents, dtype=latents.dtype)
649
+
650
+ # 3. Время t (сэмплим, как и раньше, но чуть сжимаем края)
651
+ u = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
652
+ t = u * (1 - 2 * 1e-5) + 1e-5 # Теперь t строго в (0.00001 ... 0.99999)
653
+ # интерполяция между x0 и шумом
654
+ noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
655
+ # делаем integer timesteps для UNet
656
+ timesteps = t.to(torch.float32).mul(999.0)
657
+ timesteps = timesteps.clamp(0, scheduler.config.num_train_timesteps - 1)
658
+
659
+ # --- Вызов UNet с маской ---
660
+ model_pred = unet(
661
+ noisy_latents,
662
+ timesteps,
663
+ encoder_hidden_states=embeddings,
664
+ encoder_attention_mask=attention_mask
665
+ ).sample
666
+
667
+ target = noise - latents
668
+
669
+ mse_loss = F.mse_loss(model_pred.float(), target.float())
670
+ batch_losses.append(mse_loss.detach().item())
671
+
672
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
673
+ accelerator.wait_for_everyone()
674
+
675
+ losses_dict = {}
676
+ losses_dict["mse"] = mse_loss
677
+
678
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
679
+ accelerator.wait_for_everyone()
680
+
681
+ accelerator.backward(mse_loss)
682
+
683
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
684
+ accelerator.wait_for_everyone()
685
+
686
+ grad = 0.0
687
+ if not fbp:
688
+ if accelerator.sync_gradients:
689
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
690
+ grad = grad_val.float().item() if torch.is_tensor(grad_val) else float(grad_val)
691
+ optimizer.step()
692
+ lr_scheduler.step()
693
+ optimizer.zero_grad(set_to_none=True)
694
+
695
+ if accelerator.sync_gradients:
696
+ global_step += 1
697
+ progress_bar.update(1)
698
+ if accelerator.is_main_process:
699
+ if fbp:
700
+ current_lr = base_learning_rate
701
+ else:
702
+ current_lr = lr_scheduler.get_last_lr()[0]
703
+ batch_grads.append(grad)
704
+
705
+ log_data = {}
706
+ log_data["loss_mse"] = mse_loss.detach().item()
707
+ log_data["lr"] = current_lr
708
+ log_data["grad"] = grad
709
+ if accelerator.sync_gradients:
710
+ if use_wandb:
711
+ wandb.log(log_data, step=global_step)
712
+ if use_comet_ml:
713
+ comet_experiment.log_metrics(log_data, step=global_step)
714
+
715
+ if global_step % sample_interval == 0 or global_step==50:
716
+ # Передаем tuple (emb, mask) для негатива
717
+ if save_model:
718
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
719
+ elif epoch % 10 == 0:
720
+ generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
721
+ last_n = sample_interval
722
+
723
+ if save_model:
724
+ has_losses = len(batch_losses) > 0
725
+ avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if has_losses else 0.0
726
+ last_loss = batch_losses[-1] if has_losses else 0.0
727
+ max_loss = max(avg_sample_loss, last_loss)
728
+ should_save = max_loss < min_loss * save_barrier
729
+ print(
730
+ f"Saving: {should_save} | Max: {max_loss:.4f} | "
731
+ f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
732
+ )
733
+ # 6. Сохранение и обновление
734
+ if should_save:
735
+ min_loss = max_loss
736
+ save_checkpoint(unet)
737
+ unet.train()
738
+
739
+ if accelerator.is_main_process:
740
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
741
+ avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
742
+
743
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
744
+ log_data_ep = {
745
+ "epoch_loss": avg_epoch_loss,
746
+ "epoch_grad": avg_epoch_grad,
747
+ "epoch": epoch + 1,
748
+ }
749
+ if use_wandb:
750
+ wandb.log(log_data_ep)
751
+ if use_comet_ml:
752
+ comet_experiment.log_metrics(log_data_ep)
753
+
754
+ if accelerator.is_main_process:
755
+ print("Обучение завершено! Сохраняем финальную модель...")
756
+ #if save_model:
757
+ save_checkpoint(unet,"fp16")
758
+ if use_comet_ml:
759
+ comet_experiment.end()
760
+ accelerator.free_memory()
761
+ if torch.distributed.is_initialized():
762
+ torch.distributed.destroy_process_group()
763
+
764
+ print("Готово!")
unet/config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:013830d918f73de1ccea49842f1c9f0e48f5cd2257af63f8ca8702b314f624a8
3
- size 1878
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c10e81c0737a03d3518c2d3034358b8aec858e40b021ae637fe3b8c44d26ec4
3
+ size 1879
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ac2dc73cd4009b5cbba85dc5bf5f4cfb303c51f7aa10d90ef6417cdd4791f467
3
- size 5935560296
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df65e034d1b9810fe6b1b41c190af19a24e58d47a369c135194456d8d3292ab8
3
+ size 5946605448
unet1.5b-2TE-text.ipynb ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:256e5ab08659d487f86da10282a658a1876ab12e703cbe8705e76b9abee8e0ac
3
+ size 44131
{unet → unet_1te}/config-Copy1.txt RENAMED
File without changes
unet_1te/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:013830d918f73de1ccea49842f1c9f0e48f5cd2257af63f8ca8702b314f624a8
3
+ size 1878
unet_1te/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac2dc73cd4009b5cbba85dc5bf5f4cfb303c51f7aa10d90ef6417cdd4791f467
3
+ size 5935560296