stablediffusion_t5 / pipeline.py
ppbrown's picture
Upload pipeline.py with huggingface_hub
238b78a verified
# pipeline.py
# subclass SD pipeline to replace CLIP-L with T5
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from transformers import T5Tokenizer, T5EncoderModel
from diffusers import StableDiffusionPipeline
from diffusers.utils import logging
from transformers import CLIPTextModelWithProjection
from transformers.models.clip.modeling_clip import CLIPTextModel, CLIPTextConfig
logger = logging.get_logger(__name__)
T5_NAME="mcmonkey/google_t5-v1_1-xxl_encoderonly"
class StableDiffusionT5Pipeline(StableDiffusionPipeline):
###################################################################################
# Create a minimal CLIPTextModelWithProjection with minimal layers and vocab size,
# and the projection layer we need
# We only actually care about the .text_projection
# but this is the only (easy) way to save using pipe.save_pretrained()
def create_clipholder(self):
config = CLIPTextConfig(
vocab_size=1, # minimal vocab size
hidden_size=4096, # input hidden size
projection_dim=768, # output dimension to projection
num_hidden_layers=0, # no transformer layers
num_attention_heads=1,
intermediate_size=4, # minimal intermediate size
)
model = CLIPTextModelWithProjection(config)
# This should automatically have generated the following:
#model.text_projection = nn.Linear(4096, 768)
return model
###################################################
# override this so we can auto-init text_encoder
# These are the original values?
#_optional_components = ["safety_checker", "feature_extractor", "image_encoder", "text_encoder"]
# t5_projection not really optional, but needed it here to stop internal whining
_optional_components = StableDiffusionPipeline._optional_components + ["text_encoder", "t5_projection"]
def __init__(
self,
vae,
text_encoder,
tokenizer,
unet,
scheduler,
safety_checker=None,
feature_extractor=None,
image_encoder=None,
requires_safety_checker=True,
t5_projection=None,
):
self.tokenizer = (
tokenizer
if tokenizer is not None
else T5Tokenizer.from_pretrained(T5_NAME,torch_dtype=unet.dtype)
)
if text_encoder is None:
self.text_encoder = T5EncoderModel.from_pretrained(T5_NAME, torch_dtype=unet.dtype)
else:
self.text_encoder = text_encoder
super().__init__(
vae=vae,
tokenizer=self.tokenizer,
text_encoder=self.text_encoder,
unet=unet,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
image_encoder=image_encoder,
requires_safety_checker=requires_safety_checker,
)
if t5_projection is None:
print("WARNING: no CLIPTextModelWithProjection found. This may indicate an error")
answer=input("Should I auto-generate one? type 'Yes' to proceed")
if answer != "Yes":
exit(1)
self.t5_projection = self.create_clipholder().to(vae.device, dtype=vae.dtype)
else:
if isinstance(t5_projection, CLIPTextModelWithProjection):
self.t5_projection = t5_projection
else:
raise TypeError("Error: expected t5_projection to be type CLIPTextModelWithProjection")
checkval = getattr(self.t5_projection.config, "scaling_factor", None)
if not checkval:
#0.013 # This is my kinda calculated factor, for norms ~ 1.0
#scaling_factor = 0.13025 # This would be the vae scaling factor
#scaling_factor = 0.035 # This is a commonly used factor for T5
# buuut... to make output stdD similar to CLIP, scaling factor = 1.8
# (See check-cache-stdd-t5.py)
scaling_factor = 1.8
print("INFO: Pipeline setting empty t5 scaling factor to", scaling_factor)
self.t5_projection.config.scaling_factor = scaling_factor
# Ensure everything is properly registered for to("cuda")
# and also for saving the model
self.register_modules(t5_projection=self.t5_projection)
# returns the raw t5 4096dim embedding, not the one scaled to 768
def encode_prompt_t5(
self,
prompt,
negative_prompt, #can be None
device,
padding=None,
):
def _tok(text):
out = self.tokenizer(
text,
return_tensors="pt",
padding=padding,
max_length=self.tokenizer.model_max_length,
truncation=True,
)
return out.input_ids.to(device=device, dtype=torch.long), out.attention_mask.to(device)
pos_ids, pos_mask = _tok(prompt)
pos_hidden = self.text_encoder(pos_ids, attention_mask=pos_mask).last_hidden_state
neg_prompt = negative_prompt if negative_prompt is not None else ""
neg_ids, neg_mask = _tok(neg_prompt)
neg_hidden = self.text_encoder(neg_ids, attention_mask=neg_mask).last_hidden_state
return pos_hidden, neg_hidden
def encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
padding=None,
**kwargs,
):
scaling_factor = self.t5_projection.config.scaling_factor
pos_hidden, neg_hidden = self.encode_prompt_t5(prompt, negative_prompt, device, padding=padding)
pos_embeds = self.t5_projection.text_projection(pos_hidden)
pos_embeds = pos_embeds * scaling_factor
if do_classifier_free_guidance:
neg_embeds = self.t5_projection.text_projection(neg_hidden)
neg_embeds = neg_embeds * scaling_factor
pos_embeds = pos_embeds.repeat_interleave(num_images_per_prompt, dim=0)
neg_embeds = neg_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return [neg_embeds, pos_embeds]
else:
pos_embeds = pos_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return pos_embeds