| | """ |
| | LLaDA configuration |
| | """ |
| | from transformers import AutoConfig, PretrainedConfig |
| |
|
| | from enum import Enum |
| | from os import PathLike |
| | from typing import Union |
| | from dataclasses import asdict, dataclass, field |
| | from glob import glob |
| | from pathlib import Path |
| | from typing import ( |
| | Any, |
| | Dict, |
| | Iterable, |
| | List, |
| | Optional, |
| | Tuple, |
| | Type, |
| | TypeVar, |
| | Union, |
| | cast, |
| | ) |
| |
|
| |
|
| | __all__ = [ |
| | "ActivationType", |
| | "ActivationCheckpointingStrategy", |
| | "BlockType", |
| | "LayerNormType", |
| | "InitFnType", |
| | "ModelConfig", |
| | ] |
| |
|
| | PathOrStr = Union[str, PathLike] |
| |
|
| |
|
| | class StrEnum(str, Enum): |
| | """ |
| | This is equivalent to Python's :class:`enum.StrEnum` since version 3.11. |
| | We include this here for compatibility with older version of Python. |
| | """ |
| |
|
| | def __str__(self) -> str: |
| | return self.value |
| |
|
| | def __repr__(self) -> str: |
| | return f"'{str(self)}'" |
| |
|
| |
|
| | class LayerNormType(StrEnum): |
| | default = "default" |
| | """ |
| | The default LayerNorm implementation, equivalent to PyTorch's built-in version. |
| | """ |
| |
|
| | low_precision = "low_precision" |
| | """ |
| | A low-precision version of the default LayerNorm. |
| | """ |
| |
|
| | rms = "rms" |
| | """ |
| | An RMSNorm implementation. When using ``torch.compile`` this is |
| | probably the fastest implementation. |
| | """ |
| |
|
| | gemma_rms = "gemma_rms" |
| | """ |
| | An RMSNorm implementation by gemmma. When using ``torch.compile`` this is |
| | probably the fastest implementation. |
| | """ |
| |
|
| | amd_compatible = "amd_compatible" |
| | """ |
| | LayerNorm implemented manually to work around an issue with ROCm. |
| | """ |
| |
|
| |
|
| | class ActivationType(StrEnum): |
| | gelu = "gelu" |
| | relu = "relu" |
| | silu = "silu" |
| | swiglu = "swiglu" |
| |
|
| |
|
| | class BlockType(StrEnum): |
| | sequential = "sequential" |
| | parallel = "parallel" |
| |
|
| | llama = "llama" |
| | """ |
| | A block similar to the sequential block with slightly different |
| | implementations of operations like attention to imitate the behavior of Llama. |
| | """ |
| |
|
| |
|
| | class InitFnType(StrEnum): |
| | mitchell = "mitchell" |
| | """ |
| | The strategy suggested to us by Mitchell Wortsman from UW. |
| | This uses a truncated normal distribution with an adaptive standard deviation that depends |
| | on the size of the weights as well as the depth of the layer. |
| | """ |
| |
|
| | normal = "normal" |
| | """ |
| | All weights are initialized from the same normal distribution. |
| | """ |
| |
|
| | kaiming_normal = "kaiming_normal" |
| | """ |
| | All weights are initialized with the Kaiming method from a normal distribution. |
| | Note this currently won't work with FSDP. |
| | """ |
| |
|
| | fan_in = "fan_in" |
| | """ |
| | "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in`` |
| | is the input dimensionality of the kernel. |
| | """ |
| |
|
| | full_megatron = "full_megatron" |
| | """ |
| | This is what metaseq calls "full megatron init". It is the init used for Llama 2. |
| | """ |
| |
|
| |
|
| | @dataclass |
| | class ModelConfig(): |
| | """ |
| | LLaDA (model) configuration. |
| | """ |
| |
|
| | |
| |
|
| | d_model: int = 768 |
| | """ |
| | The hidden size of the model. |
| | """ |
| |
|
| | n_heads: int = 12 |
| | """ |
| | The number of self-attention heads. |
| | """ |
| |
|
| | n_kv_heads: Optional[int] = None |
| | """ |
| | The number of heads to use for keys and values. Defaults to `n_heads`. |
| | Set this to ``None`` or ``n_heads`` for normal multi-head attention. |
| | Set this to 1 for multi-query attention. |
| | Set it to some in-between value for Llama2-style grouped query attention. |
| | """ |
| |
|
| | n_layers: int = 12 |
| | """ |
| | The number of layers/blocks. |
| | """ |
| |
|
| | mlp_ratio: int = 4 |
| | """ |
| | The ratio of the inner MLP dimensionality to ``d_model``. |
| | This is only used when ``mlp_hidden_size`` is not set. |
| | """ |
| |
|
| | mlp_hidden_size: Optional[int] = None |
| | """ |
| | Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`. |
| | """ |
| |
|
| | activation_type: ActivationType = ActivationType.swiglu |
| | """ |
| | The activation function to use within the MLP layers. |
| | """ |
| |
|
| | block_type: BlockType = BlockType.sequential |
| | """ |
| | The transformer block implementation. |
| | """ |
| |
|
| | block_group_size: int = 1 |
| | """ |
| | The number of blocks to group together into a single parent block. |
| | This has no affect on the number of parameters in the model and is only used to wrap groups |
| | of blocks together with a single FSDP wrapper during training. |
| | """ |
| |
|
| | alibi: bool = False |
| | """ |
| | If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``. |
| | """ |
| |
|
| | alibi_bias_max: float = 8.0 |
| | """ |
| | Maximum absolute value of ALiBi bias. |
| | """ |
| |
|
| | rope: bool = False |
| | """ |
| | Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``. |
| | """ |
| |
|
| | rope_full_precision: bool = True |
| | """ |
| | If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise, |
| | apply RoPE at the precision of the input. |
| | """ |
| |
|
| | flash_attention: bool = False |
| | """ |
| | If ``True``, use ``FlashAttention``. |
| | """ |
| |
|
| | attention_dropout: float = 0.1 |
| | """ |
| | The dropout probability within the attention modules. |
| | """ |
| |
|
| | multi_query_attention: Optional[bool] = None |
| | """ |
| | Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters |
| | and is more efficient during inference. |
| | """ |
| |
|
| | attention_layer_norm: bool = False |
| | """ |
| | Apply layer norm to the keys and queries within the attention mechanism. |
| | This can help stabilize training. |
| | """ |
| |
|
| | residual_dropout: float = 0.1 |
| | """ |
| | The dropout probability for the MLP and attention output within each block. |
| | """ |
| |
|
| | embedding_dropout: float = 0.1 |
| | """ |
| | The dropout probability for embeddings. |
| | """ |
| |
|
| | input_emb_norm: bool = False |
| | """ |
| | An input hidden_states norm implementation by gemmma. |
| | """ |
| |
|
| | layer_norm_type: LayerNormType = LayerNormType.default |
| | """ |
| | The layernorm implementation to use. |
| | """ |
| |
|
| | layer_norm_with_affine: bool = True |
| | """ |
| | Whether to include bias and weight parameters for the layer norms. |
| | This only affects layer norms that are immediately followed by a linear layer in the forward pass, |
| | so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine` |
| | to ``False``. |
| | """ |
| |
|
| | rms_norm_eps: float = 1e-05 |
| | """ |
| | The rms layernorm eps param. |
| | """ |
| |
|
| | attention_layer_norm_with_affine: bool = True |
| | """ |
| | Toggle affine transform for the QK norms. |
| | """ |
| |
|
| | max_sequence_length: int = 1024 |
| | """ |
| | The maximum input sequence length supported by the model. |
| | """ |
| |
|
| | rope_theta: float = 10000.0 |
| | """ |
| | The rope base param. |
| | """ |
| |
|
| | include_qkv_bias: Optional[bool] = False |
| | """ |
| | Whether or not to include bias parameters in qkv linear layers. |
| | """ |
| |
|
| | include_bias: bool = False |
| | """ |
| | Whether or not to include bias parameters in linear layers. |
| | In PaLM, they got rid of all bias terms because they found that large |
| | models tend to have near 0 bias terms anyway. |
| | """ |
| |
|
| | bias_for_layer_norm: Optional[bool] = None |
| | """ |
| | Whether or not to include bias parameters in layer norm. |
| | This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in |
| | layer norm. |
| | When this is None (the default), it inherits the setting from include_bias. |
| | """ |
| |
|
| | scale_logits: bool = False |
| | """ |
| | If ``True``, scale the output logits by ``1 / sqrt(d_model)``. |
| | """ |
| |
|
| | vocab_size: int = 50257 |
| | """ |
| | Vocabulary size of the model. |
| | """ |
| |
|
| | embedding_size: Optional[int] = 50304 |
| | """ |
| | The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default |
| | to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the |
| | next multiple of 128 that's greater than ``vocab_size`` can improve throughput |
| | substantially. |
| | """ |
| |
|
| | weight_tying: bool = True |
| | """ |
| | Whether to tie output linear weights to the input embedding. |
| | """ |
| |
|
| | eos_token_id: int = 50256 |
| | """ |
| | The ID of the end-of-sentence special token. |
| | """ |
| |
|
| | pad_token_id: int = 50256 |
| | """ |
| | The ID of the token to use for padding. Defaults to the ID of the EOS token. |
| | """ |
| |
|
| | mask_token_id: Optional[int] = 50256 |
| | """ |
| | The ID of the token to use for mask token. Defaults to the ID of the EOS token. |
| | """ |
| |
|
| | init_device: Optional[str] = None |
| | """ |
| | The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta". |
| | """ |
| |
|
| | init_fn: InitFnType = InitFnType.normal |
| | """ |
| | The weight initialization strategy. |
| | """ |
| |
|
| | init_std: float = 0.02 |
| | """ |
| | The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such |
| | as "normal". |
| | """ |
| |
|
| | init_cutoff_factor: Optional[float] = None |
| | """ |
| | A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such |
| | as "normal". Setting this to None means values are not cutoff. |
| | """ |
| |
|
| | precision: Optional[str] = None |
| | """ |
| | Precision used to train/evaluate with. You shouldn't set this directly. |
| | See :data:`TrainConfig.precision` instead. |
| | """ |
| |
|
| | @property |
| | def effective_n_kv_heads(self) -> int: |
| | if self.n_kv_heads is None: |
| | if self.multi_query_attention is True: |
| | return 1 |
| | else: |
| | return self.n_heads |
| | else: |
| | if self.multi_query_attention is None: |
| | return self.n_kv_heads |
| | if self.multi_query_attention: |
| | n_kv_heads_should_be = 1 |
| | else: |
| | n_kv_heads_should_be = self.n_heads |
| | if self.n_kv_heads == n_kv_heads_should_be: |
| | return n_kv_heads_should_be |
| | else: |
| | raise Exception( |
| | "You can't set `multi_query_attention` and `n_kv_heads` at the same time." |
| | ) |
| |
|
| | class ActivationCheckpointingStrategy(StrEnum): |
| | whole_layer = "whole_layer" |
| | """ |
| | Checkpoint every transformer layer. |
| | """ |
| |
|
| | one_in_two = "one_in_two" |
| | """ |
| | Checkpoint one in two transformer layers. |
| | """ |
| |
|
| | one_in_three = "one_in_three" |
| | """ |
| | Checkpoint one in three transformer layers. |
| | """ |
| |
|
| | one_in_four = "one_in_four" |
| | """ |
| | Checkpoint one in four transformer layers. |
| | """ |
| | |
| | two_in_three = "two_in_three" |
| | """ |
| | Checkpoint two out of every three transformer layers. |
| | """ |
| |
|
| | three_in_four = "three_in_four" |
| | """ |
| | Checkpoint three out of four of every transformer layers. |
| | """ |
| |
|
| | four_in_five = "four_in_five" |
| | """ |
| | Checkpoint four out of five of every transformer layers. |
| | """ |
| |
|
| | nine_in_ten = "nine_in_ten" |
| | """ |
| | Checkpoint nine out of ten of every transformer layers. |
| | """ |
| |
|
| | fine_grained = "fine_grained" |
| | """ |
| | Focus checkpointing on where it is cheap to recompute and saves most memory. |
| | """ |
| |
|
| |
|
| | class LLaDAConfig(PretrainedConfig): |
| | model_type = "llada" |
| | keys_to_ignore_at_inference = ["past_key_values"] |
| |
|
| | def __init__(self, use_cache: bool = False, **kwargs): |
| | model_config = ModelConfig() |
| | all_kwargs = model_config.__dict__ |
| | all_kwargs.update(kwargs) |
| | all_kwargs.update({"use_cache": use_cache}) |
| | all_kwargs.update( |
| | { |
| | "architectures": all_kwargs.get("architectures", ["LLaDAModelLM"]) |
| | } |
| | ) |
| | super().__init__(**all_kwargs) |
| |
|
| | @property |
| | def num_attention_heads(self): |
| | return self.n_heads |
| |
|
| | @property |
| | def num_hidden_layers(self): |
| | return self.n_layers |
| |
|
| | @property |
| | def hidden_size(self): |
| | return self.d_model |
| |
|
| |
|
| | |
| | AutoConfig.register("llada", LLaDAConfig) |
| |
|