|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
from typing import Optional, Tuple, List |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
|
|
from transformers.generation.utils import GenerationMixin |
|
|
from collections import OrderedDict |
|
|
import logging |
|
|
from functools import lru_cache |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
try: |
|
|
from flash_attn import flash_attn_func |
|
|
HAS_FLASH_ATTN = True |
|
|
except ImportError: |
|
|
HAS_FLASH_ATTN = False |
|
|
|
|
|
try: |
|
|
from xformers.ops import memory_efficient_attention |
|
|
HAS_XFORMERS = True |
|
|
except ImportError: |
|
|
HAS_XFORMERS = False |
|
|
|
|
|
HAS_SDPA = hasattr(F, 'scaled_dot_product_attention') |
|
|
|
|
|
|
|
|
class CacaConfig(PretrainedConfig): |
|
|
model_type = "caca" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=32000, |
|
|
hidden_size=2048, |
|
|
intermediate_size=8192, |
|
|
num_hidden_layers=24, |
|
|
num_attention_heads=32, |
|
|
num_key_value_heads=8, |
|
|
head_dim=64, |
|
|
max_position_embeddings=8192, |
|
|
rms_norm_eps=1e-6, |
|
|
qk_norm_eps=1e-6, |
|
|
initializer_range=0.02, |
|
|
use_cache=True, |
|
|
pad_token_id=None, |
|
|
bos_token_id=1, |
|
|
eos_token_id=2, |
|
|
tie_word_embeddings=False, |
|
|
rope_theta=10000.0, |
|
|
rope_scaling=None, |
|
|
use_rotary_embeddings=True, |
|
|
attention_bias=False, |
|
|
attention_dropout=0.0, |
|
|
use_qk_norm=True, |
|
|
use_alibi=False, |
|
|
use_flash_attn=True, |
|
|
use_grouped_query_attention=False, |
|
|
use_multi_query_attention=False, |
|
|
sliding_window=None, |
|
|
use_longformer_attention=False, |
|
|
longformer_attention_window=512, |
|
|
attn_logit_softcapping=None, |
|
|
final_logit_softcapping=None, |
|
|
attention_sink_size=4, |
|
|
attention_sink_window=1024, |
|
|
use_attention_sink=False, |
|
|
attention_pattern="all_global", |
|
|
global_attention_every_n_layers=2, |
|
|
mlp_bias=False, |
|
|
hidden_dropout=0.1, |
|
|
residual_dropout=0.1, |
|
|
use_moe=False, |
|
|
num_experts=8, |
|
|
num_experts_per_tok=2, |
|
|
use_expert_choice=False, |
|
|
expert_choice_k=0.125, |
|
|
router_aux_loss_coef=0.01, |
|
|
router_z_loss_coef=0.001, |
|
|
moe_layer_frequency=2, |
|
|
expert_capacity_factor=1.0, |
|
|
use_grouped_moe=False, |
|
|
num_expert_groups=1, |
|
|
use_layer_scale=False, |
|
|
layer_scale_init=1e-5, |
|
|
use_stochastic_depth=False, |
|
|
stochastic_depth_prob=0.1, |
|
|
use_mixture_of_depths=False, |
|
|
mod_capacity_factor=0.5, |
|
|
mod_route_method="learned", |
|
|
use_cross_attention=False, |
|
|
cross_attention_frequency=4, |
|
|
use_multimodal=False, |
|
|
vision_config=None, |
|
|
audio_config=None, |
|
|
projector_hidden_size=None, |
|
|
use_soft_merging=False, |
|
|
merge_threshold=0.5, |
|
|
pretraining_tp=1, |
|
|
tensor_parallel_size=1, |
|
|
pipeline_parallel_size=1, |
|
|
chat_template=None, |
|
|
**kwargs |
|
|
): |
|
|
self.vocab_size = vocab_size |
|
|
self.hidden_size = hidden_size |
|
|
self.intermediate_size = intermediate_size |
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.num_key_value_heads = num_key_value_heads |
|
|
self.head_dim = head_dim or (hidden_size // num_attention_heads if hidden_size and num_attention_heads else None) |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.rms_norm_eps = rms_norm_eps |
|
|
self.qk_norm_eps = qk_norm_eps |
|
|
self.initializer_range = initializer_range |
|
|
self.use_cache = use_cache |
|
|
self.pad_token_id = pad_token_id |
|
|
self.bos_token_id = bos_token_id |
|
|
self.eos_token_id = eos_token_id |
|
|
self.tie_word_embeddings = tie_word_embeddings |
|
|
self.rope_theta = rope_theta |
|
|
self.rope_scaling = rope_scaling |
|
|
self.use_rotary_embeddings = use_rotary_embeddings |
|
|
self.attention_bias = attention_bias |
|
|
self.attention_dropout = attention_dropout |
|
|
self.use_qk_norm = use_qk_norm |
|
|
self.use_alibi = use_alibi |
|
|
self.use_flash_attn = use_flash_attn |
|
|
self.use_grouped_query_attention = use_grouped_query_attention |
|
|
self.use_multi_query_attention = use_multi_query_attention |
|
|
self.sliding_window = sliding_window |
|
|
self.use_longformer_attention = use_longformer_attention |
|
|
self.longformer_attention_window = longformer_attention_window |
|
|
self.attn_logit_softcapping = attn_logit_softcapping |
|
|
self.final_logit_softcapping = final_logit_softcapping |
|
|
self.attention_sink_size = attention_sink_size |
|
|
self.attention_sink_window = attention_sink_window |
|
|
self.use_attention_sink = use_attention_sink |
|
|
self.attention_pattern = attention_pattern |
|
|
self.global_attention_every_n_layers = global_attention_every_n_layers |
|
|
self.mlp_bias = mlp_bias |
|
|
self.hidden_dropout = hidden_dropout |
|
|
self.residual_dropout = residual_dropout |
|
|
self.use_moe = use_moe |
|
|
self.num_experts = num_experts |
|
|
self.num_experts_per_tok = num_experts_per_tok |
|
|
self.use_expert_choice = use_expert_choice |
|
|
self.expert_choice_k = expert_choice_k |
|
|
self.router_aux_loss_coef = router_aux_loss_coef |
|
|
self.router_z_loss_coef = router_z_loss_coef |
|
|
self.moe_layer_frequency = moe_layer_frequency |
|
|
self.expert_capacity_factor = expert_capacity_factor |
|
|
self.use_grouped_moe = use_grouped_moe |
|
|
self.num_expert_groups = num_expert_groups |
|
|
self.use_layer_scale = use_layer_scale |
|
|
self.layer_scale_init = layer_scale_init |
|
|
self.use_stochastic_depth = use_stochastic_depth |
|
|
self.stochastic_depth_prob = stochastic_depth_prob |
|
|
self.use_mixture_of_depths = use_mixture_of_depths |
|
|
self.mod_capacity_factor = mod_capacity_factor |
|
|
self.mod_route_method = mod_route_method |
|
|
self.use_cross_attention = use_cross_attention |
|
|
self.cross_attention_frequency = cross_attention_frequency |
|
|
self.use_multimodal = use_multimodal |
|
|
self.vision_config = vision_config or {} |
|
|
self.audio_config = audio_config or {} |
|
|
self.projector_hidden_size = projector_hidden_size or hidden_size |
|
|
self.use_soft_merging = use_soft_merging |
|
|
self.merge_threshold = merge_threshold |
|
|
self.pretraining_tp = pretraining_tp |
|
|
self.tensor_parallel_size = tensor_parallel_size |
|
|
self.pipeline_parallel_size = pipeline_parallel_size |
|
|
|
|
|
if chat_template is None: |
|
|
self.chat_template = ( |
|
|
"{% for message in messages %}" |
|
|
"{% if message['role'] == 'system' %}" |
|
|
"System: {{ message['content'] }}\n" |
|
|
"{% elif message['role'] == 'user' %}" |
|
|
"User: {{ message['content'] }}\n" |
|
|
"{% elif message['role'] == 'assistant' %}" |
|
|
"Assistant: {{ message['content'] }}\n" |
|
|
"{% endif %}" |
|
|
"{% endfor %}" |
|
|
"{% if add_generation_prompt %}Assistant:{% endif %}" |
|
|
) |
|
|
else: |
|
|
self.chat_template = chat_template |
|
|
|
|
|
self._validate_config() |
|
|
super().__init__( |
|
|
pad_token_id=pad_token_id, |
|
|
bos_token_id=bos_token_id, |
|
|
eos_token_id=eos_token_id, |
|
|
tie_word_embeddings=tie_word_embeddings, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
def _validate_config(self): |
|
|
if self.num_attention_heads % self.num_key_value_heads != 0: |
|
|
raise ValueError( |
|
|
f"num_attention_heads ({self.num_attention_heads}) harus habis dibagi " |
|
|
f"num_key_value_heads ({self.num_key_value_heads})" |
|
|
) |
|
|
|
|
|
if self.use_moe and self.num_experts < self.num_experts_per_tok: |
|
|
raise ValueError( |
|
|
f"num_experts ({self.num_experts}) harus >= " |
|
|
f"num_experts_per_tok ({self.num_experts_per_tok})" |
|
|
) |
|
|
|
|
|
if self.hidden_size % self.num_attention_heads != 0: |
|
|
raise ValueError( |
|
|
f"hidden_size ({self.hidden_size}) harus habis dibagi " |
|
|
f"num_attention_heads ({self.num_attention_heads})" |
|
|
) |
|
|
|
|
|
if self.vocab_size <= 0: |
|
|
raise ValueError(f"vocab_size harus > 0, dapat {self.vocab_size}") |
|
|
|
|
|
if self.use_flash_attn and not HAS_FLASH_ATTN: |
|
|
logger.warning( |
|
|
"use_flash_attn=True tapi flash-attn tidak terinstall. " |
|
|
"Akan fallback ke SDPA/standard attention." |
|
|
) |
|
|
|
|
|
if self.sliding_window is not None: |
|
|
if self.sliding_window > self.max_position_embeddings: |
|
|
raise ValueError( |
|
|
f"sliding_window ({self.sliding_window}) tidak boleh > " |
|
|
f"max_position_embeddings ({self.max_position_embeddings})" |
|
|
) |
|
|
|
|
|
if self.use_moe: |
|
|
if self.moe_layer_frequency <= 0: |
|
|
raise ValueError(f"moe_layer_frequency harus > 0") |
|
|
if self.moe_layer_frequency > self.num_hidden_layers: |
|
|
logger.warning( |
|
|
f"moe_layer_frequency ({self.moe_layer_frequency}) > " |
|
|
f"num_hidden_layers ({self.num_hidden_layers}). " |
|
|
f"MoE tidak akan digunakan." |
|
|
) |
|
|
|
|
|
def to_dict(self): |
|
|
has_quant_config = hasattr(self, 'quantization_config') |
|
|
quantization_config_backup = getattr(self, 'quantization_config', None) |
|
|
|
|
|
if has_quant_config and quantization_config_backup is None: |
|
|
delattr(self, 'quantization_config') |
|
|
|
|
|
try: |
|
|
output = super().to_dict() |
|
|
output['auto_map'] = { |
|
|
"AutoConfig": "caca_transformers.CacaConfig", |
|
|
"AutoModel": "caca_transformers.CacaModel", |
|
|
"AutoModelForCausalLM": "caca_transformers.CacaForCausalLM" |
|
|
} |
|
|
finally: |
|
|
if has_quant_config: |
|
|
self.quantization_config = quantization_config_backup |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class CacaRMSNorm(nn.Module): |
|
|
def __init__(self, hidden_size, eps=1e-6): |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.ones(hidden_size)) |
|
|
self.eps = eps |
|
|
|
|
|
def forward(self, x): |
|
|
input_dtype = x.dtype |
|
|
x = x.float() |
|
|
variance = x.pow(2).mean(-1, keepdim=True) |
|
|
x = x * torch.rsqrt(variance + self.eps) |
|
|
return (self.weight * x).to(input_dtype) |
|
|
|
|
|
class LayerScale(nn.Module): |
|
|
def __init__(self, dim, init_value=1e-5): |
|
|
super().__init__() |
|
|
self.gamma = nn.Parameter(init_value * torch.ones(dim)) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.gamma * x |
|
|
|
|
|
class StochasticDepth(nn.Module): |
|
|
def __init__(self, drop_prob=0.0): |
|
|
super().__init__() |
|
|
self.drop_prob = drop_prob |
|
|
|
|
|
def forward(self, x, training=True): |
|
|
if not training or self.drop_prob == 0.0: |
|
|
return x |
|
|
keep_prob = 1 - self.drop_prob |
|
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
|
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
|
|
random_tensor.floor_() |
|
|
return x.div(keep_prob) * random_tensor |
|
|
|
|
|
class CacaRotaryEmbedding(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
dim, |
|
|
max_position_embeddings=8192, |
|
|
base=10000.0, |
|
|
scaling_factor=1.0, |
|
|
scaling_type=None, |
|
|
): |
|
|
super().__init__() |
|
|
self.dim = dim |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.base = base |
|
|
self.scaling_factor = scaling_factor |
|
|
self.scaling_type = scaling_type |
|
|
inv_freq = 1.0 / ( |
|
|
self.base ** (torch.arange(0, self.dim, 2).float() / self.dim) |
|
|
) |
|
|
if scaling_type == "linear": |
|
|
inv_freq = inv_freq / scaling_factor |
|
|
elif scaling_type == "dynamic": |
|
|
inv_freq = inv_freq |
|
|
elif scaling_type == "yarn": |
|
|
inv_freq = self._yarn_get_inv_freq(inv_freq) |
|
|
self.register_buffer("inv_freq", inv_freq, persistent=False) |
|
|
|
|
|
def _yarn_get_inv_freq(self, inv_freq): |
|
|
if len(inv_freq) == 0: |
|
|
return inv_freq |
|
|
alpha = self.scaling_factor |
|
|
beta_fast = 32 |
|
|
beta_slow = 1 |
|
|
freq_threshold = 1 / (self.max_position_embeddings * beta_fast) |
|
|
low_freq_mask = inv_freq > freq_threshold |
|
|
high_freq_mask = ~low_freq_mask |
|
|
low_freq = inv_freq[low_freq_mask] |
|
|
high_freq = inv_freq[high_freq_mask] |
|
|
if len(low_freq) > 0: |
|
|
low_freq = low_freq / alpha |
|
|
if len(high_freq) > 0: |
|
|
smooth_factor = ( |
|
|
self.max_position_embeddings * beta_slow / high_freq - beta_fast |
|
|
) / (beta_slow - beta_fast) |
|
|
smooth_factor = torch.clamp(smooth_factor, 0.0, 1.0) |
|
|
high_freq = (1 - smooth_factor) * ( |
|
|
high_freq / alpha |
|
|
) + smooth_factor * high_freq |
|
|
result = torch.zeros_like(inv_freq) |
|
|
result[low_freq_mask] = low_freq |
|
|
result[high_freq_mask] = high_freq |
|
|
return result |
|
|
|
|
|
def forward(self, x, seq_len, position_offset=0): |
|
|
t = torch.arange( |
|
|
position_offset, position_offset + seq_len, device=x.device |
|
|
).type_as(self.inv_freq) |
|
|
if self.scaling_type == "dynamic": |
|
|
if seq_len > self.max_position_embeddings: |
|
|
dynamic_scale = seq_len / self.max_position_embeddings |
|
|
t = t / dynamic_scale |
|
|
freqs = torch.outer(t, self.inv_freq) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
cos = emb.cos()[None, None, :, :] |
|
|
sin = emb.sin()[None, None, :, :] |
|
|
return cos.to(x.dtype), sin.to(x.dtype) |
|
|
|
|
|
class ALiBiPositionalBias(nn.Module): |
|
|
def __init__(self, num_heads, max_positions=8192): |
|
|
super().__init__() |
|
|
self.num_heads = num_heads |
|
|
self.max_positions = max_positions |
|
|
slopes = torch.tensor(self._get_slopes(num_heads)) |
|
|
self.register_buffer("slopes", slopes, persistent=False) |
|
|
|
|
|
def _get_slopes(self, n): |
|
|
def get_slopes_power_of_2(n): |
|
|
start = 2 ** (-(2 ** -(math.log2(n) - 3))) |
|
|
ratio = start |
|
|
return [start * (ratio**i) for i in range(n)] |
|
|
|
|
|
if math.log2(n).is_integer(): |
|
|
return get_slopes_power_of_2(n) |
|
|
else: |
|
|
closest_power_of_2 = 2 ** math.floor(math.log2(n)) |
|
|
return ( |
|
|
get_slopes_power_of_2(closest_power_of_2) |
|
|
+ self._get_slopes(2 * closest_power_of_2)[0::2][ |
|
|
: n - closest_power_of_2 |
|
|
] |
|
|
) |
|
|
|
|
|
def forward(self, seq_len, key_len=None): |
|
|
if key_len is None: |
|
|
key_len = seq_len |
|
|
query_pos = torch.arange(seq_len, device=self.slopes.device).unsqueeze(1) |
|
|
key_pos = torch.arange(key_len, device=self.slopes.device).unsqueeze(0) |
|
|
relative_pos = key_pos - query_pos |
|
|
bias = relative_pos.unsqueeze(0) * self.slopes.unsqueeze(1).unsqueeze(2) |
|
|
return bias.unsqueeze(0) |
|
|
|
|
|
def rotate_half(x): |
|
|
x1 = x[..., : x.shape[-1] // 2] |
|
|
x2 = x[..., x.shape[-1] // 2 :] |
|
|
return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
def apply_rotary_pos_emb(q, k, cos, sin): |
|
|
cos = cos.to(q.dtype) |
|
|
sin = sin.to(q.dtype) |
|
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
|
k_embed = (k * cos) + (rotate_half(k) * sin) |
|
|
return q_embed, k_embed |
|
|
|
|
|
def soft_cap_logits(x, cap): |
|
|
if cap is None or cap <= 0: |
|
|
return x |
|
|
return x.clamp(-cap * 0.99, cap * 0.99) |
|
|
|
|
|
class TopKRouter(nn.Module): |
|
|
def __init__(self, hidden_size, num_experts, num_experts_per_tok): |
|
|
super().__init__() |
|
|
self.num_experts = num_experts |
|
|
self.num_experts_per_tok = num_experts_per_tok |
|
|
self.gate = nn.Linear(hidden_size, num_experts, bias=False) |
|
|
self.gate_norm = nn.LayerNorm(hidden_size) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
batch_size, seq_len, hidden_size = hidden_states.shape |
|
|
hidden_states = hidden_states.view(-1, hidden_size) |
|
|
hidden_states = self.gate_norm(hidden_states) |
|
|
router_logits = self.gate(hidden_states) |
|
|
router_logits = torch.clamp(router_logits, min=-10, max=10) |
|
|
routing_weights = F.softmax(router_logits, dim=-1) |
|
|
top_k_weights, top_k_indices = torch.topk( |
|
|
routing_weights, self.num_experts_per_tok, dim=-1 |
|
|
) |
|
|
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-8) |
|
|
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) |
|
|
expert_usage = router_probs.mean(dim=0) |
|
|
mean_usage = expert_usage.mean() |
|
|
aux_loss = ((expert_usage - mean_usage) ** 2).sum() / (mean_usage + 1e-10) |
|
|
router_logits_for_z = router_logits.to(torch.float32) |
|
|
z_loss = torch.logsumexp(router_logits_for_z, dim=-1).mean() |
|
|
return top_k_weights, top_k_indices, aux_loss, z_loss |
|
|
|
|
|
class ExpertChoiceRouter(nn.Module): |
|
|
def __init__(self, hidden_size, num_experts, expert_choice_k): |
|
|
super().__init__() |
|
|
self.num_experts = num_experts |
|
|
self.expert_choice_k = expert_choice_k |
|
|
self.gate = nn.Linear(hidden_size, num_experts, bias=False) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
batch_size, seq_len, hidden_size = hidden_states.shape |
|
|
total_tokens = batch_size * seq_len |
|
|
hidden_states_flat = hidden_states.view(-1, hidden_size) |
|
|
router_logits = self.gate(hidden_states_flat) |
|
|
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) |
|
|
router_probs_t = router_probs.t() |
|
|
capacity = max(1, int(self.expert_choice_k * total_tokens / self.num_experts)) |
|
|
top_k_values, top_k_indices = torch.topk( |
|
|
router_probs_t, k=min(capacity, total_tokens), dim=-1 |
|
|
) |
|
|
expert_mask = torch.zeros( |
|
|
self.num_experts, total_tokens, device=hidden_states.device |
|
|
) |
|
|
for expert_idx in range(self.num_experts): |
|
|
expert_mask[expert_idx, top_k_indices[expert_idx]] = 1.0 |
|
|
routing_weights = expert_mask.t() * router_probs |
|
|
aux_loss = (router_probs.mean(dim=0) ** 2).sum() * self.num_experts |
|
|
z_loss = torch.logsumexp(router_logits, dim=-1).mean() |
|
|
return routing_weights, aux_loss, z_loss |
|
|
|
|
|
class Expert(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.gate_proj = nn.Linear( |
|
|
config.hidden_size, config.intermediate_size, bias=config.mlp_bias |
|
|
) |
|
|
self.up_proj = nn.Linear( |
|
|
config.hidden_size, config.intermediate_size, bias=config.mlp_bias |
|
|
) |
|
|
self.down_proj = nn.Linear( |
|
|
config.intermediate_size, config.hidden_size, bias=config.mlp_bias |
|
|
) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
gate = F.silu(self.gate_proj(x)) |
|
|
up = self.up_proj(x) |
|
|
hidden = gate * up |
|
|
hidden = self.dropout(hidden) |
|
|
return self.down_proj(hidden) |
|
|
|
|
|
class MixtureOfExperts(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.num_experts = config.num_experts |
|
|
self.num_experts_per_tok = config.num_experts_per_tok |
|
|
self.use_expert_choice = config.use_expert_choice |
|
|
self.experts = nn.ModuleList([Expert(config) for _ in range(self.num_experts)]) |
|
|
if self.use_expert_choice: |
|
|
self.router = ExpertChoiceRouter( |
|
|
config.hidden_size, config.num_experts, config.expert_choice_k |
|
|
) |
|
|
else: |
|
|
self.router = TopKRouter( |
|
|
config.hidden_size, config.num_experts, config.num_experts_per_tok |
|
|
) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
batch_size, seq_len, hidden_size = hidden_states.shape |
|
|
hidden_states_flat = hidden_states.view(-1, hidden_size) |
|
|
if self.use_expert_choice: |
|
|
routing_weights, aux_loss, z_loss = self.router(hidden_states) |
|
|
final_output = torch.zeros_like(hidden_states_flat) |
|
|
for expert_idx, expert in enumerate(self.experts): |
|
|
expert_mask = routing_weights[:, expert_idx] > 0 |
|
|
if expert_mask.any(): |
|
|
expert_input = hidden_states_flat[expert_mask] |
|
|
expert_output = expert(expert_input) |
|
|
final_output[expert_mask] += ( |
|
|
expert_output |
|
|
* routing_weights[expert_mask, expert_idx : expert_idx + 1] |
|
|
) |
|
|
else: |
|
|
top_k_weights, top_k_indices, aux_loss, z_loss = self.router(hidden_states) |
|
|
final_output = torch.zeros_like(hidden_states_flat) |
|
|
for i in range(self.num_experts_per_tok): |
|
|
expert_indices = top_k_indices[:, i] |
|
|
expert_weights = top_k_weights[:, i : i + 1] |
|
|
for expert_idx in range(self.num_experts): |
|
|
expert_mask = expert_indices == expert_idx |
|
|
if expert_mask.any(): |
|
|
expert_input = hidden_states_flat[expert_mask] |
|
|
expert_output = self.experts[expert_idx](expert_input) |
|
|
final_output[expert_mask] += ( |
|
|
expert_output * expert_weights[expert_mask] |
|
|
) |
|
|
final_output = final_output.view(batch_size, seq_len, hidden_size) |
|
|
return final_output, aux_loss, z_loss |
|
|
|
|
|
class MixtureOfDepthsRouter(nn.Module): |
|
|
def __init__(self, hidden_size, capacity_factor=0.5, route_method="learned"): |
|
|
super().__init__() |
|
|
self.capacity_factor = capacity_factor |
|
|
self.route_method = route_method |
|
|
if route_method == "learned": |
|
|
self.router = nn.Linear(hidden_size, 1) |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
batch_size, seq_len, hidden_size = hidden_states.shape |
|
|
if self.route_method == "learned": |
|
|
routing_logits = self.router(hidden_states).squeeze(-1) |
|
|
elif self.route_method == "random": |
|
|
routing_logits = torch.rand( |
|
|
batch_size, seq_len, device=hidden_states.device |
|
|
) |
|
|
else: |
|
|
routing_logits = torch.zeros( |
|
|
batch_size, seq_len, device=hidden_states.device |
|
|
) |
|
|
capacity = max(1, int(seq_len * self.capacity_factor)) |
|
|
_, top_indices = torch.topk(routing_logits, k=capacity, dim=-1) |
|
|
process_mask = torch.zeros( |
|
|
batch_size, seq_len, dtype=torch.bool, device=hidden_states.device |
|
|
) |
|
|
process_mask.scatter_(1, top_indices, True) |
|
|
return process_mask |
|
|
|
|
|
class CacaAttention(nn.Module): |
|
|
def __init__(self, config, layer_idx=None): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.num_key_value_heads = config.num_key_value_heads |
|
|
self.head_dim = config.head_dim |
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
|
self.sliding_window = config.sliding_window |
|
|
self.attn_logit_softcapping = config.attn_logit_softcapping |
|
|
self.attention_sink_size = config.attention_sink_size |
|
|
self.attention_sink_window = config.attention_sink_window |
|
|
self.q_proj = nn.Linear( |
|
|
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias |
|
|
) |
|
|
self.k_proj = nn.Linear( |
|
|
self.hidden_size, |
|
|
self.num_key_value_heads * self.head_dim, |
|
|
bias=config.attention_bias, |
|
|
) |
|
|
self.v_proj = nn.Linear( |
|
|
self.hidden_size, |
|
|
self.num_key_value_heads * self.head_dim, |
|
|
bias=config.attention_bias, |
|
|
) |
|
|
self.o_proj = nn.Linear( |
|
|
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias |
|
|
) |
|
|
if config.use_qk_norm: |
|
|
self.q_norm = CacaRMSNorm(self.head_dim, eps=config.qk_norm_eps) |
|
|
self.k_norm = CacaRMSNorm(self.head_dim, eps=config.qk_norm_eps) |
|
|
else: |
|
|
self.q_norm = None |
|
|
self.k_norm = None |
|
|
if config.use_rotary_embeddings: |
|
|
scaling_factor = 1.0 |
|
|
scaling_type = None |
|
|
if config.rope_scaling is not None: |
|
|
scaling_type = config.rope_scaling.get("type", "linear") |
|
|
scaling_factor = config.rope_scaling.get("factor", 1.0) |
|
|
self.rotary_emb = CacaRotaryEmbedding( |
|
|
self.head_dim, |
|
|
config.max_position_embeddings, |
|
|
config.rope_theta, |
|
|
scaling_factor=scaling_factor, |
|
|
scaling_type=scaling_type, |
|
|
) |
|
|
else: |
|
|
self.rotary_emb = None |
|
|
if config.use_alibi: |
|
|
self.alibi = ALiBiPositionalBias( |
|
|
self.num_heads, config.max_position_embeddings |
|
|
) |
|
|
else: |
|
|
self.alibi = None |
|
|
self.attention_dropout = nn.Dropout(config.attention_dropout) |
|
|
self.is_global_attention = self._determine_attention_type(config, layer_idx) |
|
|
self.has_flash_attn = HAS_FLASH_ATTN and config.use_flash_attn |
|
|
self.has_xformers = HAS_XFORMERS |
|
|
self.has_sdpa = HAS_SDPA |
|
|
self._mask_cache = {} |
|
|
self._max_cache_size = 10 |
|
|
|
|
|
def _determine_attention_type(self, config, layer_idx): |
|
|
if layer_idx is None: |
|
|
return False |
|
|
if config.attention_pattern == "all_global": |
|
|
return True |
|
|
elif config.attention_pattern == "all_local": |
|
|
return False |
|
|
elif config.attention_pattern == "interleaved": |
|
|
return (layer_idx % config.global_attention_every_n_layers) == ( |
|
|
config.global_attention_every_n_layers - 1 |
|
|
) |
|
|
return False |
|
|
|
|
|
def forward( |
|
|
self, hidden_states, attention_mask=None, past_key_value=None, use_cache=False |
|
|
): |
|
|
batch_size, seq_length, _ = hidden_states.size() |
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(hidden_states) |
|
|
value_states = self.v_proj(hidden_states) |
|
|
query_states = query_states.view( |
|
|
batch_size, seq_length, self.num_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
key_states = key_states.view( |
|
|
batch_size, seq_length, self.num_key_value_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
value_states = value_states.view( |
|
|
batch_size, seq_length, self.num_key_value_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
if self.q_norm is not None and self.k_norm is not None: |
|
|
query_states = self.q_norm(query_states) |
|
|
key_states = self.k_norm(key_states) |
|
|
|
|
|
position_offset = 0 |
|
|
if past_key_value is not None: |
|
|
try: |
|
|
if isinstance(past_key_value, (tuple, list)) and len(past_key_value) >= 2: |
|
|
if past_key_value[0] is not None: |
|
|
position_offset = past_key_value[0].shape[2] |
|
|
except (IndexError, AttributeError, TypeError): |
|
|
position_offset = 0 |
|
|
|
|
|
if self.rotary_emb is not None: |
|
|
cos, sin = self.rotary_emb(query_states, seq_length, position_offset) |
|
|
query_states, key_states = apply_rotary_pos_emb( |
|
|
query_states, key_states, cos, sin |
|
|
) |
|
|
|
|
|
if past_key_value is not None and past_key_value[0] is not None: |
|
|
key_states = torch.cat([past_key_value[0], key_states], dim=2) |
|
|
value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
|
|
|
|
if use_cache: |
|
|
present_key_value = (key_states, value_states) |
|
|
else: |
|
|
present_key_value = None |
|
|
|
|
|
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) |
|
|
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) |
|
|
kv_seq_len = key_states.shape[-2] |
|
|
|
|
|
use_sliding_window = (not self.is_global_attention) and ( |
|
|
self.sliding_window is not None |
|
|
) |
|
|
if self.has_flash_attn and attention_mask is None: |
|
|
if query_states.device.type == "cuda" and query_states.dtype in [ |
|
|
torch.float16, |
|
|
torch.bfloat16, |
|
|
]: |
|
|
try: |
|
|
attn_output = self._flash_attention( |
|
|
query_states, key_states, value_states, use_sliding_window |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning(f"Flash Attention gagal, pakai fallback: {e}") |
|
|
attn_output = self._fallback_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
kv_seq_len, |
|
|
use_sliding_window, |
|
|
) |
|
|
else: |
|
|
attn_output = self._fallback_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
kv_seq_len, |
|
|
use_sliding_window, |
|
|
) |
|
|
else: |
|
|
attn_output = self._fallback_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
kv_seq_len, |
|
|
use_sliding_window, |
|
|
) |
|
|
attn_output = self.o_proj(attn_output) |
|
|
return attn_output, present_key_value |
|
|
|
|
|
def _flash_attention( |
|
|
self, query_states, key_states, value_states, use_sliding_window |
|
|
): |
|
|
batch_size, num_heads, seq_length, head_dim = query_states.shape |
|
|
kv_seq_len = key_states.shape[-2] |
|
|
original_dtype = query_states.dtype |
|
|
if original_dtype == torch.bfloat16: |
|
|
if not torch.cuda.is_bf16_supported(): |
|
|
logger.warning("BF16 not supported on this GPU, falling back to FP16") |
|
|
original_dtype = torch.float16 |
|
|
compute_dtype = ( |
|
|
torch.bfloat16 |
|
|
if original_dtype not in [torch.float16, torch.bfloat16] |
|
|
else original_dtype |
|
|
) |
|
|
query_states = query_states.transpose(1, 2).contiguous().to(compute_dtype) |
|
|
key_states = key_states.transpose(1, 2).contiguous().to(compute_dtype) |
|
|
value_states = value_states.transpose(1, 2).contiguous().to(compute_dtype) |
|
|
if use_sliding_window and self.sliding_window < kv_seq_len: |
|
|
window_size = (self.sliding_window, 0) |
|
|
else: |
|
|
window_size = (-1, 0) |
|
|
attn_output = flash_attn_func( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
dropout_p=self.config.attention_dropout if self.training else 0.0, |
|
|
softmax_scale=None, |
|
|
causal=True, |
|
|
window_size=window_size, |
|
|
) |
|
|
attn_output = attn_output.to(original_dtype) |
|
|
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) |
|
|
return attn_output |
|
|
|
|
|
def _fallback_attention( |
|
|
self, |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
kv_seq_len, |
|
|
use_sliding_window, |
|
|
): |
|
|
device_type = query_states.device.type |
|
|
if self.has_xformers and device_type == "cuda" and attention_mask is None: |
|
|
try: |
|
|
return self._xformers_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
kv_seq_len, |
|
|
use_sliding_window, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.warning(f"xFormers gagal, pakai SDPA: {e}") |
|
|
if self.has_sdpa: |
|
|
return self._sdpa_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
kv_seq_len, |
|
|
use_sliding_window, |
|
|
) |
|
|
else: |
|
|
return self._standard_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
kv_seq_len, |
|
|
use_sliding_window, |
|
|
) |
|
|
|
|
|
def _create_causal_mask( |
|
|
self, query_length, key_length, dtype, device, use_sliding_window |
|
|
): |
|
|
cache_key = ( |
|
|
query_length, |
|
|
key_length, |
|
|
str(dtype), |
|
|
use_sliding_window, |
|
|
self.sliding_window if use_sliding_window else None, |
|
|
) |
|
|
if cache_key in self._mask_cache: |
|
|
cached_mask = self._mask_cache[cache_key] |
|
|
return cached_mask.to(device, dtype) |
|
|
if query_length > key_length: |
|
|
key_length = query_length |
|
|
query_pos = torch.arange(query_length, device=device) + ( |
|
|
key_length - query_length |
|
|
) |
|
|
key_pos = torch.arange(key_length, device=device) |
|
|
distance = query_pos[:, None] - key_pos[None, :] |
|
|
mask = distance < 0 |
|
|
|
|
|
if use_sliding_window and self.sliding_window is not None: |
|
|
if self.config.use_attention_sink and self.attention_sink_size > 0: |
|
|
is_sink = key_pos[None, :] < self.attention_sink_size |
|
|
in_window = (distance >= 0) & (distance <= self.sliding_window) |
|
|
mask = (distance < 0) | ((~is_sink) & (~in_window)) |
|
|
|
|
|
else: |
|
|
too_far_mask = distance > self.sliding_window |
|
|
mask = mask | too_far_mask |
|
|
float_mask = torch.zeros( |
|
|
1, 1, query_length, key_length, dtype=dtype, device=device |
|
|
) |
|
|
float_mask.masked_fill_(mask.unsqueeze(0).unsqueeze(0), -1e9) |
|
|
if len(self._mask_cache) >= self._max_cache_size: |
|
|
oldest_key = next(iter(self._mask_cache)) |
|
|
del self._mask_cache[oldest_key] |
|
|
self._mask_cache[cache_key] = float_mask.detach().cpu() |
|
|
return float_mask |
|
|
|
|
|
def _xformers_attention( |
|
|
self, query_states, key_states, value_states, kv_seq_len, use_sliding_window |
|
|
): |
|
|
batch_size, num_heads, seq_length, head_dim = query_states.shape |
|
|
attn_bias = self._create_causal_mask( |
|
|
seq_length, |
|
|
kv_seq_len, |
|
|
query_states.dtype, |
|
|
query_states.device, |
|
|
use_sliding_window, |
|
|
) |
|
|
query_states = query_states.transpose(1, 2) |
|
|
key_states = key_states.transpose(1, 2) |
|
|
value_states = value_states.transpose(1, 2) |
|
|
attn_output = memory_efficient_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attn_bias=attn_bias, |
|
|
p=self.config.attention_dropout if self.training else 0.0, |
|
|
) |
|
|
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) |
|
|
return attn_output |
|
|
|
|
|
def _sdpa_attention( |
|
|
self, |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
kv_seq_len, |
|
|
use_sliding_window, |
|
|
): |
|
|
batch_size, num_heads, seq_length, head_dim = query_states.shape |
|
|
if attention_mask is None: |
|
|
attention_mask = self._create_causal_mask( |
|
|
seq_length, |
|
|
kv_seq_len, |
|
|
query_states.dtype, |
|
|
query_states.device, |
|
|
use_sliding_window, |
|
|
) |
|
|
if self.alibi is not None: |
|
|
alibi_bias = self.alibi(seq_length, kv_seq_len) |
|
|
attention_mask = attention_mask + alibi_bias |
|
|
attn_output = F.scaled_dot_product_attention( |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attn_mask=attention_mask, |
|
|
dropout_p=self.config.attention_dropout if self.training else 0.0, |
|
|
is_causal=False, |
|
|
) |
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) |
|
|
return attn_output |
|
|
|
|
|
def _standard_attention( |
|
|
self, |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
kv_seq_len, |
|
|
use_sliding_window, |
|
|
): |
|
|
batch_size, num_heads, seq_length, head_dim = query_states.shape |
|
|
attn_weights = torch.matmul( |
|
|
query_states, key_states.transpose(2, 3) |
|
|
) / math.sqrt(head_dim) |
|
|
attn_weights = torch.clamp(attn_weights, min=-50.0, max=50.0) |
|
|
attn_weights = soft_cap_logits(attn_weights, self.attn_logit_softcapping) |
|
|
if attention_mask is None: |
|
|
attention_mask = self._create_causal_mask( |
|
|
seq_length, |
|
|
kv_seq_len, |
|
|
attn_weights.dtype, |
|
|
attn_weights.device, |
|
|
use_sliding_window, |
|
|
) |
|
|
if self.alibi is not None: |
|
|
alibi_bias = self.alibi(seq_length, kv_seq_len) |
|
|
attention_mask = attention_mask + alibi_bias |
|
|
attn_weights = attn_weights + attention_mask |
|
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( |
|
|
query_states.dtype |
|
|
) |
|
|
attn_weights = self.attention_dropout(attn_weights) |
|
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) |
|
|
return attn_output |
|
|
|
|
|
class CacaCrossAttention(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.num_key_value_heads = config.num_key_value_heads |
|
|
self.head_dim = config.head_dim |
|
|
self.num_key_value_groups = self.num_heads // self.num_key_value_heads |
|
|
self.q_proj = nn.Linear( |
|
|
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias |
|
|
) |
|
|
self.k_proj = nn.Linear( |
|
|
self.hidden_size, |
|
|
self.num_key_value_heads * self.head_dim, |
|
|
bias=config.attention_bias, |
|
|
) |
|
|
self.v_proj = nn.Linear( |
|
|
self.hidden_size, |
|
|
self.num_key_value_heads * self.head_dim, |
|
|
bias=config.attention_bias, |
|
|
) |
|
|
self.o_proj = nn.Linear( |
|
|
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias |
|
|
) |
|
|
self.attention_dropout = nn.Dropout(config.attention_dropout) |
|
|
|
|
|
def forward(self, hidden_states, encoder_hidden_states, attention_mask=None): |
|
|
batch_size, seq_length, _ = hidden_states.size() |
|
|
encoder_seq_length = encoder_hidden_states.size(1) |
|
|
query_states = self.q_proj(hidden_states) |
|
|
key_states = self.k_proj(encoder_hidden_states) |
|
|
value_states = self.v_proj(encoder_hidden_states) |
|
|
query_states = query_states.view( |
|
|
batch_size, seq_length, self.num_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
key_states = key_states.view( |
|
|
batch_size, encoder_seq_length, self.num_key_value_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
value_states = value_states.view( |
|
|
batch_size, encoder_seq_length, self.num_key_value_heads, self.head_dim |
|
|
).transpose(1, 2) |
|
|
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) |
|
|
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) |
|
|
attn_weights = torch.matmul( |
|
|
query_states, key_states.transpose(2, 3) |
|
|
) / math.sqrt(self.head_dim) |
|
|
if attention_mask is not None: |
|
|
attn_weights = attn_weights + attention_mask |
|
|
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( |
|
|
query_states.dtype |
|
|
) |
|
|
attn_weights = self.attention_dropout(attn_weights) |
|
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
|
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) |
|
|
attn_output = self.o_proj(attn_output) |
|
|
return attn_output |
|
|
|
|
|
class CacaMLP(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size |
|
|
self.gate_proj = nn.Linear( |
|
|
self.hidden_size, self.intermediate_size, bias=config.mlp_bias |
|
|
) |
|
|
self.up_proj = nn.Linear( |
|
|
self.hidden_size, self.intermediate_size, bias=config.mlp_bias |
|
|
) |
|
|
self.down_proj = nn.Linear( |
|
|
self.intermediate_size, self.hidden_size, bias=config.mlp_bias |
|
|
) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
gate = F.silu(self.gate_proj(x)) |
|
|
up = self.up_proj(x) |
|
|
hidden = gate * up |
|
|
hidden = self.dropout(hidden) |
|
|
output = self.down_proj(hidden) |
|
|
return output |
|
|
|
|
|
class CacaDecoderLayer(nn.Module): |
|
|
def __init__(self, config, layer_idx): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
self.self_attn = CacaAttention(config, layer_idx=layer_idx) |
|
|
self.use_moe = config.use_moe and (layer_idx % config.moe_layer_frequency == 0) |
|
|
if self.use_moe: |
|
|
self.mlp = MixtureOfExperts(config) |
|
|
else: |
|
|
self.mlp = CacaMLP(config) |
|
|
self.use_cross_attention = config.use_cross_attention and ( |
|
|
layer_idx % config.cross_attention_frequency == 0 |
|
|
) |
|
|
if self.use_cross_attention: |
|
|
self.cross_attn = CacaCrossAttention(config) |
|
|
self.cross_attn_layernorm = CacaRMSNorm( |
|
|
config.hidden_size, config.rms_norm_eps |
|
|
) |
|
|
self.input_layernorm = CacaRMSNorm(config.hidden_size, config.rms_norm_eps) |
|
|
self.post_attention_layernorm = CacaRMSNorm( |
|
|
config.hidden_size, config.rms_norm_eps |
|
|
) |
|
|
self.residual_dropout = nn.Dropout(config.residual_dropout) |
|
|
if config.use_layer_scale: |
|
|
self.layer_scale_1 = LayerScale(config.hidden_size, config.layer_scale_init) |
|
|
self.layer_scale_2 = LayerScale(config.hidden_size, config.layer_scale_init) |
|
|
if self.use_cross_attention: |
|
|
self.layer_scale_cross = LayerScale( |
|
|
config.hidden_size, config.layer_scale_init |
|
|
) |
|
|
else: |
|
|
self.layer_scale_1 = None |
|
|
self.layer_scale_2 = None |
|
|
self.layer_scale_cross = None |
|
|
if config.use_stochastic_depth: |
|
|
drop_prob = ( |
|
|
config.stochastic_depth_prob * layer_idx / config.num_hidden_layers |
|
|
) |
|
|
self.stochastic_depth = StochasticDepth(drop_prob) |
|
|
else: |
|
|
self.stochastic_depth = None |
|
|
if config.use_mixture_of_depths: |
|
|
self.mod_router = MixtureOfDepthsRouter( |
|
|
config.hidden_size, config.mod_capacity_factor, config.mod_route_method |
|
|
) |
|
|
else: |
|
|
self.mod_router = None |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states, |
|
|
attention_mask=None, |
|
|
encoder_hidden_states=None, |
|
|
encoder_attention_mask=None, |
|
|
past_key_value=None, |
|
|
use_cache=False, |
|
|
): |
|
|
aux_loss = 0.0 |
|
|
z_loss = 0.0 |
|
|
if self.mod_router is not None: |
|
|
process_mask = self.mod_router(hidden_states) |
|
|
tokens_to_process = hidden_states[process_mask] |
|
|
if tokens_to_process.numel() == 0: |
|
|
present_key_value = past_key_value if use_cache else None |
|
|
return hidden_states, present_key_value, aux_loss, z_loss |
|
|
else: |
|
|
process_mask = None |
|
|
tokens_to_process = hidden_states |
|
|
residual = tokens_to_process |
|
|
tokens_to_process = self.input_layernorm(tokens_to_process) |
|
|
attn_output, present_key_value = self.self_attn( |
|
|
tokens_to_process, |
|
|
attention_mask, |
|
|
past_key_value=past_key_value, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
if self.layer_scale_1 is not None: |
|
|
attn_output = self.layer_scale_1(attn_output) |
|
|
if self.stochastic_depth is not None: |
|
|
attn_output = self.stochastic_depth(attn_output, self.training) |
|
|
tokens_to_process = residual + self.residual_dropout(attn_output) |
|
|
if self.use_cross_attention and encoder_hidden_states is not None: |
|
|
residual = tokens_to_process |
|
|
tokens_to_process = self.cross_attn_layernorm(tokens_to_process) |
|
|
cross_attn_output = self.cross_attn( |
|
|
tokens_to_process, |
|
|
encoder_hidden_states, |
|
|
attention_mask=encoder_attention_mask, |
|
|
) |
|
|
if self.layer_scale_cross is not None: |
|
|
cross_attn_output = self.layer_scale_cross(cross_attn_output) |
|
|
if self.stochastic_depth is not None: |
|
|
cross_attn_output = self.stochastic_depth( |
|
|
cross_attn_output, self.training |
|
|
) |
|
|
tokens_to_process = residual + self.residual_dropout(cross_attn_output) |
|
|
residual = tokens_to_process |
|
|
tokens_to_process = self.post_attention_layernorm(tokens_to_process) |
|
|
if self.use_moe: |
|
|
mlp_output, moe_aux_loss, moe_z_loss = self.mlp(tokens_to_process) |
|
|
aux_loss += moe_aux_loss |
|
|
z_loss += moe_z_loss |
|
|
else: |
|
|
mlp_output = self.mlp(tokens_to_process) |
|
|
if self.layer_scale_2 is not None: |
|
|
mlp_output = self.layer_scale_2(mlp_output) |
|
|
if self.stochastic_depth is not None: |
|
|
mlp_output = self.stochastic_depth(mlp_output, self.training) |
|
|
tokens_to_process = residual + self.residual_dropout(mlp_output) |
|
|
if process_mask is not None: |
|
|
hidden_states[process_mask] = tokens_to_process |
|
|
else: |
|
|
hidden_states = tokens_to_process |
|
|
return hidden_states, present_key_value, aux_loss, z_loss |
|
|
|
|
|
class VisionEncoder(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
vision_config = config.vision_config |
|
|
self.patch_size = vision_config.get("patch_size", 14) |
|
|
self.image_size = vision_config.get("image_size", 224) |
|
|
self.num_channels = vision_config.get("num_channels", 3) |
|
|
self.hidden_size = vision_config.get("hidden_size", 1024) |
|
|
self.num_layers = vision_config.get("num_layers", 24) |
|
|
self.num_heads = vision_config.get("num_heads", 16) |
|
|
self.intermediate_size = vision_config.get("intermediate_size", 4096) |
|
|
self.layer_norm_eps = vision_config.get("layer_norm_eps", 1e-6) |
|
|
self.num_patches = (self.image_size // self.patch_size) ** 2 |
|
|
self.patch_embed = nn.Sequential( |
|
|
nn.Conv2d( |
|
|
self.num_channels, |
|
|
self.hidden_size, |
|
|
kernel_size=self.patch_size, |
|
|
stride=self.patch_size, |
|
|
bias=False, |
|
|
), |
|
|
nn.Dropout(p=vision_config.get("dropout", 0.0)), |
|
|
) |
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size)) |
|
|
self.pos_embed = nn.Parameter( |
|
|
torch.zeros(1, self.num_patches + 1, self.hidden_size) |
|
|
) |
|
|
self.pos_drop = nn.Dropout(p=vision_config.get("dropout", 0.0)) |
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
VisionTransformerBlock( |
|
|
dim=self.hidden_size, |
|
|
num_heads=self.num_heads, |
|
|
mlp_ratio=self.intermediate_size / self.hidden_size, |
|
|
dropout=vision_config.get("dropout", 0.0), |
|
|
layer_norm_eps=self.layer_norm_eps, |
|
|
) |
|
|
for _ in range(self.num_layers) |
|
|
] |
|
|
) |
|
|
self.norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) |
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
nn.init.trunc_normal_(self.pos_embed, std=0.02) |
|
|
nn.init.trunc_normal_(self.cls_token, std=0.02) |
|
|
nn.init.trunc_normal_(self.patch_embed[0].weight, std=0.02) |
|
|
|
|
|
def forward(self, pixel_values): |
|
|
batch_size = pixel_values.shape[0] |
|
|
x = self.patch_embed(pixel_values) |
|
|
x = x.flatten(2).transpose(1, 2) |
|
|
cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
|
|
x = torch.cat([cls_tokens, x], dim=1) |
|
|
x = x + self.pos_embed |
|
|
x = self.pos_drop(x) |
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
x = self.norm(x) |
|
|
return x |
|
|
|
|
|
class VisionTransformerBlock(nn.Module): |
|
|
def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0, layer_norm_eps=1e-6): |
|
|
super().__init__() |
|
|
self.norm1 = nn.LayerNorm(dim, eps=layer_norm_eps) |
|
|
self.attn = nn.MultiheadAttention( |
|
|
dim, num_heads, dropout=dropout, batch_first=True |
|
|
) |
|
|
self.drop_path1 = nn.Dropout(dropout) |
|
|
self.norm2 = nn.LayerNorm(dim, eps=layer_norm_eps) |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(dim, mlp_hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(mlp_hidden_dim, dim), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
self.drop_path2 = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
residual = x |
|
|
x = self.norm1(x) |
|
|
x = self.attn(x, x, x, need_weights=False)[0] |
|
|
x = self.drop_path1(x) |
|
|
x = residual + x |
|
|
residual = x |
|
|
x = self.norm2(x) |
|
|
x = self.mlp(x) |
|
|
x = self.drop_path2(x) |
|
|
x = residual + x |
|
|
return x |
|
|
|
|
|
class AudioEncoder(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
audio_config = config.audio_config |
|
|
self.num_mel_bins = audio_config.get("num_mel_bins", 80) |
|
|
self.hidden_size = audio_config.get("hidden_size", 1024) |
|
|
self.num_layers = audio_config.get("num_layers", 12) |
|
|
self.num_heads = audio_config.get("num_heads", 16) |
|
|
self.intermediate_size = audio_config.get("intermediate_size", 4096) |
|
|
self.max_audio_length = audio_config.get("max_audio_length", 3000) |
|
|
self.dropout = audio_config.get("dropout", 0.0) |
|
|
self.conv1 = nn.Sequential( |
|
|
nn.Conv1d(self.num_mel_bins, self.hidden_size, kernel_size=3, padding=1), |
|
|
nn.GELU(), |
|
|
nn.Dropout(p=self.dropout), |
|
|
) |
|
|
self.conv2 = nn.Sequential( |
|
|
nn.Conv1d( |
|
|
self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1 |
|
|
), |
|
|
nn.GELU(), |
|
|
nn.Dropout(p=self.dropout), |
|
|
) |
|
|
self.pos_embed = nn.Parameter( |
|
|
torch.zeros(1, self.max_audio_length // 2, self.hidden_size) |
|
|
) |
|
|
self.pos_drop = nn.Dropout(p=self.dropout) |
|
|
self.blocks = nn.ModuleList( |
|
|
[ |
|
|
AudioTransformerBlock( |
|
|
dim=self.hidden_size, |
|
|
num_heads=self.num_heads, |
|
|
mlp_ratio=self.intermediate_size / self.hidden_size, |
|
|
dropout=self.dropout, |
|
|
) |
|
|
for _ in range(self.num_layers) |
|
|
] |
|
|
) |
|
|
self.norm = nn.LayerNorm(self.hidden_size) |
|
|
self._init_weights() |
|
|
|
|
|
def _init_weights(self): |
|
|
nn.init.trunc_normal_(self.pos_embed, std=0.02) |
|
|
|
|
|
def forward(self, audio_features): |
|
|
x = F.gelu(self.conv1(audio_features)) |
|
|
x = F.gelu(self.conv2(x)) |
|
|
x = x.transpose(1, 2) |
|
|
seq_len = x.shape[1] |
|
|
if seq_len <= self.pos_embed.shape[1]: |
|
|
x = x + self.pos_embed[:, :seq_len, :] |
|
|
else: |
|
|
pos_embed_interp = F.interpolate( |
|
|
self.pos_embed.transpose(1, 2), |
|
|
size=seq_len, |
|
|
mode="linear", |
|
|
align_corners=False, |
|
|
).transpose(1, 2) |
|
|
x = x + pos_embed_interp |
|
|
x = self.pos_drop(x) |
|
|
for block in self.blocks: |
|
|
x = block(x) |
|
|
x = self.norm(x) |
|
|
return x |
|
|
|
|
|
class AudioTransformerBlock(nn.Module): |
|
|
def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0): |
|
|
super().__init__() |
|
|
self.norm1 = nn.LayerNorm(dim) |
|
|
self.attn = nn.MultiheadAttention( |
|
|
dim, num_heads, dropout=dropout, batch_first=True |
|
|
) |
|
|
self.drop_path1 = nn.Dropout(dropout) |
|
|
self.norm2 = nn.LayerNorm(dim) |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.Linear(dim, mlp_hidden_dim), |
|
|
nn.GELU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(mlp_hidden_dim, dim), |
|
|
nn.Dropout(dropout), |
|
|
) |
|
|
self.drop_path2 = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
residual = x |
|
|
x = self.norm1(x) |
|
|
x = self.attn(x, x, x, need_weights=False)[0] |
|
|
x = self.drop_path1(x) |
|
|
x = residual + x |
|
|
residual = x |
|
|
x = self.norm2(x) |
|
|
x = self.mlp(x) |
|
|
x = self.drop_path2(x) |
|
|
x = residual + x |
|
|
return x |
|
|
|
|
|
class MultiModalProjector(nn.Module): |
|
|
def __init__(self, input_size, output_size, projector_type="mlp", num_layers=2): |
|
|
super().__init__() |
|
|
self.projector_type = projector_type |
|
|
if projector_type == "linear": |
|
|
self.projector = nn.Linear(input_size, output_size) |
|
|
elif projector_type == "mlp": |
|
|
layers = [] |
|
|
current_size = input_size |
|
|
for i in range(num_layers - 1): |
|
|
layers.extend( |
|
|
[nn.Linear(current_size, output_size), nn.GELU(), nn.Dropout(0.1)] |
|
|
) |
|
|
current_size = output_size |
|
|
layers.append(nn.Linear(current_size, output_size)) |
|
|
self.projector = nn.Sequential(*layers) |
|
|
elif projector_type == "perceiver": |
|
|
self.projector = PerceiverResampler( |
|
|
input_size, output_size, num_latents=64, num_layers=2 |
|
|
) |
|
|
elif projector_type == "qformer": |
|
|
self.projector = QFormerProjector( |
|
|
input_size, output_size, num_queries=32, num_layers=2 |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"projector_type tidak dikenal: {projector_type}") |
|
|
|
|
|
def forward(self, x): |
|
|
return self.projector(x) |
|
|
|
|
|
class PerceiverResampler(nn.Module): |
|
|
def __init__(self, input_size, output_size, num_latents=64, num_layers=2): |
|
|
super().__init__() |
|
|
self.num_latents = num_latents |
|
|
self.latents = nn.Parameter(torch.randn(num_latents, output_size)) |
|
|
self.layers = nn.ModuleList( |
|
|
[ |
|
|
PerceiverLayer(output_size, input_size if i == 0 else output_size) |
|
|
for i in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.norm = nn.LayerNorm(output_size) |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size = x.shape[0] |
|
|
latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
for i, layer in enumerate(self.layers): |
|
|
if i == 0: |
|
|
latents = layer(latents, x) |
|
|
else: |
|
|
latents = layer(latents, latents) |
|
|
return self.norm(latents) |
|
|
|
|
|
class PerceiverLayer(nn.Module): |
|
|
def __init__(self, query_dim, key_dim): |
|
|
super().__init__() |
|
|
self.cross_attn = nn.MultiheadAttention( |
|
|
query_dim, num_heads=8, kdim=key_dim, vdim=key_dim, batch_first=True |
|
|
) |
|
|
self.mlp = nn.Sequential( |
|
|
nn.LayerNorm(query_dim), |
|
|
nn.Linear(query_dim, query_dim * 4), |
|
|
nn.GELU(), |
|
|
nn.Linear(query_dim * 4, query_dim), |
|
|
) |
|
|
self.norm1 = nn.LayerNorm(query_dim) |
|
|
self.norm2 = nn.LayerNorm(query_dim) |
|
|
|
|
|
def forward(self, query, key): |
|
|
query = ( |
|
|
query + self.cross_attn(self.norm1(query), key, key, need_weights=False)[0] |
|
|
) |
|
|
query = query + self.mlp(self.norm2(query)) |
|
|
return query |
|
|
|
|
|
class QFormerProjector(nn.Module): |
|
|
def __init__(self, input_size, output_size, num_queries=32, num_layers=2): |
|
|
super().__init__() |
|
|
self.num_queries = num_queries |
|
|
self.query_embeds = nn.Parameter(torch.randn(num_queries, output_size)) |
|
|
self.query_layers = nn.ModuleList( |
|
|
[ |
|
|
nn.TransformerEncoderLayer( |
|
|
d_model=output_size, |
|
|
nhead=8, |
|
|
dim_feedforward=output_size * 4, |
|
|
batch_first=True, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.cross_attn_layers = nn.ModuleList( |
|
|
[ |
|
|
nn.MultiheadAttention( |
|
|
output_size, |
|
|
num_heads=8, |
|
|
kdim=input_size, |
|
|
vdim=input_size, |
|
|
batch_first=True, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.norm = nn.LayerNorm(output_size) |
|
|
|
|
|
def forward(self, x): |
|
|
batch_size = x.shape[0] |
|
|
queries = self.query_embeds.unsqueeze(0).expand(batch_size, -1, -1) |
|
|
for query_layer, cross_attn_layer in zip( |
|
|
self.query_layers, self.cross_attn_layers |
|
|
): |
|
|
queries = query_layer(queries) |
|
|
queries = queries + cross_attn_layer(queries, x, x, need_weights=False)[0] |
|
|
return self.norm(queries) |
|
|
|
|
|
class CacaPreTrainedModel(PreTrainedModel): |
|
|
config_class = CacaConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["CacaDecoderLayer"] |
|
|
_skip_keys_device_placement = "past_key_values" |
|
|
|
|
|
def _init_weights(self, module): |
|
|
std = self.config.initializer_range |
|
|
if isinstance(module, nn.Linear): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=std) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
if isinstance(module, CacaModel): |
|
|
module.gradient_checkpointing = value |
|
|
|
|
|
class CacaModel(CacaPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.layers = nn.ModuleList( |
|
|
[ |
|
|
CacaDecoderLayer(config, layer_idx=idx) |
|
|
for idx in range(config.num_hidden_layers) |
|
|
] |
|
|
) |
|
|
self.norm = CacaRMSNorm(config.hidden_size, config.rms_norm_eps) |
|
|
self.gradient_checkpointing = False |
|
|
if config.use_multimodal: |
|
|
if config.vision_config: |
|
|
self.vision_encoder = VisionEncoder(config) |
|
|
vision_hidden_size = config.vision_config.get("hidden_size", 768) |
|
|
self.vision_projector = MultiModalProjector( |
|
|
vision_hidden_size, |
|
|
config.hidden_size, |
|
|
projector_type=config.vision_config.get("projector_type", "mlp"), |
|
|
) |
|
|
else: |
|
|
self.vision_encoder = None |
|
|
self.vision_projector = None |
|
|
if config.audio_config: |
|
|
self.audio_encoder = AudioEncoder(config) |
|
|
audio_hidden_size = config.audio_config.get("hidden_size", 768) |
|
|
self.audio_projector = MultiModalProjector( |
|
|
audio_hidden_size, |
|
|
config.hidden_size, |
|
|
projector_type=config.audio_config.get("projector_type", "mlp"), |
|
|
) |
|
|
else: |
|
|
self.audio_encoder = None |
|
|
self.audio_projector = None |
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.embed_tokens = value |
|
|
|
|
|
def _prepare_attention_mask(self, attention_mask, input_shape, dtype): |
|
|
if attention_mask is None: |
|
|
return None |
|
|
batch_size, seq_length = input_shape |
|
|
if attention_mask.dim() == 2: |
|
|
attention_mask = attention_mask[:, None, None, :] |
|
|
elif attention_mask.dim() == 3: |
|
|
attention_mask = attention_mask[:, None, :, :] |
|
|
attention_mask = attention_mask.to(dtype=dtype) |
|
|
attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min |
|
|
return attention_mask |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids=None, |
|
|
pixel_values=None, |
|
|
audio_features=None, |
|
|
attention_mask=None, |
|
|
past_key_values=None, |
|
|
use_cache=None, |
|
|
output_hidden_states=False, |
|
|
return_dict=True, |
|
|
**kwargs, |
|
|
): |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
if input_ids is not None: |
|
|
batch_size, seq_length = input_ids.shape |
|
|
device = input_ids.device |
|
|
hidden_states = self.embed_tokens(input_ids) |
|
|
else: |
|
|
raise ValueError("input_ids tidak boleh None") |
|
|
|
|
|
if pixel_values is not None: |
|
|
pixel_values = pixel_values.to(device) |
|
|
if audio_features is not None: |
|
|
audio_features = audio_features.to(device) |
|
|
|
|
|
encoder_hidden_states = None |
|
|
encoder_attention_mask = None |
|
|
if self.config.use_multimodal: |
|
|
multimodal_embeds = [] |
|
|
if pixel_values is not None and self.vision_encoder is not None: |
|
|
vision_features = self.vision_encoder(pixel_values.to(hidden_states.device)) |
|
|
vision_embeds = self.vision_projector(vision_features) |
|
|
multimodal_embeds.append(vision_embeds) |
|
|
if audio_features is not None and self.audio_encoder is not None: |
|
|
audio_encoded = self.audio_encoder(audio_features.to(hidden_states.device)) |
|
|
audio_embeds = self.audio_projector(audio_encoded) |
|
|
multimodal_embeds.append(audio_embeds) |
|
|
if multimodal_embeds and self.config.use_cross_attention: |
|
|
encoder_hidden_states = torch.cat(multimodal_embeds, dim=1) |
|
|
encoder_seq_len = encoder_hidden_states.shape[1] |
|
|
encoder_attention_mask = torch.ones( |
|
|
batch_size, |
|
|
encoder_seq_len, |
|
|
dtype=hidden_states.dtype, |
|
|
device=hidden_states.device, |
|
|
) |
|
|
|
|
|
elif multimodal_embeds: |
|
|
multimodal_concat = torch.cat(multimodal_embeds, dim=1) |
|
|
max_multimodal_tokens = self.config.max_position_embeddings // 4 |
|
|
if multimodal_concat.shape[1] > max_multimodal_tokens: |
|
|
logger.warning( |
|
|
f"Multimodal tokens ({multimodal_concat.shape[1]}) > max ({max_multimodal_tokens}). " |
|
|
f"Truncating..." |
|
|
) |
|
|
multimodal_concat = multimodal_concat[:, :max_multimodal_tokens] |
|
|
hidden_states = torch.cat([multimodal_concat, hidden_states], dim=1) |
|
|
seq_length = hidden_states.shape[1] |
|
|
if attention_mask is not None: |
|
|
multimodal_mask = torch.ones( |
|
|
batch_size, |
|
|
multimodal_concat.shape[1], |
|
|
dtype=attention_mask.dtype, |
|
|
device=attention_mask.device, |
|
|
) |
|
|
attention_mask = torch.cat([multimodal_mask, attention_mask], dim=1) |
|
|
else: |
|
|
attention_mask = torch.ones( |
|
|
batch_size, |
|
|
seq_length, |
|
|
dtype=hidden_states.dtype, |
|
|
device=device, |
|
|
) |
|
|
if attention_mask is not None: |
|
|
attention_mask = self._prepare_attention_mask( |
|
|
attention_mask, (batch_size, seq_length), hidden_states.dtype |
|
|
) |
|
|
if encoder_attention_mask is not None and self.config.use_cross_attention: |
|
|
encoder_attention_mask = self._prepare_attention_mask( |
|
|
encoder_attention_mask, |
|
|
(batch_size, encoder_hidden_states.shape[1]), |
|
|
hidden_states.dtype, |
|
|
) |
|
|
|
|
|
if use_cache: |
|
|
if past_key_values is None: |
|
|
past_key_values = tuple([None] * len(self.layers)) |
|
|
|
|
|
present_key_values = [] if use_cache else None |
|
|
all_hidden_states = [] if output_hidden_states else None |
|
|
total_aux_loss = 0.0 |
|
|
total_z_loss = 0.0 |
|
|
for idx, layer in enumerate(self.layers): |
|
|
if output_hidden_states: |
|
|
all_hidden_states.append(hidden_states) |
|
|
past_key_value = ( |
|
|
past_key_values[idx] if past_key_values is not None else None |
|
|
) |
|
|
if self.gradient_checkpointing and self.training and not use_cache: |
|
|
hidden_states, aux_loss, z_loss = self._gradient_checkpointing_forward( |
|
|
layer, |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
encoder_hidden_states, |
|
|
encoder_attention_mask, |
|
|
) |
|
|
present_key_value = None |
|
|
else: |
|
|
hidden_states, present_key_value, aux_loss, z_loss = layer( |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
encoder_attention_mask=encoder_attention_mask, |
|
|
past_key_value=past_key_value, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
if use_cache: |
|
|
present_key_values.append(present_key_value) |
|
|
total_aux_loss += aux_loss |
|
|
total_z_loss += z_loss |
|
|
|
|
|
if self.training and torch.cuda.is_available(): |
|
|
allocated_gb = torch.cuda.memory_allocated() / 1024**3 |
|
|
reserved_gb = torch.cuda.memory_reserved() / 1024**3 |
|
|
if allocated_gb > 10: |
|
|
logger.warning( |
|
|
f"High GPU memory usage - Allocated: {allocated_gb:.2f}GB, " |
|
|
f"Reserved: {reserved_gb:.2f}GB" |
|
|
) |
|
|
hidden_states = self.norm(hidden_states) |
|
|
if output_hidden_states: |
|
|
all_hidden_states.append(hidden_states) |
|
|
if not return_dict: |
|
|
return tuple( |
|
|
v |
|
|
for v in [ |
|
|
hidden_states, |
|
|
present_key_values, |
|
|
all_hidden_states, |
|
|
total_aux_loss, |
|
|
total_z_loss, |
|
|
] |
|
|
if v is not None |
|
|
) |
|
|
return ( |
|
|
BaseModelOutputWithPast( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=tuple(present_key_values) if use_cache else None, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=None, |
|
|
), |
|
|
total_aux_loss, |
|
|
total_z_loss, |
|
|
) |
|
|
|
|
|
def _gradient_checkpointing_forward( |
|
|
self, |
|
|
layer, |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
encoder_hidden_states, |
|
|
encoder_attention_mask, |
|
|
): |
|
|
from torch.utils.checkpoint import checkpoint |
|
|
|
|
|
def custom_forward(hidden_states, attention_mask, encoder_hidden_states, |
|
|
encoder_attention_mask): |
|
|
output, _, aux_loss, z_loss = layer( |
|
|
hidden_states, attention_mask, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
encoder_attention_mask=encoder_attention_mask, |
|
|
past_key_value=None, |
|
|
use_cache=False, |
|
|
) |
|
|
|
|
|
return output, aux_loss, z_loss |
|
|
|
|
|
hidden_states, aux_loss, z_loss = checkpoint( |
|
|
custom_forward, |
|
|
hidden_states, attention_mask, |
|
|
encoder_hidden_states, encoder_attention_mask, |
|
|
use_reentrant=False, |
|
|
) |
|
|
return hidden_states, aux_loss, z_loss |
|
|
|
|
|
class CacaForCausalLM(CacaPreTrainedModel, GenerationMixin): |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = CacaModel(config) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.model.embed_tokens |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.model.embed_tokens = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def set_decoder(self, decoder): |
|
|
self.model = decoder |
|
|
|
|
|
def get_decoder(self): |
|
|
return self.model |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids=None, |
|
|
pixel_values=None, |
|
|
audio_features=None, |
|
|
attention_mask=None, |
|
|
labels=None, |
|
|
past_key_values=None, |
|
|
inputs_embeds=None, |
|
|
use_cache=None, |
|
|
output_attentions=None, |
|
|
output_hidden_states=None, |
|
|
return_dict=None, |
|
|
**kwargs, |
|
|
): |
|
|
if input_ids is not None: |
|
|
if input_ids.dtype.is_floating_point: |
|
|
raise TypeError( |
|
|
f"input_ids harus integer dtype, dapat {input_ids.dtype}. " |
|
|
f"Gunakan input_ids.long() untuk convert." |
|
|
) |
|
|
if (input_ids < 0).any(): |
|
|
neg_vals = input_ids[input_ids < 0].unique().tolist() |
|
|
raise ValueError(f"input_ids mengandung nilai negatif: {neg_vals}") |
|
|
max_val = input_ids.max().item() |
|
|
if max_val >= self.config.vocab_size: |
|
|
raise ValueError( |
|
|
f"input_ids mengandung nilai >= vocab_size. " |
|
|
f"Max value: {max_val}, vocab_size: {self.config.vocab_size:,}" |
|
|
) |
|
|
|
|
|
if labels is not None: |
|
|
if not labels.dtype in [torch.long, torch.int, torch.int32, torch.int64]: |
|
|
raise TypeError(f"labels harus integer dtype, dapat {labels.dtype}") |
|
|
if (labels[labels != -100] < 0).any(): |
|
|
raise ValueError(f"labels mengandung nilai negatif (selain -100)") |
|
|
max_label = labels[labels != -100].max().item() if (labels != -100).any() else 0 |
|
|
if max_label >= self.config.vocab_size: |
|
|
raise ValueError( |
|
|
f"labels mengandung nilai >= vocab_size. " |
|
|
f"Max: {max_label}, vocab_size: {self.config.vocab_size}" |
|
|
) |
|
|
if attention_mask is not None: |
|
|
if attention_mask.shape[0] != input_ids.shape[0]: |
|
|
raise ValueError( |
|
|
f"attention_mask batch size ({attention_mask.shape[0]}) != " |
|
|
f"input_ids batch size ({input_ids.shape[0]})" |
|
|
) |
|
|
if attention_mask.shape[1] != input_ids.shape[1]: |
|
|
raise ValueError( |
|
|
f"attention_mask seq length ({attention_mask.shape[1]}) != " |
|
|
f"input_ids seq length ({input_ids.shape[1]})" |
|
|
) |
|
|
|
|
|
return_dict = ( |
|
|
return_dict if return_dict is not None else self.config.use_return_dict |
|
|
) |
|
|
outputs, aux_loss, z_loss = self.model( |
|
|
input_ids, |
|
|
pixel_values=pixel_values, |
|
|
audio_features=audio_features, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=use_cache, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
if return_dict: |
|
|
hidden_states = outputs.last_hidden_state |
|
|
else: |
|
|
hidden_states = outputs[0] |
|
|
logits = self.lm_head(hidden_states) |
|
|
if self.config.final_logit_softcapping: |
|
|
logits = soft_cap_logits(logits, self.config.final_logit_softcapping) |
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
lm_loss = loss_fct( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
|
|
) |
|
|
if self.config.use_moe: |
|
|
total_loss = ( |
|
|
lm_loss |
|
|
+ (self.config.router_aux_loss_coef * aux_loss) |
|
|
+ (self.config.router_z_loss_coef * z_loss) |
|
|
) |
|
|
else: |
|
|
total_loss = lm_loss |
|
|
loss = total_loss |
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
if return_dict: |
|
|
output += tuple( |
|
|
v |
|
|
for v in [outputs.past_key_values, outputs.hidden_states] |
|
|
if v is not None |
|
|
) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=outputs.past_key_values if return_dict else None, |
|
|
hidden_states=outputs.hidden_states if return_dict else None, |
|
|
attentions=None, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation( |
|
|
self, |
|
|
input_ids, |
|
|
past_key_values=None, |
|
|
attention_mask=None, |
|
|
inputs_embeds=None, |
|
|
pixel_values=None, |
|
|
audio_features=None, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
has_past = ( |
|
|
past_key_values is not None |
|
|
and len(past_key_values) > 0 |
|
|
and past_key_values[0] is not None |
|
|
) |
|
|
|
|
|
if has_past: |
|
|
input_ids = input_ids[:, -1:] |
|
|
|
|
|
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: |
|
|
attention_mask = attention_mask[:, -input_ids.shape[1]:] |
|
|
|
|
|
if inputs_embeds is not None and not has_past: |
|
|
model_inputs = {"inputs_embeds": inputs_embeds} |
|
|
else: |
|
|
model_inputs = {"input_ids": input_ids} |
|
|
|
|
|
model_inputs.update( |
|
|
{ |
|
|
"past_key_values": past_key_values if has_past else None, |
|
|
"use_cache": kwargs.get("use_cache"), |
|
|
"attention_mask": attention_mask, |
|
|
"pixel_values": pixel_values if not has_past else None, |
|
|
"audio_features": audio_features if not has_past else None, |
|
|
} |
|
|
) |
|
|
return model_inputs |
|
|
|
|
|
@staticmethod |
|
|
def _reorder_cache(past_key_values, beam_idx): |
|
|
reordered_past = () |
|
|
for layer_past in past_key_values: |
|
|
if layer_past is not None and len(layer_past) > 0: |
|
|
reordered_past += ( |
|
|
tuple( |
|
|
past_state.index_select(0, beam_idx.to(past_state.device)) |
|
|
for past_state in layer_past |
|
|
if past_state is not None |
|
|
), |
|
|
) |
|
|
else: |
|
|
reordered_past += (None,) |
|
|
return reordered_past |
|
|
|
|
|
def save_pretrained(self, save_directory, **kwargs): |
|
|
has_quant_config = hasattr(self.config, 'quantization_config') |
|
|
quantization_config_backup = getattr(self.config, 'quantization_config', None) |
|
|
|
|
|
if has_quant_config and quantization_config_backup is None: |
|
|
delattr(self.config, 'quantization_config') |
|
|
|
|
|
try: |
|
|
super().save_pretrained(save_directory, **kwargs) |
|
|
finally: |
|
|
if has_quant_config: |
|
|
self.config.quantization_config = quantization_config_backup |
|
|
|
|
|
class CacaForCausalLMQuantized(CacaForCausalLM): |
|
|
def __init__(self, config, quantization_config=None): |
|
|
super().__init__(config) |
|
|
self.quantization_config = quantization_config |
|
|
if quantization_config: |
|
|
self._apply_quantization() |
|
|
|
|
|
def _apply_quantization(self): |
|
|
if self.quantization_config.get("load_in_8bit"): |
|
|
self._quantize_8bit() |
|
|
elif self.quantization_config.get("load_in_4bit"): |
|
|
self._quantize_4bit() |
|
|
|
|
|
def _quantize_8bit(self): |
|
|
try: |
|
|
import bitsandbytes as bnb |
|
|
|
|
|
for name, module in self.named_modules(): |
|
|
if isinstance(module, nn.Linear): |
|
|
has_bias = module.bias is not None |
|
|
new_module = bnb.nn.Linear8bitLt( |
|
|
module.in_features, |
|
|
module.out_features, |
|
|
has_bias, |
|
|
threshold=self.quantization_config.get( |
|
|
"llm_int8_threshold", 6.0 |
|
|
), |
|
|
) |
|
|
new_module.weight = module.weight |
|
|
if has_bias: |
|
|
new_module.bias = module.bias |
|
|
parent_name = ".".join(name.split(".")[:-1]) |
|
|
child_name = name.split(".")[-1] |
|
|
if parent_name: |
|
|
parent = self.get_submodule(parent_name) |
|
|
setattr(parent, child_name, new_module) |
|
|
else: |
|
|
setattr(self, child_name, new_module) |
|
|
logger.info("Quantisasi 8-bit berhasil diterapkan") |
|
|
except ImportError: |
|
|
logger.error("bitsandbytes tidak terinstall! pip install bitsandbytes") |
|
|
|
|
|
def _quantize_4bit(self): |
|
|
try: |
|
|
import bitsandbytes as bnb |
|
|
|
|
|
compute_dtype = torch.float16 |
|
|
if self.quantization_config.get("bnb_4bit_compute_dtype"): |
|
|
compute_dtype = getattr( |
|
|
torch, self.quantization_config["bnb_4bit_compute_dtype"] |
|
|
) |
|
|
for name, module in self.named_modules(): |
|
|
if isinstance(module, nn.Linear): |
|
|
has_bias = module.bias is not None |
|
|
new_module = bnb.nn.Linear4bit( |
|
|
module.in_features, |
|
|
module.out_features, |
|
|
bias=has_bias, |
|
|
compute_dtype=compute_dtype, |
|
|
quant_type=self.quantization_config.get( |
|
|
"bnb_4bit_quant_type", "nf4" |
|
|
), |
|
|
use_double_quant=self.quantization_config.get( |
|
|
"bnb_4bit_use_double_quant", True |
|
|
), |
|
|
) |
|
|
new_module.weight = module.weight |
|
|
if has_bias: |
|
|
new_module.bias = module.bias |
|
|
parent_name = ".".join(name.split(".")[:-1]) |
|
|
child_name = name.split(".")[-1] |
|
|
if parent_name: |
|
|
parent = self.get_submodule(parent_name) |
|
|
setattr(parent, child_name, new_module) |
|
|
else: |
|
|
setattr(self, child_name, new_module) |
|
|
logger.info("Quantisasi 4-bit berhasil diterapkan") |
|
|
except ImportError: |
|
|
logger.error("bitsandbytes tidak terinstall!") |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained_quantized(cls, model_path, quantization_config): |
|
|
config = CacaConfig.from_pretrained(model_path) |
|
|
model = cls(config, quantization_config=quantization_config) |
|
|
state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu") |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
return model |