| | |
| |
|
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import AutoTokenizer, AutoModel |
| | from transformers import T5Tokenizer, T5EncoderModel |
| | from diffusers import StableDiffusionPipeline |
| | from diffusers.utils import logging |
| |
|
| | from transformers import CLIPTextModelWithProjection |
| | from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPTextConfig |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | T5_NAME="mcmonkey/google_t5-v1_1-xxl_encoderonly" |
| |
|
| |
|
| | class StableDiffusionT5Pipeline(StableDiffusionPipeline): |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | def create_clipholder(self): |
| | config = CLIPTextConfig( |
| | vocab_size=1, |
| | hidden_size=4096, |
| | projection_dim=768, |
| | num_hidden_layers=0, |
| | num_attention_heads=1, |
| | intermediate_size=4, |
| | ) |
| |
|
| | model = CLIPTextModelWithProjection(config) |
| | |
| | |
| |
|
| | return model |
| |
|
| | |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | |
| | _optional_components = StableDiffusionPipeline._optional_components + ["text_encoder", "t5_projection"] |
| |
|
| | def __init__( |
| | self, |
| | vae, |
| | text_encoder, |
| | tokenizer, |
| | unet, |
| | scheduler, |
| | safety_checker=None, |
| | feature_extractor=None, |
| | image_encoder=None, |
| | requires_safety_checker=True, |
| | t5_projection=None, |
| | ): |
| | self.tokenizer = ( |
| | tokenizer |
| | if tokenizer is not None |
| | else T5Tokenizer.from_pretrained(T5_NAME,torch_dtype=unet.dtype) |
| | ) |
| |
|
| |
|
| | if text_encoder is None: |
| | self.text_encoder = T5EncoderModel.from_pretrained(T5_NAME, torch_dtype=unet.dtype) |
| | else: |
| | self.text_encoder = text_encoder |
| |
|
| | super().__init__( |
| | vae=vae, |
| | tokenizer=self.tokenizer, |
| | text_encoder=self.text_encoder, |
| | unet=unet, |
| | scheduler=scheduler, |
| | safety_checker=safety_checker, |
| | feature_extractor=feature_extractor, |
| | image_encoder=image_encoder, |
| | requires_safety_checker=requires_safety_checker, |
| | ) |
| |
|
| | if t5_projection is None: |
| | print("WARNING: no CLIPTextModelWithProjection found. This may indicate an error") |
| | answer=input("Should I auto-generate one? type 'Yes' to proceed") |
| | if answer != "Yes": |
| | exit(1) |
| | self.t5_projection = self.create_clipholder().to(vae.device, dtype=vae.dtype) |
| | else: |
| | if isinstance(t5_projection, CLIPTextModelWithProjection): |
| | self.t5_projection = t5_projection |
| | else: |
| | raise TypeError("Error: expected t5_projection to be type CLIPTextModelWithProjection") |
| |
|
| | checkval = getattr(self.t5_projection.config, "scaling_factor", None) |
| | if not checkval: |
| | |
| | |
| | |
| | |
| | |
| | scaling_factor = 1.8 |
| | print("INFO: Pipeline setting empty t5 scaling factor to", scaling_factor) |
| | self.t5_projection.config.scaling_factor = scaling_factor |
| |
|
| | |
| | |
| | self.register_modules(t5_projection=self.t5_projection) |
| |
|
| | |
| | def encode_prompt_t5( |
| | self, |
| | prompt, |
| | negative_prompt, |
| | device, |
| | padding=None, |
| | ): |
| | def _tok(text): |
| | out = self.tokenizer( |
| | text, |
| | return_tensors="pt", |
| | padding=padding, |
| | max_length=self.tokenizer.model_max_length, |
| | truncation=True, |
| | ) |
| | return out.input_ids.to(device=device, dtype=torch.long), out.attention_mask.to(device) |
| |
|
| | pos_ids, pos_mask = _tok(prompt) |
| | pos_hidden = self.text_encoder(pos_ids, attention_mask=pos_mask).last_hidden_state |
| |
|
| | neg_prompt = negative_prompt if negative_prompt is not None else "" |
| | neg_ids, neg_mask = _tok(neg_prompt) |
| | neg_hidden = self.text_encoder(neg_ids, attention_mask=neg_mask).last_hidden_state |
| |
|
| | return pos_hidden, neg_hidden |
| |
|
| |
|
| | def encode_prompt( |
| | self, |
| | prompt, |
| | device, |
| | num_images_per_prompt, |
| | do_classifier_free_guidance, |
| | negative_prompt=None, |
| | padding=None, |
| | **kwargs, |
| | ): |
| |
|
| |
|
| | scaling_factor = self.t5_projection.config.scaling_factor |
| |
|
| | pos_hidden, neg_hidden = self.encode_prompt_t5(prompt, negative_prompt, device, padding=padding) |
| |
|
| | pos_embeds = self.t5_projection.text_projection(pos_hidden) |
| | pos_embeds = pos_embeds * scaling_factor |
| |
|
| | if do_classifier_free_guidance: |
| | neg_embeds = self.t5_projection.text_projection(neg_hidden) |
| | neg_embeds = neg_embeds * scaling_factor |
| |
|
| | pos_embeds = pos_embeds.repeat_interleave(num_images_per_prompt, dim=0) |
| | neg_embeds = neg_embeds.repeat_interleave(num_images_per_prompt, dim=0) |
| | return [neg_embeds, pos_embeds] |
| | else: |
| | pos_embeds = pos_embeds.repeat_interleave(num_images_per_prompt, dim=0) |
| | return pos_embeds |
| |
|