# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://github.com/Tencent-Hunyuan/HunyuanImage-3.0/blob/main/LICENSE # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import math import warnings import random from typing import List, Optional, Union, Dict, Any from collections import defaultdict from copy import deepcopy import numpy as np import torch import torch.nn.functional as F from transformers import AutoTokenizer from diffusers.utils import BaseOutput def default(value, default_value): return value if value is not None else default_value def ensure_list(value): if value is None: return [] if isinstance(value, (list, tuple)): return list(value) return [value] class Resolution(object): def __init__(self, size, *args): if isinstance(size, str): if 'x' in size: size = size.split('x') size = (int(size[0]), int(size[1])) else: size = int(size) if len(args) > 0: size = (size, args[0]) if isinstance(size, int): size = (size, size) self.h = self.height = size[0] self.w = self.width = size[1] self.r = self.ratio = self.height / self.width def __getitem__(self, idx): if idx == 0: return self.h elif idx == 1: return self.w else: raise IndexError(f'Index {idx} out of range') def __str__(self): return f'{self.h}x{self.w}' class ResolutionGroup(object): def __init__( self, base_size=None, step=None, align=1, min_multiple=0.5, max_multiple=2.0, max_entries=33, presets=None, ): self.align = int(align) if self.align <= 0: raise ValueError(f'align must be positive, got {align}') self.base_size = base_size if base_size is None or not isinstance(base_size, int): raise ValueError(f'base_size must be an int, but got {base_size!r}') if base_size % self.align != 0: raise ValueError(f'base_size {base_size} is not divisible by align {self.align}') self.min_multiple = float(min_multiple) self.max_multiple = float(max_multiple) if not (0 < self.min_multiple < self.max_multiple): raise ValueError( f"min_multiple ({self.min_multiple}) must be positive and smaller than max_multiple ({self.max_multiple})" ) self.max_entries = max_entries if max_entries is None else int(max_entries) if self.max_entries is not None and self.max_entries <= 0: raise ValueError(f'max_entries must be positive when provided, got {self.max_entries}') if step is None: step = max(self.align, base_size // 16) else: if step <= 0: raise ValueError(f'step must be positive, got {step}') step = max(self.align, int(math.ceil(step / self.align)) * self.align) if step > base_size * max(1, int(self.max_multiple * 2)): raise ValueError( f'step {step} is too large for base_size {base_size} and max_multiple {self.max_multiple}' ) min_height = max(self.align, int(math.ceil(base_size * self.min_multiple / self.align)) * self.align) min_width = min_height max_height = int(math.floor(base_size * self.max_multiple / self.align)) * self.align max_width = max_height if min_height >= max_height: raise ValueError( f'min height ({min_height}) must be smaller than max height ({max_height})' ) span_up = max_height - base_size span_down = base_size - min_height if self.max_entries is not None: allowed_steps = max(self.max_entries - 1, 0) def _steps_for(candidate: int) -> int: up_steps = math.ceil(span_up / candidate) if span_up else 0 down_steps = math.ceil(span_down / candidate) if span_down else 0 return up_steps + down_steps candidate_step = step while _steps_for(candidate_step) > allowed_steps: candidate_step += self.align step = candidate_step self.step = step self.min_height = min_height self.min_width = min_width self.max_height = max_height self.max_width = max_width if self.min_height >= self.max_height: raise ValueError( f'min height ({self.min_height}) must be smaller than max height ({self.max_height})' ) self.presets = presets if presets: self.data = self._build_from_presets(presets) else: self.data = self._calc_by_step() if not self.data: raise ValueError('ResolutionGroup has no valid entries') self.ratio = np.array([x.ratio for x in self.data]) self.attr = ['' for _ in range(len(self.data))] self.prefix_space = 0 self._lookup = {(res.height, res.width): res for res in self.data} def __len__(self): return len(self.data) def __getitem__(self, idx): return self.data[idx] def __repr__(self): prefix = self.prefix_space * ' ' prefix_close = (self.prefix_space - 4) * ' ' res_str = f'ResolutionGroup(base_size={self.base_size}, step={self.step}, data=' attr_maxlen = max([len(x) for x in self.attr] + [5]) res_str += \ f'\n{prefix}ID: height width ratio {" " * max(0, attr_maxlen - 4)}count h/16 w/16 tokens\n{prefix}' res_str += \ ('\n' + prefix).join([f'{i:2d}: ({x.h:4d}, {x.w:4d}) {self.ratio[i]:.4f} {self.attr[i]:>{attr_maxlen}s} ' f'({x.h // 16:3d}, {x.w // 16:3d}) {x.h // 16 * x.w // 16:6d}' for i, x in enumerate(self.data)]) res_str += f'\n{prefix_close})' return res_str def _calc_by_step(self): assert self.align <= self.step, f'align {self.align} must be smaller than step {self.step}' min_height = self.min_height min_width = self.min_width max_height = self.max_height max_width = self.max_width resolutions = [Resolution(self.base_size, self.base_size)] cur_height, cur_width = self.base_size, self.base_size while True: if cur_height >= max_height and cur_width <= min_width: break cur_height = min(cur_height + self.step, max_height) cur_width = max(cur_width - self.step, min_width) resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align)) cur_height, cur_width = self.base_size, self.base_size while True: if cur_height <= min_height and cur_width >= max_width: break cur_height = max(cur_height - self.step, min_height) cur_width = min(cur_width + self.step, max_width) resolutions.append(Resolution(cur_height // self.align * self.align, cur_width // self.align * self.align)) resolutions = sorted(resolutions, key=lambda x: x.ratio) return self._enforce_entry_limit(resolutions) def _build_from_presets(self, presets): resolutions = [] seen = set() for entry in presets: if isinstance(entry, str): if 'x' not in entry: continue parts = entry.split('x') if len(parts) != 2: continue try: height = int(parts[0]) width = int(parts[1]) except ValueError: continue elif isinstance(entry, (list, tuple)) and len(entry) == 2: height, width = entry try: height = int(height) width = int(width) except ValueError: continue else: continue height = max(self.align, int(round(height / self.align)) * self.align) width = max(self.align, int(round(width / self.align)) * self.align) if not (self.min_height <= height <= self.max_height and self.min_width <= width <= self.max_width): continue key = (height, width) if key in seen: continue seen.add(key) resolutions.append(Resolution(height, width)) if not any(reso.height == self.base_size and reso.width == self.base_size for reso in resolutions): resolutions.append(Resolution(self.base_size, self.base_size)) resolutions = sorted(resolutions, key=lambda x: x.ratio) return self._enforce_entry_limit(resolutions) def _enforce_entry_limit(self, resolutions): if self.max_entries is None or len(resolutions) <= self.max_entries: return resolutions step = (len(resolutions) - 1) / (self.max_entries - 1) if self.max_entries > 1 else 1 indices = [] for i in range(self.max_entries): idx = int(round(i * step)) idx = max(0, min(idx, len(resolutions) - 1)) if indices and idx <= indices[-1]: idx = indices[-1] + 1 if idx >= len(resolutions): idx = len(resolutions) - 1 indices.append(idx) indices = sorted(set(indices)) for idx in range(len(resolutions)): if len(indices) >= self.max_entries: break if idx not in indices: indices.append(idx) indices.sort() base_index = next((i for i, reso in enumerate(resolutions) if reso.height == self.base_size and reso.width == self.base_size), None) if base_index is not None and base_index not in indices: indices[len(indices) // 2] = base_index indices.sort() return [resolutions[i] for i in indices[:self.max_entries]] def get_target_size(self, width, height): key = (int(round(height)), int(round(width))) reso = self._lookup.get(key) if reso is None: ratio = height / width idx = np.argmin(np.abs(self.ratio - ratio)) reso = self.data[idx] return reso.w, reso.h def get_base_size_and_ratio_index(self, width, height): key = (int(round(height)), int(round(width))) reso = self._lookup.get(key) if reso is not None: return self.base_size, self.data.index(reso) ratio = height / width idx = np.argmin(np.abs(self.ratio - ratio)) return self.base_size, idx class ImageInfo: """ Class to store image information for processing and generation. """ def __init__( self, image_type: str = None, image_tensor: torch.Tensor = None, image_width: int = None, image_height: int = None, token_width: int = None, token_height: int = None, image_token_length: int = None, base_size: int = None, ratio_index: int = None, **kwargs, ): self.image_type = image_type self.image_tensor = image_tensor self.image_width = image_width self.w = image_width self.image_height = image_height self.h = image_height self.token_width = token_width self.tk_w = token_width self.token_height = token_height self.tk_h = token_height self.image_token_length = default( image_token_length, token_width * token_height if token_width is not None and token_height is not None else None ) self.base_size = base_size self.ratio_index = ratio_index self.add_timestep_token = kwargs.get("add_timestep_token", True) self.add_guidance_token = kwargs.get("add_guidance_token", False) self.use_front_boi_token = kwargs.get("use_front_boi_token", True) self.add_image_shape_token = kwargs.get("add_image_shape_token", True) def __getitem__(self, key: str) -> Any: """Allow dictionary-like access to attributes.""" if hasattr(self, key): return getattr(self, key) raise KeyError(f"Key '{key}' not found in ImageInfo") def __setitem__(self, key: str, value: Any) -> None: """Allow dictionary-like assignment to attributes.""" if hasattr(self, key): setattr(self, key, value) else: raise KeyError(f"Key '{key}' not found in ImageInfo") def __contains__(self, key: str) -> bool: """Check if the key exists in the ImageInfo object.""" return hasattr(self, key) def __repr__(self): return (f"ImageInfo(image_type={self.image_type}, image_tensor={self.image_tensor}, " f"image_width={self.image_width}, image_height={self.image_height}, " f"token_width={self.token_width}, token_height={self.token_height}, " f"image_token_length={self.image_token_length}, " f"base_size={self.base_size}, ratio_index={self.ratio_index}") @property def meta_info(self): # Used for image sections of tkwrapper.encode_general() if self.image_type in ["vae", "gen_image"]: return dict( token_length=self.image_token_length, add_timestep_token=self.add_timestep_token, add_guidance_token=self.add_guidance_token, use_front_boi_token=self.use_front_boi_token, add_image_shape_token=self.add_image_shape_token, base_size=self.base_size, ratio_idx=self.ratio_index, # for rope 2d token_height=self.token_height, token_width=self.token_width, # for bc image_height=self.image_height, image_width=self.image_width, ) elif self.image_type in ["vit"]: return dict( token_length=self.image_token_length, use_front_boi_token=self.use_front_boi_token, add_image_shape_token=self.add_image_shape_token, # for rope 2d token_height=self.token_height, token_width=self.token_width, # for bc image_height=self.image_height, image_width=self.image_width, ) else: raise ValueError(f"Unknown image type '{self.image_type}'") @property def num_special_tokens(self): if self.args is None: raise ValueError("meta_info requires `args` attribute to be set.") if self.image_type in ["vae", "src_image", "gen_image"]: count = ( 2 + # + or + (1 if self.add_timestep_token else 0) + (1 if self.add_guidance_token else 0) + (2 if self.add_image_shape_token else 0) ) else: raise ValueError(f"Unknown image_type: {self.image_type}") return count def copy(self, copy_image_tensor=True): if copy_image_tensor and self.image_tensor is None: raise ValueError("image_tensor is None, cannot copy") return ImageInfo( image_type=self.image_type, image_tensor=self.image_tensor.clone() if copy_image_tensor else None, image_width=self.image_width, image_height=self.image_height, token_width=self.token_width, token_height=self.token_height, image_token_length=self.image_token_length, base_size=self.base_size, ratio_index=self.ratio_index, ) def zeros_(self): self.image_tensor = torch.zeros_like(self.image_tensor) class ImageTensor(torch.Tensor): # This class is just for type hinting purposes. Attribute `i` should be defined # as an instance attribute of the torch.Tensor instance, like: tensor.i = ImageInfo(...) i: ImageInfo vision_encoder_kwargs: dict class JointImageInfo(object): def __init__(self, vae_image_info: ImageInfo, vision_image_info: ImageInfo, vision_encoder_kwargs: dict = None): self.vae_image_info = vae_image_info self.vision_image_info = vision_image_info self.vision_encoder_kwargs = vision_encoder_kwargs # Define key attributes to align with ImageInfo for uniformity self.image_type = "joint_image" self.image_token_length = vae_image_info.image_token_length + vision_image_info.image_token_length self.add_timestep_token = vae_image_info.add_timestep_token self.use_front_boi_token = vae_image_info.use_front_boi_token self.add_image_shape_token = vae_image_info.add_image_shape_token def __repr__(self): return f"JointImageInfo(vae_image={self.vae_image_info}, vision_image={self.vision_image_info})" @property def meta_info(self): # Used for image sections of tkwrapper.encode_general() return dict( token_length=[self.vae_image_info.image_token_length, self.vision_image_info.image_token_length], add_timestep_token=self.add_timestep_token, use_front_boi_token=self.use_front_boi_token, add_image_shape_token=self.add_image_shape_token, base_size=self.vae_image_info.base_size, ratio_idx=self.vae_image_info.ratio_index, # for rope 2d token_height=[self.vae_image_info.token_height, self.vision_image_info.token_height], token_width=[self.vae_image_info.token_width, self.vision_image_info.token_width], # for bc image_height=[self.vae_image_info.image_height, self.vision_image_info.image_height], image_width=[self.vae_image_info.image_width, self.vision_image_info.image_width], ) @property def num_special_tokens(self): return ( 2 + # + (1 if self.add_timestep_token else 0) + (2 if self.add_image_shape_token else 0) + 1 # ) def copy(self, copy_image_tensor=True): if copy_image_tensor and ( self.vae_image_info.image_tensor is None or self.vision_image_info.image_tensor is None): raise ValueError("image_tensor is None, cannot copy") return JointImageInfo( self.vae_image_info.copy(copy_image_tensor), self.vision_image_info.copy(copy_image_tensor), self.vision_encoder_kwargs, ) def zeros_(self): self.vae_image_info.zeros_() self.vision_image_info.zeros_() class JointImage(object): def __init__(self, vae_image: ImageTensor, vision_image: ImageTensor): self.vae_image = vae_image self.vision_image = vision_image self.i = JointImageInfo(vae_image.i, vision_image.i) class TokenizerEncodeOutput(BaseOutput): tokens: torch.Tensor = None timestep_scatter_index: Optional[torch.Tensor] = None guidance_scatter_index: Optional[torch.Tensor] = None text_slices: Optional[List[slice]] = None gen_image_slices: Optional[List[slice]] = None joint_image_slices: Optional[List[slice]] = None cond_vae_image_slices: Optional[List[slice]] = None cond_vit_image_slices: Optional[List[slice]] = None text_mask: Optional[torch.Tensor] = None gen_image_mask: Optional[torch.Tensor] = None cond_vae_image_mask: Optional[torch.Tensor] = None cond_vit_image_mask: Optional[torch.Tensor] = None real_pos: Optional[torch.Tensor] = None all_image_slices: Optional[List[slice]] = None cond_timestep_scatter_index: Optional[torch.Tensor] = None gen_timestep_scatter_index: Optional[torch.Tensor] = None class Conversation: roles: List[str] = ["User", "Assistant"] sep: str = "\n\n" class TokenizerWrapper(object): def __init__(self, tokenizer): if isinstance(tokenizer, str): self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) else: self.tokenizer = tokenizer # Define short names self.bos_token_id = self.tokenizer.bos_token_id self.eos_token_id = self.tokenizer.eos_token_id self.pad_token_id = self.tokenizer.pad_token_id self.boi_token_id = self.tokenizer.convert_tokens_to_ids("") self.eoi_token_id = self.tokenizer.convert_tokens_to_ids("") self.img_token_id = self.tokenizer.convert_tokens_to_ids("") self.cfg_token_id = self.tokenizer.convert_tokens_to_ids("") self.end_answer_token_id = self.tokenizer.convert_tokens_to_ids("") self.end_recaption_token_id = self.tokenizer.convert_tokens_to_ids("") self.ratio_token_offset = self.tokenizer.convert_tokens_to_ids("") self.special_token_map = self.tokenizer.added_tokens_encoder def pad(self, tensor_list, dim=0, pad_val=None): if pad_val is None: pad_val = self.pad_token_id max_len = max([t.shape[dim] for t in tensor_list]) padded_tensor_list = [] for t in tensor_list: if t.shape[dim] < max_len: assert pad_val is not False, "Not allowed pad." t = F.pad(t, (0, max_len - t.shape[dim]), value=pad_val) padded_tensor_list.append(t) return padded_tensor_list def encode(self, *args, **kwargs): return self.tokenizer.encode(*args, **kwargs) def decode(self, *args, **kwargs): return self.tokenizer.decode(*args, **kwargs) def encode_text( self, *texts, uncond_enabled: Optional[Union[bool, List[bool]]] = None, uncond_p: Optional[float] = None, max_length: Optional[int] = None, pad: Optional[str] = None, return_lengths: bool = False, ): """ Encode text and image for AR-like model training of the text-to-image/instruction tuning tasks. Support encode multiple texts at once. Each text can be separately conditioned or unconditioned based on the uncond_flags and a uniform uncond_p. ** token is always prepended to the text tokens.** Parameters ---------- texts: str or List[str] List of texts to be encoded. uncond_enabled: bool or List[bool] List of flags to indicate whether the text should be unconditioned. If False, the text will never be unconditioned. If True, the text will be unconditioned with uncond_p. uncond_p: float Probability to the unconditional text. Only works when uncond_enabled is True. max_length: int Maximum length of the encoded text. pad: Optional[str] Padding method. Can be 'left' or 'right'. return_lengths: bool Whether to return the length of each encoded text. """ if pad is not None: assert max_length is not None, "max_length should be provided when pad is not None." if uncond_enabled is None: uncond_enabled = [True] * len(texts) elif isinstance(uncond_enabled, bool): uncond_enabled = [uncond_enabled] * len(texts) if len(uncond_enabled) != len(texts): print(uncond_enabled, texts) assert len(uncond_enabled) == len(texts), ( f"Length of uncond_flags should be equal to the number of texts, " f"but got {len(uncond_enabled)} and {len(texts)}." ) # Prepare text/uncond tokens # TODO: If len(texts) > 1, such as instruction + prompt in inpainting, we need to determine how to do uncond. # Now all texts will be cond or uncond at the same time. do_uncond_drop = (uncond_p is not None) and (random.random() < uncond_p) text_tokens, lengths = [], [] cum_length = 0 for text, uncond_flag in zip(texts, uncond_enabled): # If reach the max_length and there still have unencoded texts, give a warning message and break the loop. if max_length is not None and cum_length >= max_length: warnings.warn( f"Text length exceeds the max_length({max_length}). The remaining texts will be ignored: " f"{text[:80]}..." ) break # Set add_special_tokens=False to avoid adding token in some LLMs. if isinstance(text, str): text_token = self.tokenizer.encode(text, add_special_tokens=False) else: text_token = text if uncond_flag and do_uncond_drop: text_token = [self.cfg_token_id] * len(text_token) # Cutoff the text by max_length if necessary if max_length is not None and (cum_length + len(text_token)) > max_length: text_token = text_token[:max_length - cum_length] text_tokens.extend(text_token) lengths.append(len(text_token)) cum_length += len(text_token) # Prepend/Append tokens if applicable if pad is not None and (pad_length := max_length - len(text_tokens)) > 0: if pad == 'left': text_tokens = [self.pad_token_id] * pad_length + text_tokens elif pad == 'right': text_tokens = text_tokens + [self.pad_token_id] * pad_length else: raise ValueError(f"Unsupported padding method: {pad}.") if return_lengths: return text_tokens, lengths return text_tokens @staticmethod def _check_key_number_matched(keys, data): # Assert keys and token_source are matched assert set(keys) == set(data.keys()), ( f"Keys in the template and token source should be matched, but got {set(keys)} and {list(data.keys())}." ) key_counts = {k: 0 for k in keys} for key in keys: key_counts[key] += 1 for key, count in key_counts.items(): assert len(data[key]) == count, ( f"Number of `{key}` in the token source should be matched with the template, but got " f"{data[key]}({len(data[key])}) and {count}." ) def _add_image_meta_info_token(self, token_seq, token_count, extra_token_pos, add_timestep_token=False, add_image_shape_token=False, base_size=None, ratio_idx=None, image_type=None, add_guidance_token=False): if add_image_shape_token: token_seq.extend([ self.special_token_map[f""], self.special_token_map[f""] ]) token_count += 2 if add_timestep_token: token_seq.extend([self.special_token_map[""]]) extra_token_pos['timestep'].append(token_count) if image_type is not None: if image_type == "gen_image": extra_token_pos['gen_timestep'].append(token_count) elif image_type in ["joint_image"]: extra_token_pos['cond_timestep'].append(token_count) else: raise ValueError(f"Unsupported image type: {image_type}.") token_count += 1 if add_guidance_token: token_seq.extend([self.special_token_map[""]]) extra_token_pos['guidance'].append(token_count) token_count += 1 return token_count @staticmethod def _shorten_text(text): import re text = re.sub(r"()+", lambda m: f"[]{{{len(m.group(0)) // 5}}}", text) text = re.sub(r"()+", lambda m: f"[]{{{len(m.group(0)) // 5}}}", text) return text def encode_sequence( self, template: str, token_source: Dict[str, List], total_length=None, add_timestep_token=False, add_guidance_token=False, last_key_only_prefix=False, add_eos=True, use_front_boi_token=True, add_pad=True, add_bos=True, drop_last: Union[str, bool] = 'auto', add_image_shape_token=False, ): """ Encode a sequence based on the template (e.g., `text-image` for t2i, `text-image-image` for instruction tuning) and token source. Parameters ---------- template: str Template of the sequence. E.g., "text-gen_image" means the sequence is composed of text and an image. "text-text-gen_image" means the sequence is composed of two sections of text and an image. token_source: Dict[str, List] Token source for each key in the template, in order. - text: List[Dict]. - gen_image: List[Dict]. - joint_image: List[Dict]. total_length: int Total length of the encoded sequence, include padding tokens. add_timestep_token: bool Whether to add timestep token before the image tokens. (Right after the tokens) add_guidance_token: bool Whether to add guidance token before the image tokens. last_key_only_prefix: bool Whether to only use the modal prefix in the last key. add_eos: bool or 'auto' Whether to add eos token at the end of the sequence. If True, always add eos token. If 'auto', add eos token only when the total_length is not reached and the last token is not . use_front_boi_token: bool: Whether to put the token at the front of iw, ih and timestep tokens. add_pad: bool or 'auto' Whether to add padding tokens to the sequence. If True and total_length is not reached, add padding tokens. add_bos: bool Whether to add bos token at the beginning of the sequence. drop_last: bool or 'auto' - If auto, drop last tokens exceeding the total_length if the total_length is provided. If cut point is in the middle of the image tokens, an error will raised. - If True, drop last tokens exceeding the total_length. If cut point is in the middle of the image tokens, all the successive image tokens will be dropped. - If False, keep the last tokens exceeding the total_length, even if the total_length is reached. add_image_shape_token: bool Whether to add image shape token before the image tokens. (Right before the token) Returns ------- token_seq: list Encoded token sequence. extra_token_pos: dict Positions of extra tokens. """ if last_key_only_prefix: assert add_eos is not True, "add_eos should not be True when last_key_only_prefix is True." if drop_last is True and total_length is None: raise ValueError("total_length should be provided when drop_last is True.") keys = template.split('-') modal_length = len(keys) index_indicator = {k: 0 for k in token_source} for k, v in token_source.items(): assert isinstance(v, (list, tuple)), ( f"Value of `{k}` in the token source should be a list or tuple, but got {type(v)}." ) self._check_key_number_matched(keys, token_source) token_seq = [] token_count = 0 extra_token_pos = defaultdict(list) if add_bos: token_seq.append(self.bos_token_id) token_count += 1 # If drop_last is True, we check the token_count on the fly and exit the loop if the total_length is reached. # This check is only applied to the block tokens. Block tokens mean the tokens that are unsplittable, like # image tokens. Text tokens are splittable, so we don't need to check the token_count for text. # If the loop is broken by drop_last, we don't add the eos token at the end because the sequence is not # complete. drop_last_break = False for i, key in enumerate(keys): source = token_source[key][index_indicator[key]] if key == "text": token_seq.extend(source) # text token sequence extra_token_pos["_start"].append(token_count) token_count += len(source) extra_token_pos["_end"].append(token_count - 1) elif key == "gen_image": if isinstance(source, int): source = {'length': source} extra_count = 2 + ( 1 if source.get('timestep', add_timestep_token) else 0) + ( 1 if source.get('guidance', add_guidance_token) else 0) + ( 2 if source.get('image_shape', add_image_shape_token) else 0 ) if drop_last is True and token_count + extra_count + source['length'] > total_length: drop_last_break = True break if source.get('front_boi', use_front_boi_token): token_seq.append(self.boi_token_id) extra_token_pos["boi"].append(token_count) token_count += 1 token_count = self._add_image_meta_info_token( token_seq=token_seq, token_count=token_count, extra_token_pos=extra_token_pos, add_timestep_token=source.get('timestep', add_timestep_token), add_guidance_token=source.get('guidance', add_guidance_token), add_image_shape_token=source.get('image_shape', add_image_shape_token), base_size=source.get('base_size'), ratio_idx=source.get('ratio_idx'), image_type=key, ) if not source.get('front_boi', use_front_boi_token): token_seq.append(self.boi_token_id) extra_token_pos["boi"].append(token_count) token_count += 1 if last_key_only_prefix and i == modal_length - 1: pass # for AR inference else: token_seq.extend( [self.img_token_id] * source['length'] + # token number [self.eoi_token_id] ) extra_token_pos["_start"].append(token_count) extra_token_pos["_start"].append(token_count) token_count += source['length'] extra_token_pos["_end"].append(token_count - 1) extra_token_pos["_end"].append(token_count - 1) extra_token_pos["eoi"].append(token_count) token_count += 1 # elif key == "joint_image": assert isinstance(source['length'], list) and len( source['length']) == 2, "joint_image length should be a list of two integers" extra_count = 2 + 1 + ( # boi, eoi, joint_img_sep 1 if source.get('timestep', add_timestep_token) else 0) + ( 2 if source.get('image_shape', add_image_shape_token) else 0 ) if drop_last is True and token_count + extra_count + sum(source['length']) > total_length: drop_last_break = True break if source.get('front_boi', use_front_boi_token): token_seq.append(self.boi_token_id) # Use patched boi for Janus, otherwise useing default extra_token_pos["boi"].append(token_count) token_count += 1 token_count = self._add_image_meta_info_token( token_seq=token_seq, token_count=token_count, extra_token_pos=extra_token_pos, add_timestep_token=source.get('timestep', add_timestep_token), add_image_shape_token=source.get('image_shape', add_image_shape_token), base_size=source.get('base_size'), ratio_idx=source.get('ratio_idx'), image_type=key, ) if not source.get('front_boi', use_front_boi_token): token_seq.append(self.boi_token_id) extra_token_pos["boi"].append(token_count) token_count += 1 if last_key_only_prefix and i == modal_length - 1: pass # for AR inference else: token_seq.extend( [self.img_token_id] * source['length'][0] ) extra_token_pos["_start"].append(token_count) extra_token_pos["_start"].append(token_count) extra_token_pos["_start"].append(token_count) token_count += source['length'][0] extra_token_pos["_end"].append(token_count - 1) extra_token_pos["_end"].append(token_count - 1) token_seq.extend( [self.special_token_map[""]] ) extra_token_pos["joint_img_sep"].append(token_count) token_count += 1 token_seq.extend( [self.img_token_id] * source['length'][1] ) extra_token_pos["_start"].append(token_count) extra_token_pos["_start"].append(token_count) token_count += source['length'][1] extra_token_pos["_end"].append(token_count - 1) extra_token_pos["_end"].append(token_count - 1) extra_token_pos["_end"].append(token_count - 1) token_seq.extend( [self.eoi_token_id] ) extra_token_pos["eoi"].append(token_count) token_count += 1 # else: raise ValueError(f"Not supported key: {key}") index_indicator[key] += 1 if add_eos is True and not drop_last_break: # Typically used for t2i task. token_seq.append(self.eos_token_id) extra_token_pos["eos"].append(token_count) token_count += 1 elif add_eos == 'auto' and not drop_last_break: # Typically used for lm and mmu task. if token_seq[-1] != self.eos_token_id and (total_length is None or token_count < total_length): token_seq.append(self.eos_token_id) extra_token_pos["eos"].append(token_count) token_count += 1 if total_length: # Check token count and clip sequence if necessary if token_count > total_length and drop_last: # Assert clip position is not in the middle of the block-wise tokens (gen_image, joint_image) for start_key, end_key in [ ("_start", "_end"), ("_start", "_end"), ("_start", "_end"), ("_start", "_end"), ]: if start_key in extra_token_pos and end_key in extra_token_pos: assert all( (start > total_length or end + 1 < total_length) for start, end in zip(extra_token_pos[start_key], extra_token_pos[end_key]) ), ("Clip position should not be in the middle of the image tokens.\n" f"Below is the text:\n{self._shorten_text(self.tokenizer.decode(token_seq))}") token_seq = token_seq[:total_length] # Pad the sequence if necessary pad_num = max(0, total_length - len(token_seq)) if add_pad and pad_num: token_seq.extend([self.pad_token_id] * pad_num) extra_token_pos["first_pad"].append(token_count) return token_seq, extra_token_pos def batch_gen_infer( self, infer_fn, prompt_list: list, negative_prompt_list: list = None, infer_fn_kwargs_list: List[Dict[str, int]] = None, do_classifier_free_guidance=False, condition_repeat_times: int = 1, uncondition_repeat_times: int = 1, ): """ Batch inference for the AR-like model training of the text-to-image/instruction tuning tasks. Parameters ---------- infer_fn: callable Inference function to encode the prompt. prompt_list: list List of prompts. Each element can be a single prompt or a list of prompts passed to the infer_fn. negative_prompt_list: list List of negative prompts. Only used when do_classifier_free_guidance is True. If None, will use token sequence as negative prompt. infer_fn_kwargs_list: List[Dict[str, int]] List of keyword arguments for the infer_fn. do_classifier_free_guidance: bool Whether to do classifier-free guidance. condition_repeat_times: int Support multi-condition. uncondition_repeat_times: int Support multi-uncondition. """ if infer_fn_kwargs_list is None: infer_fn_kwargs_list = [{} for _ in prompt_list] # [n_output, bsz] cond_results_list = None uncond_results_list = None output_type_list = [] for prompt_idx, (prompt, infer_fn_kwargs) in enumerate(zip(prompt_list, infer_fn_kwargs_list)): if not isinstance(prompt, (list, tuple)): prompt = [prompt] cond_kwargs = {"uncond_p": 0.0} if do_classifier_free_guidance else {} results = infer_fn( *prompt, **infer_fn_kwargs, **cond_kwargs, ) output_type_list.append((type(results), len(results) if isinstance(results, (list, tuple)) else 1)) if isinstance(results, dict): raise ValueError("Make batch on dict is not supported. Please return list or tuple for infer_fn.") if not isinstance(results, (list, tuple)): results = (results,) if cond_results_list is None: cond_results_list = [[] for _ in results] uncond_results_list = [[] for _ in results] for i, result in enumerate(results): cond_results_list[i].append(result) if do_classifier_free_guidance: if negative_prompt_list is None: uncond_kwargs = {"uncond_p": 1.0} uncond_results = infer_fn( *prompt, **infer_fn_kwargs, **uncond_kwargs, ) else: negative_prompt = negative_prompt_list[prompt_idx] if not isinstance(negative_prompt, (list, tuple)): negative_prompt = [negative_prompt] uncond_results = infer_fn( *negative_prompt, **infer_fn_kwargs, ) if isinstance(uncond_results, TokenizerEncodeOutput): uncond_results_list.append(uncond_results) else: for i, result in enumerate(uncond_results): uncond_results_list[i].append(result) assert all(output_type_list[0] == n for n in output_type_list), \ f"Number of outputs should be equal for all samples, but got {output_type_list}." output_type, output_num = output_type_list[0] def make_batch(batch_cond_item, batch_uncond_item): # Process each output item to make batch first = batch_cond_item[0] # The first element in the batch if isinstance(first, torch.Tensor): stacked_item = torch.stack(self.pad( batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times, )) elif first is None: assert all(item is None for item in batch_cond_item + batch_uncond_item), \ (f"The first cond item is None, but some items are not None:\n\n" f"condition: {batch_cond_item}\n\n" f"uncondition: {batch_uncond_item}") stacked_item = None elif isinstance(first, (list, tuple)): # If the output item is a list or tuple, we treat it as a whole, and won't make nested batch any more. stacked_item = batch_cond_item * condition_repeat_times + batch_uncond_item * uncondition_repeat_times elif isinstance(first, TokenizerEncodeOutput): stacked_item = {} # Traverse not-None attributes for key in list(first.keys()): merged_list = [cond_item[key] for cond_item in batch_cond_item] * condition_repeat_times + \ [uncond_item[key] for uncond_item in batch_uncond_item] * uncondition_repeat_times if isinstance(first[key], torch.Tensor): if 'mask' in key: pad_val = 0.0 elif key == 'tokens': pad_val = self.special_token_map[""] else: pad_val = False # Should not pad for other tensors stacked_item[key] = torch.stack(self.pad(merged_list, pad_val=pad_val), dim=0) elif isinstance(first[key], list): stacked_item[key] = merged_list elif first[key] is None: pass else: raise ValueError(f"Unsupported type of {key}: {type(first[key])}.") stacked_item = TokenizerEncodeOutput(stacked_item) else: raise TypeError(f"Making batch on type {type(first)} is not supported.") return stacked_item stacked_outputs = [] for cond_results, uncond_results in zip(cond_results_list, uncond_results_list): stacked_outputs.append(make_batch(cond_results, uncond_results)) if output_type == list: return stacked_outputs elif output_type == tuple: return tuple(stacked_outputs) elif output_num == 1: return stacked_outputs[0] else: raise ValueError(f"Unsupported output type: {output_type}.") @staticmethod def parse_extra_token_pos(extra_token_pos, prefix, tokens, rng=None): if rng is None: rng = slice(None) image_slices = [ slice(start, end + 1) for start, end in zip(extra_token_pos[f'<{prefix}>_start'][rng], extra_token_pos[f'<{prefix}>_end'][rng]) ] if f'<{prefix}>_start' in extra_token_pos and f'<{prefix}>_end' in extra_token_pos else [] if image_slices: image_mask = torch.zeros_like(tokens, dtype=torch.bool) for image_slice in image_slices: image_mask[image_slice] = True else: image_mask = None return image_slices, image_mask def encode_general( self, sections: Optional[List[Dict[str, Any]]] = None, max_token_length: Optional[int] = None, add_eos='auto', use_text_mask=True, add_pad='auto', add_bos=True, drop_last='auto', ): """ General encode function to encode a sequence with multiple sections of text and images. Each section is a dict with a `type` key and other keys depending on the type. Supported section types: - text: dict with keys: - text: str or List[int], text to be encoded. Either `text` or `tokens` should be provided. - tokens: List[int], pre-encoded text tokens. Either `text` or `tokens` should be provided. - uncond_enabled: bool, whether to enable uncondition for this text section. - uncond_p: float, probability to drop the text section for uncondition. - max_length: int, maximum length of the text section. - ignore: bool, whether to ignore this text section in the text mask. - start_offset: int, start offset of the text mask. - end_offset: int, end offset of the text mask. - gen_image: dict with keys: - token_length: int, number of image tokens. - add_timestep_token: bool, whether to add timestep token before the image tokens. - add_guidance_token: bool, whether to add guidance token before the image tokens. - use_front_boi_token: bool, whether to put the token at the front of size, ratio and timestep tokens. - add_image_shape_token: bool, whether to add image shape token before the image tokens. - base_size: int, base size of the image. - ratio_idx: int, ratio index of the image. - joint_image: dict with keys: - token_length: List[int], number of image tokens for the two images. - add_timestep_token: bool, whether to add timestep token before the image tokens. - use_front_boi_token: bool, whether to put the token at the front of size, ratio and timestep tokens. - add_image_shape_token: bool, whether to add image shape token before the image tokens. - base_size: int, base size of the image. - ratio_idx: int, ratio index of the image. Parameters ---------- sections: List[Dict[str, Any]] List of sections to be encoded. max_token_length: int Maximum length of the encoded token sequence. add_eos: bool or 'auto' Whether to add eos token at the end of the sequence. If True, always add eos token. If 'auto', add eos token only when the total_length is not reached and the last token is not . use_text_mask: bool Whether to generate text mask. add_pad: bool or 'auto' Whether to add padding tokens to the sequence. If True and total_length is not reached, add padding tokens. add_bos: bool Whether to add bos token at the beginning of the sequence. drop_last: bool or 'auto' - If auto, drop last tokens exceeding the total_length if the total_length is provided. If cut point is in the middle of the image tokens, an error will raised. - If True, drop last tokens exceeding the total_length. If cut point is in the middle of the image tokens, all the successive image tokens will be dropped. - If False, keep the last tokens exceeding the total_length, even if the total_length is reached. Returns ------- TokenizerEncodeOutput Encoded token sequence and extra information. """ if sections is None: raise ValueError("sections must be provided.") template = '-'.join([section['type'] for section in sections]) sections = deepcopy(sections) token_source = defaultdict(list) text_mask_specs = [] for section in sections: if section['type'] == 'text': text = self.encode_text( section['text'] if 'text' in section else section['tokens'], uncond_enabled=section.get('uncond_enabled'), uncond_p=section.get('uncond_p'), max_length=section.get('max_length'), ) token_source['text'].append(text) text_mask_specs.append(dict( ignore=section.get('ignore', False), start_offset=section.get('start_offset', 0), end_offset=section.get('end_offset', 0), )) elif section['type'] == 'gen_image': token_source['gen_image'].append(dict( length=section['token_length'], timestep=section.get('add_timestep_token', False), guidance=section.get('add_guidance_token', False), front_boi=section.get('use_front_boi_token', False), image_shape=section.get('add_image_shape_token', False), base_size=section.get('base_size'), ratio_idx=section.get('ratio_idx'), )) elif section['type'] == 'joint_image': token_source['joint_image'].append(dict( length=section['token_length'], timestep=section.get('add_timestep_token', False), front_boi=section.get('use_front_boi_token', False), image_shape=section.get('add_image_shape_token', False), base_size=section.get('base_size'), ratio_idx=section.get('ratio_idx'), )) else: raise ValueError(f"Invalid section type: {section['type']}") # Combine text and image tokens full_token_seq, extra_token_pos = self.encode_sequence( template=template, token_source=dict(token_source), total_length=max_token_length, add_eos=add_eos, add_pad=add_pad, add_bos=add_bos, drop_last=drop_last, ) full_seq_token_tensor = torch.tensor(full_token_seq, dtype=torch.long) timestep_scatter_index = torch.tensor(extra_token_pos['timestep'], dtype=torch.long) \ if 'timestep' in extra_token_pos else None guidance_scatter_index = torch.tensor(extra_token_pos['guidance'], dtype=torch.long) \ if 'guidance' in extra_token_pos else None cond_timestep_scatter_index = torch.tensor(extra_token_pos['cond_timestep'], dtype=torch.long) \ if 'cond_timestep' in extra_token_pos else None gen_timestep_scatter_index = torch.tensor(extra_token_pos['gen_timestep'], dtype=torch.long) \ if 'gen_timestep' in extra_token_pos else None # Gen image mask gen_image_slices, gen_image_mask = self.parse_extra_token_pos(extra_token_pos, 'img', full_seq_token_tensor) # Joint image joint_image_slices, _ = self.parse_extra_token_pos(extra_token_pos, 'joint_img', full_seq_token_tensor) # Conditional vae image cond_vae_image_slices, cond_vae_image_mask = self.parse_extra_token_pos( extra_token_pos, 'vae_img', full_seq_token_tensor) # Conditional vit image cond_vit_image_slices, cond_vit_image_mask = self.parse_extra_token_pos( extra_token_pos, 'vit_img', full_seq_token_tensor) # All image slices (gen_image, joint_image) all_image_slices = [ slice(start, end + 1) for start, end in zip(extra_token_pos['_start'], extra_token_pos['_end']) ] if '_start' in extra_token_pos and '_end' in extra_token_pos else [] # Text mask text_slices = [ slice(start, end + 1) for start, end in zip(extra_token_pos['_start'], extra_token_pos['_end']) ] if '_start' in extra_token_pos and '_end' in extra_token_pos else [] assert len(text_slices) <= len(text_mask_specs), \ (f"Number of text slices ({len(text_slices)}) should be less than or equal to " f"number of text mask specs ({len(text_mask_specs)})") if use_text_mask: text_mask = torch.zeros_like(full_seq_token_tensor, dtype=torch.float32) for text_slice, mask_spec in zip(text_slices, text_mask_specs): if not mask_spec['ignore']: real_slice = slice( text_slice.start + mask_spec['start_offset'], text_slice.stop + mask_spec['end_offset'] ) text_mask[real_slice] = 1.0 else: text_mask = None # real_pos is the first position of the token real_pos = torch.tensor(extra_token_pos.get('first_pad', [full_seq_token_tensor.shape[0]]), dtype=torch.long) return TokenizerEncodeOutput( tokens=full_seq_token_tensor, timestep_scatter_index=timestep_scatter_index, guidance_scatter_index=guidance_scatter_index, text_slices=text_slices, gen_image_slices=gen_image_slices, joint_image_slices=joint_image_slices, cond_vae_image_slices=cond_vae_image_slices, cond_vit_image_slices=cond_vit_image_slices, text_mask=text_mask, gen_image_mask=gen_image_mask, cond_vae_image_mask=cond_vae_image_mask, cond_vit_image_mask=cond_vit_image_mask, real_pos=real_pos, all_image_slices=all_image_slices, cond_timestep_scatter_index=cond_timestep_scatter_index, gen_timestep_scatter_index=gen_timestep_scatter_index, ) def get_cot_sections(self, cot_text, uncond_kwargs, cot_max_length=None, drop_think=False): if not cot_text: # None or empty return [] if '' in cot_text and '' in cot_text: before_think_sec = cot_text.split('')[0] after_think_sec = cot_text.split('')[1] think_sec = cot_text.split('')[1].split('')[0] return self.get_cot_sections(before_think_sec, uncond_kwargs, drop_think=drop_think) + \ ([ dict(type="text", text=""), dict(type="text", text=think_sec, max_length=cot_max_length, **uncond_kwargs), dict(type="text", text="") ] if not drop_think else []) + \ self.get_cot_sections(after_think_sec, uncond_kwargs, drop_think=drop_think) if '' in cot_text and '' in cot_text: before_recaption_sec = cot_text.split('')[0] after_recaption_sec = cot_text.split('')[1] recaption_sec = cot_text.split('')[1].split('')[0] return self.get_cot_sections(before_recaption_sec, uncond_kwargs, drop_think=drop_think) + \ [ dict(type="text", text=""), dict(type="text", text=recaption_sec, max_length=cot_max_length, **uncond_kwargs), dict(type="text", text="") ] + \ self.get_cot_sections(after_recaption_sec, uncond_kwargs, drop_think=drop_think) return [ dict(type="text", text=cot_text, **uncond_kwargs), ] def apply_general_template( self, message_list, max_length=None, add_assistant_prefix=False, answer="auto", bot_task="auto", sequence_template="instruct", uncond_p=0.0, cfg_factor=1, batchify=False, image_base_size=1024, drop_think=False, ): # If cfg_factor > 1, we need to repeat the unconditioned part if batchify: assert isinstance(message_list[0], list), \ f"When batchify is True, message_list should be a list of list, but got [{type(message_list[0])}, ...]." return self.batch_gen_infer( infer_fn=self.apply_general_template, prompt_list=[[]], infer_fn_kwargs_list=[dict( message_list=message_list_i, max_length=max_length, add_assistant_prefix=add_assistant_prefix, answer=answer, bot_task=bot_task, sequence_template=sequence_template, image_base_size=image_base_size, drop_think=drop_think, ) for message_list_i in message_list], do_classifier_free_guidance=cfg_factor > 1, condition_repeat_times=1, uncondition_repeat_times=cfg_factor - 1, ) conv = Conversation() uncond_kwargs = dict(uncond_enabled=uncond_p == 1.0, uncond_p=uncond_p) def process_successive_message(_message_list, _cur_message_idx, role, prefix, suffix, answer_prefix="", answer_suffix=""): _sub_sections = [] while _cur_message_idx < len(message_list) and _message_list[_cur_message_idx]['role'] == role: message = _message_list[_cur_message_idx] if message['type'] == 'text': text = message['content'] if role == "system": _sub_sections.append(dict(type="text", text=text)) elif role == "assistant": if ("" in text and "" in text) or ( "" in text and "" in text): _sub_sections.extend(self.get_cot_sections(text, uncond_kwargs, drop_think=drop_think)) else: _sub_sections.append(dict(type="text", text=text, **uncond_kwargs)) else: _sub_sections.append(dict( type="text", text=f"{answer_prefix}{text}{answer_suffix}", **uncond_kwargs)) elif message['type'] == 'gen_image': info = message['content'] assert isinstance(info, ImageInfo), f"Expected ImageInfo, but got {type(info)}" if role == "assistant": _sub_sections.append(dict(type="text", text=answer_prefix)) _sub_sections.append(dict(type=message['type'], **info.meta_info)) if role == "assistant": _sub_sections.append(dict(type="text", text=answer_suffix)) elif message['type'] == 'joint_image': info = message['content'] assert isinstance(info, JointImageInfo), f"Expected JointImageInfo, but got {type(info)}" _sub_sections.append(dict(type=message['type'], **info.meta_info)) else: raise ValueError(f"Unknown message type: {message['type']}") _cur_message_idx += 1 if len(_sub_sections) > 0: # Add role prefix and suffix _sub_sections.insert(0, dict(type='text', text=prefix)) _sub_sections.append(dict(type='text', text=suffix)) return _sub_sections, _cur_message_idx # Define assistant prefix and suffix if (answer == "auto" and sequence_template == "instruct") or answer is True: answer_prefix, answer_suffix = "", "" else: answer_prefix, answer_suffix = "", "" if sequence_template == "pretrain": system_suffix = "" user_prefix = "" user_suffix = "" bot_prefix = "" bot_suffix = "" else: system_suffix = f"{conv.sep}" user_prefix = f"{conv.roles[0]}: " user_suffix = f"{conv.sep}" bot_prefix = f"{conv.roles[1]}: " bot_suffix = f"{conv.sep}" # Process successive user and assistant messages sections = [] cur_message_idx = 0 final_role = None while cur_message_idx < len(message_list): # Process successive system messages sub_sections, cur_message_idx = process_successive_message( message_list, cur_message_idx, role="system", prefix="", suffix=system_suffix) # Add to the template and sections sections.extend(sub_sections) if len(sub_sections) > 0: final_role = "system" # Process successive user messages sub_sections, cur_message_idx = process_successive_message( message_list, cur_message_idx, role="user", prefix=user_prefix, suffix=user_suffix) # Add to the template and sections sections.extend(sub_sections) if len(sub_sections) > 0: final_role = "user" # Process successive assistant messages sub_sections, cur_message_idx = process_successive_message( message_list, cur_message_idx, role="assistant", prefix=bot_prefix, suffix=bot_suffix, answer_prefix=answer_prefix, answer_suffix=answer_suffix, ) # Add to the template and sections sections.extend(sub_sections) if len(sub_sections) > 0: final_role = "assistant" if add_assistant_prefix: if final_role == "assistant": # Avoid adding prefix twice _bot_prefix = "" # Remove the final bot_suffix if len(sections) > 0 and sections[-1]['type'] == 'text' and sections[-1]['text'] == bot_suffix: sections = sections[:-1] else: _bot_prefix = bot_prefix # We can add special tokens for the bot lastest message according to different tasks bot_response_prefix = dict( auto=_bot_prefix, image="", think=f"{_bot_prefix}", recaption=f"{_bot_prefix}", img_ratio=f"{_bot_prefix}{answer_prefix}", )[bot_task] sections.append(dict(type='text', text=bot_response_prefix)) output = self.encode_general( sections=sections, use_text_mask=False, add_eos=False, add_pad=False, ) if max_length is not None: if output.tokens.shape[-1] > max_length: raise ValueError( f"Encoded token length {output.tokens.shape[-1]} exceeds max_length {max_length}.\n" f"Please set a larger max_length or check the input messages:\n{message_list}" ) return output, sections def apply_chat_template( self, batch_prompt: Optional[List[str]] = None, batch_message_list: Optional[List[List[Dict[str, Any]]]] = None, mode: str = "gen_text", batch_gen_image_info: Optional[List[ImageInfo]] = None, batch_cond_image_info: Optional[Union[List[JointImageInfo], List[List[JointImageInfo]]]] = None, batch_system_prompt: Optional[List[str]] = None, batch_cot_text: Optional[List[str]] = None, max_length: Optional[int] = None, bot_task: str = "auto", # auto/image/think/recaption/img_ratio image_base_size: int = 1024, sequence_template: str = "pretrain", cfg_factor: int = 1, add_assistant_prefix: Optional[bool] = None, drop_think: bool = False, ) -> Dict[str, Any]: assert bot_task in ["image", "auto", "think", "recaption", "img_ratio"], \ f"bot_task should be one of ['image', 'auto', 'think', 'recaption', 'img_ratio'], but got {bot_task}." if batch_message_list is None: # Simple text-to-image or text-cot-to-image task batch_size = len(batch_prompt) # Batchify inputs if not isinstance(batch_system_prompt, list): batch_system_prompt = [batch_system_prompt] * batch_size if not isinstance(batch_gen_image_info, list): batch_gen_image_info = [batch_gen_image_info] * batch_size if batch_cot_text is not None: assert len(batch_cot_text) == batch_size, \ (f"batch_cot_text should have the same length as batch_size ({batch_size}), " f"but got {len(batch_cot_text)}.") else: batch_cot_text = [None] * batch_size if batch_cond_image_info is not None: assert len(batch_cond_image_info) == batch_size, \ (f"batch_cond_image_info should have the same length as batch_size ({batch_size}), " f"but got {len(batch_cond_image_info)}.") batch_cond_image_info = [ cond_image_info if isinstance(cond_image_info, list) else [cond_image_info] for cond_image_info in batch_cond_image_info ] else: batch_cond_image_info = [[] for _ in range(batch_size)] # Convert single round materials into standard message list batch_message_list = [] for prompt, system_prompt, cot_text, gen_image_info, cond_image_info_list in zip( batch_prompt, batch_system_prompt, batch_cot_text, batch_gen_image_info, batch_cond_image_info, ): message_list = [] # 1. system prompt section if system_prompt: message_list.append(dict( role="system", type="text", content=system_prompt, context_type="str")) # 2. user inputs sections # 2.1 image inputs if len(cond_image_info_list) > 0: message_list.extend([ dict(role="user", type="joint_image", content=cond_image_info, context_type="image_info") for cond_image_info in cond_image_info_list ]) # 2.2 text inputs message_list.append(dict( role="user", type="text", content=prompt, context_type="str")) # 3. assistant answer sections if cot_text is not None: message_list.append(dict(role="assistant", type="text", content=cot_text, context_type="str")) if mode == "gen_image": message_list.append(dict( role="assistant", type="gen_image", content=gen_image_info, context_type="image_info")) # --- batch_message_list.append(message_list) output, sections = self.apply_general_template( message_list=batch_message_list, max_length=max_length, add_assistant_prefix=default(add_assistant_prefix, mode != "gen_image"), bot_task=bot_task, sequence_template=sequence_template, cfg_factor=cfg_factor, batchify=True, image_base_size=image_base_size, drop_think=drop_think, ) return dict(output=output, sections=sections)