import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, Tuple, List from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from transformers.generation.utils import GenerationMixin from collections import OrderedDict import logging logger = logging.getLogger(__name__) # Flash Attention try: from flash_attn import flash_attn_func HAS_FLASH_ATTN = True except ImportError: HAS_FLASH_ATTN = False # xFormers try: from xformers.ops import memory_efficient_attention HAS_XFORMERS = True except ImportError: HAS_XFORMERS = False # PyTorch SDPA HAS_SDPA = hasattr(F, 'scaled_dot_product_attention') # --- config --- class CacaConfig(PretrainedConfig): model_type = "caca" def __init__( self, vocab_size=32000, hidden_size=2048, intermediate_size=8192, num_hidden_layers=24, num_attention_heads=32, num_key_value_heads=8, head_dim=64, max_position_embeddings=8192, rms_norm_eps=1e-6, qk_norm_eps=1e-6, initializer_range=0.02, use_cache=True, pad_token_id=None, bos_token_id=1, eos_token_id=2, tie_word_embeddings=False, rope_theta=10000.0, rope_scaling=None, use_rotary_embeddings=True, attention_bias=False, attention_dropout=0.0, use_qk_norm=True, use_alibi=False, use_flash_attn=True, use_grouped_query_attention=False, use_multi_query_attention=False, sliding_window=None, use_longformer_attention=False, longformer_attention_window=512, attn_logit_softcapping=None, final_logit_softcapping=None, attention_sink_size=4, attention_sink_window=1024, use_attention_sink=False, attention_pattern="all_global", global_attention_every_n_layers=2, mlp_bias=False, hidden_dropout=0.1, residual_dropout=0.1, use_moe=False, num_experts=8, num_experts_per_tok=2, use_expert_choice=False, expert_choice_k=0.125, router_aux_loss_coef=0.01, router_z_loss_coef=0.001, moe_layer_frequency=2, expert_capacity_factor=1.0, use_grouped_moe=False, num_expert_groups=1, use_layer_scale=False, layer_scale_init=1e-5, use_stochastic_depth=False, stochastic_depth_prob=0.1, use_mixture_of_depths=False, mod_capacity_factor=0.5, mod_route_method="learned", use_cross_attention=False, cross_attention_frequency=4, use_multimodal=False, vision_config=None, audio_config=None, projector_hidden_size=None, use_soft_merging=False, merge_threshold=0.5, pretraining_tp=1, tensor_parallel_size=1, pipeline_parallel_size=1, chat_template=None, **kwargs ): self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.head_dim = head_dim or (hidden_size // num_attention_heads if hidden_size and num_attention_heads else None) self.max_position_embeddings = max_position_embeddings self.rms_norm_eps = rms_norm_eps self.qk_norm_eps = qk_norm_eps self.initializer_range = initializer_range self.use_cache = use_cache self.pad_token_id = pad_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.tie_word_embeddings = tie_word_embeddings self.rope_theta = rope_theta self.rope_scaling = rope_scaling self.use_rotary_embeddings = use_rotary_embeddings self.attention_bias = attention_bias self.attention_dropout = attention_dropout self.use_qk_norm = use_qk_norm self.use_alibi = use_alibi self.use_flash_attn = use_flash_attn self.use_grouped_query_attention = use_grouped_query_attention self.use_multi_query_attention = use_multi_query_attention self.sliding_window = sliding_window self.use_longformer_attention = use_longformer_attention self.longformer_attention_window = longformer_attention_window self.attn_logit_softcapping = attn_logit_softcapping self.final_logit_softcapping = final_logit_softcapping self.attention_sink_size = attention_sink_size self.attention_sink_window = attention_sink_window self.use_attention_sink = use_attention_sink self.attention_pattern = attention_pattern self.global_attention_every_n_layers = global_attention_every_n_layers self.mlp_bias = mlp_bias self.hidden_dropout = hidden_dropout self.residual_dropout = residual_dropout self.use_moe = use_moe self.num_experts = num_experts self.num_experts_per_tok = num_experts_per_tok self.use_expert_choice = use_expert_choice self.expert_choice_k = expert_choice_k self.router_aux_loss_coef = router_aux_loss_coef self.router_z_loss_coef = router_z_loss_coef self.moe_layer_frequency = moe_layer_frequency self.expert_capacity_factor = expert_capacity_factor self.use_grouped_moe = use_grouped_moe self.num_expert_groups = num_expert_groups self.use_layer_scale = use_layer_scale self.layer_scale_init = layer_scale_init self.use_stochastic_depth = use_stochastic_depth self.stochastic_depth_prob = stochastic_depth_prob self.use_mixture_of_depths = use_mixture_of_depths self.mod_capacity_factor = mod_capacity_factor self.mod_route_method = mod_route_method self.use_cross_attention = use_cross_attention self.cross_attention_frequency = cross_attention_frequency self.use_multimodal = use_multimodal self.vision_config = vision_config or {} self.audio_config = audio_config or {} self.projector_hidden_size = projector_hidden_size or hidden_size self.use_soft_merging = use_soft_merging self.merge_threshold = merge_threshold self.pretraining_tp = pretraining_tp self.tensor_parallel_size = tensor_parallel_size self.pipeline_parallel_size = pipeline_parallel_size if chat_template is None: self.chat_template = ( "{% for message in messages %}" "{% if message['role'] == 'system' %}" "System: {{ message['content'] }}\n" "{% elif message['role'] == 'user' %}" "User: {{ message['content'] }}\n" "{% elif message['role'] == 'assistant' %}" "Assistant: {{ message['content'] }}\n" "{% endif %}" "{% endfor %}" "{% if add_generation_prompt %}Assistant:{% endif %}" ) else: self.chat_template = chat_template super().__init__(**kwargs) def to_dict(self): """ Serializes this instance to a Python dictionary. Handles quantization_config properly and adds auto_map for HuggingFace Hub. """ # Backup and temporarily remove quantization_config if None has_quant_config = hasattr(self, 'quantization_config') quantization_config_backup = getattr(self, 'quantization_config', None) if has_quant_config and quantization_config_backup is None: delattr(self, 'quantization_config') try: # Call parent to_dict output = super().to_dict() # Add auto_map for custom model loading output['auto_map'] = { "AutoConfig": "caca_transformers.CacaConfig", "AutoModel": "caca_transformers.CacaModel", "AutoModelForCausalLM": "caca_transformers.CacaForCausalLM" } finally: # Always restore quantization_config if has_quant_config: self.quantization_config = quantization_config_backup return output # --- Arsitektur Model --- class CacaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x): variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return self.weight * x class LayerScale(nn.Module): def __init__(self, dim, init_value=1e-5): super().__init__() self.gamma = nn.Parameter(init_value * torch.ones(dim)) def forward(self, x): return self.gamma * x class StochasticDepth(nn.Module): def __init__(self, drop_prob=0.0): super().__init__() self.drop_prob = drop_prob def forward(self, x, training=True): if not training or self.drop_prob == 0.0: return x keep_prob = 1 - self.drop_prob shape = (x.shape[0],) + (1,) * (x.ndim - 1) random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() return x.div(keep_prob) * random_tensor class CacaRotaryEmbedding(nn.Module): def __init__( self, dim, max_position_embeddings=8192, base=10000.0, scaling_factor=1.0, scaling_type=None, ): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base self.scaling_factor = scaling_factor self.scaling_type = scaling_type inv_freq = 1.0 / ( self.base ** (torch.arange(0, self.dim, 2).float() / self.dim) ) if scaling_type == "linear": inv_freq = inv_freq / scaling_factor elif scaling_type == "dynamic": inv_freq = inv_freq elif scaling_type == "yarn": inv_freq = self._yarn_get_inv_freq(inv_freq) self.register_buffer("inv_freq", inv_freq, persistent=False) def _yarn_get_inv_freq(self, inv_freq): if len(inv_freq) == 0: return inv_freq alpha = self.scaling_factor beta_fast = 32 beta_slow = 1 freq_threshold = 1 / (self.max_position_embeddings * beta_fast) low_freq_mask = inv_freq > freq_threshold high_freq_mask = ~low_freq_mask low_freq = inv_freq[low_freq_mask] high_freq = inv_freq[high_freq_mask] if len(low_freq) > 0: low_freq = low_freq / alpha if len(high_freq) > 0: smooth_factor = ( self.max_position_embeddings * beta_slow / high_freq - beta_fast ) / (beta_slow - beta_fast) smooth_factor = torch.clamp(smooth_factor, 0.0, 1.0) high_freq = (1 - smooth_factor) * ( high_freq / alpha ) + smooth_factor * high_freq result = torch.zeros_like(inv_freq) result[low_freq_mask] = low_freq result[high_freq_mask] = high_freq return result def forward(self, x, seq_len, position_offset=0): t = torch.arange( position_offset, position_offset + seq_len, device=x.device ).type_as(self.inv_freq) if self.scaling_type == "dynamic": if seq_len > self.max_position_embeddings: dynamic_scale = seq_len / self.max_position_embeddings t = t / dynamic_scale freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) cos = emb.cos()[None, None, :, :] sin = emb.sin()[None, None, :, :] return cos.to(x.dtype), sin.to(x.dtype) class ALiBiPositionalBias(nn.Module): def __init__(self, num_heads, max_positions=8192): super().__init__() self.num_heads = num_heads self.max_positions = max_positions slopes = torch.tensor(self._get_slopes(num_heads)) self.register_buffer("slopes", slopes, persistent=False) def _get_slopes(self, n): def get_slopes_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * (ratio**i) for i in range(n)] if math.log2(n).is_integer(): return get_slopes_power_of_2(n) else: closest_power_of_2 = 2 ** math.floor(math.log2(n)) return ( get_slopes_power_of_2(closest_power_of_2) + self._get_slopes(2 * closest_power_of_2)[0::2][ : n - closest_power_of_2 ] ) def forward(self, seq_len, key_len=None): if key_len is None: key_len = seq_len query_pos = torch.arange(seq_len, device=self.slopes.device).unsqueeze(1) key_pos = torch.arange(key_len, device=self.slopes.device).unsqueeze(0) relative_pos = key_pos - query_pos bias = relative_pos.unsqueeze(0) * self.slopes.unsqueeze(1).unsqueeze(2) return bias.unsqueeze(0) def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): cos = cos.to(q.dtype) sin = sin.to(q.dtype) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed def soft_cap_logits(x, cap): if cap is None or cap <= 0: return x return cap * torch.tanh(x / cap) class TopKRouter(nn.Module): def __init__(self, hidden_size, num_experts, num_experts_per_tok): super().__init__() self.num_experts = num_experts self.num_experts_per_tok = num_experts_per_tok self.gate = nn.Linear(hidden_size, num_experts, bias=False) self.gate_norm = nn.LayerNorm(hidden_size) def forward(self, hidden_states): batch_size, seq_len, hidden_size = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_size) hidden_states = self.gate_norm(hidden_states) router_logits = self.gate(hidden_states) router_logits = torch.clamp(router_logits, min=-10, max=10) routing_weights = F.softmax(router_logits, dim=-1) top_k_weights, top_k_indices = torch.topk( routing_weights, self.num_experts_per_tok, dim=-1 ) top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-8) router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) expert_usage = router_probs.mean(dim=0) aux_loss = self.num_experts * (expert_usage * expert_usage).sum() router_logits_for_z = router_logits.to(torch.float32) z_loss = torch.logsumexp(router_logits_for_z, dim=-1).mean() return top_k_weights, top_k_indices, aux_loss, z_loss class ExpertChoiceRouter(nn.Module): def __init__(self, hidden_size, num_experts, expert_choice_k): super().__init__() self.num_experts = num_experts self.expert_choice_k = expert_choice_k self.gate = nn.Linear(hidden_size, num_experts, bias=False) def forward(self, hidden_states): batch_size, seq_len, hidden_size = hidden_states.shape total_tokens = batch_size * seq_len hidden_states_flat = hidden_states.view(-1, hidden_size) router_logits = self.gate(hidden_states_flat) router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) router_probs_t = router_probs.t() capacity = max(1, int(self.expert_choice_k * total_tokens / self.num_experts)) top_k_values, top_k_indices = torch.topk( router_probs_t, k=min(capacity, total_tokens), dim=-1 ) expert_mask = torch.zeros( self.num_experts, total_tokens, device=hidden_states.device ) for expert_idx in range(self.num_experts): expert_mask[expert_idx, top_k_indices[expert_idx]] = 1.0 routing_weights = expert_mask.t() * router_probs aux_loss = (router_probs.mean(dim=0) ** 2).sum() * self.num_experts z_loss = torch.logsumexp(router_logits, dim=-1).mean() return routing_weights, aux_loss, z_loss class Expert(nn.Module): def __init__(self, config): super().__init__() self.gate_proj = nn.Linear( config.hidden_size, config.intermediate_size, bias=config.mlp_bias ) self.up_proj = nn.Linear( config.hidden_size, config.intermediate_size, bias=config.mlp_bias ) self.down_proj = nn.Linear( config.intermediate_size, config.hidden_size, bias=config.mlp_bias ) self.dropout = nn.Dropout(config.hidden_dropout) def forward(self, x): gate = F.silu(self.gate_proj(x)) up = self.up_proj(x) hidden = gate * up hidden = self.dropout(hidden) return self.down_proj(hidden) class MixtureOfExperts(nn.Module): def __init__(self, config): super().__init__() self.config = config self.num_experts = config.num_experts self.num_experts_per_tok = config.num_experts_per_tok self.use_expert_choice = config.use_expert_choice self.experts = nn.ModuleList([Expert(config) for _ in range(self.num_experts)]) if self.use_expert_choice: self.router = ExpertChoiceRouter( config.hidden_size, config.num_experts, config.expert_choice_k ) else: self.router = TopKRouter( config.hidden_size, config.num_experts, config.num_experts_per_tok ) def forward(self, hidden_states): batch_size, seq_len, hidden_size = hidden_states.shape hidden_states_flat = hidden_states.view(-1, hidden_size) if self.use_expert_choice: routing_weights, aux_loss, z_loss = self.router(hidden_states) final_output = torch.zeros_like(hidden_states_flat) for expert_idx, expert in enumerate(self.experts): expert_mask = routing_weights[:, expert_idx] > 0 if expert_mask.any(): expert_input = hidden_states_flat[expert_mask] expert_output = expert(expert_input) final_output[expert_mask] += ( expert_output * routing_weights[expert_mask, expert_idx : expert_idx + 1] ) else: top_k_weights, top_k_indices, aux_loss, z_loss = self.router(hidden_states) final_output = torch.zeros_like(hidden_states_flat) for i in range(self.num_experts_per_tok): expert_indices = top_k_indices[:, i] expert_weights = top_k_weights[:, i : i + 1] for expert_idx in range(self.num_experts): expert_mask = expert_indices == expert_idx if expert_mask.any(): expert_input = hidden_states_flat[expert_mask] expert_output = self.experts[expert_idx](expert_input) final_output[expert_mask] += ( expert_output * expert_weights[expert_mask] ) final_output = final_output.view(batch_size, seq_len, hidden_size) return final_output, aux_loss, z_loss class MixtureOfDepthsRouter(nn.Module): def __init__(self, hidden_size, capacity_factor=0.5, route_method="learned"): super().__init__() self.capacity_factor = capacity_factor self.route_method = route_method if route_method == "learned": self.router = nn.Linear(hidden_size, 1) def forward(self, hidden_states): batch_size, seq_len, hidden_size = hidden_states.shape if self.route_method == "learned": routing_logits = self.router(hidden_states).squeeze(-1) elif self.route_method == "random": routing_logits = torch.rand( batch_size, seq_len, device=hidden_states.device ) else: routing_logits = torch.zeros( batch_size, seq_len, device=hidden_states.device ) capacity = max(1, int(seq_len * self.capacity_factor)) _, top_indices = torch.topk(routing_logits, k=capacity, dim=-1) process_mask = torch.zeros( batch_size, seq_len, dtype=torch.bool, device=hidden_states.device ) process_mask.scatter_(1, top_indices, True) return process_mask class CacaAttention(nn.Module): def __init__(self, config, layer_idx=None): super().__init__() self.config = config self.layer_idx = layer_idx self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.head_dim self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.sliding_window = config.sliding_window self.attn_logit_softcapping = config.attn_logit_softcapping self.attention_sink_size = config.attention_sink_size self.attention_sink_window = config.attention_sink_window self.q_proj = nn.Linear( self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.v_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.o_proj = nn.Linear( self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias ) if config.use_qk_norm: self.q_norm = CacaRMSNorm(self.head_dim, eps=config.qk_norm_eps) self.k_norm = CacaRMSNorm(self.head_dim, eps=config.qk_norm_eps) else: self.q_norm = None self.k_norm = None if config.use_rotary_embeddings: scaling_factor = 1.0 scaling_type = None if config.rope_scaling is not None: scaling_type = config.rope_scaling.get("type", "linear") scaling_factor = config.rope_scaling.get("factor", 1.0) self.rotary_emb = CacaRotaryEmbedding( self.head_dim, config.max_position_embeddings, config.rope_theta, scaling_factor=scaling_factor, scaling_type=scaling_type, ) else: self.rotary_emb = None if config.use_alibi: self.alibi = ALiBiPositionalBias( self.num_heads, config.max_position_embeddings ) else: self.alibi = None self.attention_dropout = nn.Dropout(config.attention_dropout) self.is_global_attention = self._determine_attention_type(config, layer_idx) self.has_flash_attn = HAS_FLASH_ATTN and config.use_flash_attn self.has_xformers = HAS_XFORMERS self.has_sdpa = HAS_SDPA self._mask_cache = OrderedDict() self._max_cache_size = 100 def _determine_attention_type(self, config, layer_idx): if layer_idx is None: return False if config.attention_pattern == "all_global": return True elif config.attention_pattern == "all_local": return False elif config.attention_pattern == "interleaved": return (layer_idx % config.global_attention_every_n_layers) == ( config.global_attention_every_n_layers - 1 ) return False def forward( self, hidden_states, attention_mask=None, past_key_value=None, use_cache=False ): batch_size, seq_length, _ = hidden_states.size() query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) query_states = query_states.view( batch_size, seq_length, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( batch_size, seq_length, self.num_key_value_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( batch_size, seq_length, self.num_key_value_heads, self.head_dim ).transpose(1, 2) if self.q_norm is not None and self.k_norm is not None: query_states = self.q_norm(query_states) key_states = self.k_norm(key_states) if past_key_value is not None and past_key_value[0] is not None: position_offset = past_key_value[0].shape[-2] else: position_offset = 0 if self.rotary_emb is not None: cos, sin = self.rotary_emb(query_states, seq_length, position_offset) query_states, key_states = apply_rotary_pos_emb( query_states, key_states, cos, sin ) if past_key_value is not None and past_key_value[0] is not None: key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) if use_cache: present_key_value = (key_states, value_states) else: present_key_value = None key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) kv_seq_len = key_states.shape[-2] use_sliding_window = (not self.is_global_attention) and ( self.sliding_window is not None ) if self.has_flash_attn and attention_mask is None: if query_states.device.type == "cuda" and query_states.dtype in [ torch.float16, torch.bfloat16, ]: try: attn_output = self._flash_attention( query_states, key_states, value_states, use_sliding_window ) except Exception as e: logger.warning(f"Flash Attention gagal, pakai fallback: {e}") attn_output = self._fallback_attention( query_states, key_states, value_states, attention_mask, kv_seq_len, use_sliding_window, ) else: attn_output = self._fallback_attention( query_states, key_states, value_states, attention_mask, kv_seq_len, use_sliding_window, ) else: attn_output = self._fallback_attention( query_states, key_states, value_states, attention_mask, kv_seq_len, use_sliding_window, ) attn_output = self.o_proj(attn_output) return attn_output, present_key_value def _flash_attention( self, query_states, key_states, value_states, use_sliding_window ): batch_size, num_heads, seq_length, head_dim = query_states.shape kv_seq_len = key_states.shape[-2] original_dtype = query_states.dtype compute_dtype = ( torch.bfloat16 if original_dtype not in [torch.float16, torch.bfloat16] else original_dtype ) query_states = query_states.transpose(1, 2).contiguous().to(compute_dtype) key_states = key_states.transpose(1, 2).contiguous().to(compute_dtype) value_states = value_states.transpose(1, 2).contiguous().to(compute_dtype) if use_sliding_window and self.sliding_window < kv_seq_len: window_size = (self.sliding_window, 0) else: window_size = (-1, 0) attn_output = flash_attn_func( query_states, key_states, value_states, dropout_p=self.config.attention_dropout if self.training else 0.0, softmax_scale=None, causal=True, window_size=window_size, ) attn_output = attn_output.to(original_dtype) attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) return attn_output def _fallback_attention( self, query_states, key_states, value_states, attention_mask, kv_seq_len, use_sliding_window, ): device_type = query_states.device.type if self.has_xformers and device_type == "cuda" and attention_mask is None: try: return self._xformers_attention( query_states, key_states, value_states, kv_seq_len, use_sliding_window, ) except Exception as e: logger.warning(f"xFormers gagal, pakai SDPA: {e}") if self.has_sdpa: return self._sdpa_attention( query_states, key_states, value_states, attention_mask, kv_seq_len, use_sliding_window, ) else: return self._standard_attention( query_states, key_states, value_states, attention_mask, kv_seq_len, use_sliding_window, ) def _create_causal_mask( self, query_length, key_length, dtype, device, use_sliding_window ): cache_key = ( query_length, key_length, str(device), str(dtype), use_sliding_window, self.sliding_window if use_sliding_window else None, ) if cache_key in self._mask_cache: self._mask_cache.move_to_end(cache_key) cached_mask = self._mask_cache[cache_key] return cached_mask.clone().to(device, dtype) if query_length > key_length: key_length = query_length query_pos = torch.arange(query_length, device=device) + ( key_length - query_length ) key_pos = torch.arange(key_length, device=device) distance = query_pos[:, None] - key_pos[None, :] mask = distance < 0 if use_sliding_window and self.sliding_window is not None: if self.config.use_attention_sink and self.attention_sink_size > 0: is_sink = key_pos[None, :] < self.attention_sink_size in_window = (distance >= 0) & (distance <= self.sliding_window) mask = (distance < 0) | (~is_sink & ~in_window) else: too_far_mask = distance > self.sliding_window mask = mask | too_far_mask float_mask = torch.zeros( 1, 1, query_length, key_length, dtype=dtype, device=device ) float_mask.masked_fill_(mask.unsqueeze(0).unsqueeze(0), torch.finfo(dtype).min) if len(self._mask_cache) >= self._max_cache_size: for _ in range(len(self._mask_cache) - self._max_cache_size + 1): self._mask_cache.popitem(last=False) self._mask_cache[cache_key] = float_mask.detach().cpu() return float_mask def _xformers_attention( self, query_states, key_states, value_states, kv_seq_len, use_sliding_window ): batch_size, num_heads, seq_length, head_dim = query_states.shape attn_bias = self._create_causal_mask( seq_length, kv_seq_len, query_states.dtype, query_states.device, use_sliding_window, ) query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) attn_output = memory_efficient_attention( query_states, key_states, value_states, attn_bias=attn_bias, p=self.config.attention_dropout if self.training else 0.0, ) attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) return attn_output def _sdpa_attention( self, query_states, key_states, value_states, attention_mask, kv_seq_len, use_sliding_window, ): batch_size, num_heads, seq_length, head_dim = query_states.shape if attention_mask is None: attention_mask = self._create_causal_mask( seq_length, kv_seq_len, query_states.dtype, query_states.device, use_sliding_window, ) if self.alibi is not None: alibi_bias = self.alibi(seq_length, kv_seq_len) attention_mask = attention_mask + alibi_bias attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=self.config.attention_dropout if self.training else 0.0, is_causal=False, ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) return attn_output def _standard_attention( self, query_states, key_states, value_states, attention_mask, kv_seq_len, use_sliding_window, ): batch_size, num_heads, seq_length, head_dim = query_states.shape attn_weights = torch.matmul( query_states, key_states.transpose(2, 3) ) / math.sqrt(head_dim) attn_weights = soft_cap_logits(attn_weights, self.attn_logit_softcapping) if attention_mask is None: attention_mask = self._create_causal_mask( seq_length, kv_seq_len, attn_weights.dtype, attn_weights.device, use_sliding_window, ) if self.alibi is not None: alibi_bias = self.alibi(seq_length, kv_seq_len) attention_mask = attention_mask + alibi_bias attn_weights = attn_weights + attention_mask attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( query_states.dtype ) attn_weights = self.attention_dropout(attn_weights) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) return attn_output class CacaCrossAttention(nn.Module): def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.head_dim self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.q_proj = nn.Linear( self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias ) self.k_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.v_proj = nn.Linear( self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias, ) self.o_proj = nn.Linear( self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias ) self.attention_dropout = nn.Dropout(config.attention_dropout) def forward(self, hidden_states, encoder_hidden_states, attention_mask=None): batch_size, seq_length, _ = hidden_states.size() encoder_seq_length = encoder_hidden_states.size(1) query_states = self.q_proj(hidden_states) key_states = self.k_proj(encoder_hidden_states) value_states = self.v_proj(encoder_hidden_states) query_states = query_states.view( batch_size, seq_length, self.num_heads, self.head_dim ).transpose(1, 2) key_states = key_states.view( batch_size, encoder_seq_length, self.num_key_value_heads, self.head_dim ).transpose(1, 2) value_states = value_states.view( batch_size, encoder_seq_length, self.num_key_value_heads, self.head_dim ).transpose(1, 2) key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1) value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1) attn_weights = torch.matmul( query_states, key_states.transpose(2, 3) ) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to( query_states.dtype ) attn_weights = self.attention_dropout(attn_weights) attn_output = torch.matmul(attn_weights, value_states) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output class CacaMLP(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear( self.hidden_size, self.intermediate_size, bias=config.mlp_bias ) self.up_proj = nn.Linear( self.hidden_size, self.intermediate_size, bias=config.mlp_bias ) self.down_proj = nn.Linear( self.intermediate_size, self.hidden_size, bias=config.mlp_bias ) self.dropout = nn.Dropout(config.hidden_dropout) def forward(self, x): gate = F.silu(self.gate_proj(x)) up = self.up_proj(x) hidden = gate * up hidden = self.dropout(hidden) output = self.down_proj(hidden) return output class CacaDecoderLayer(nn.Module): def __init__(self, config, layer_idx): super().__init__() self.config = config self.layer_idx = layer_idx self.self_attn = CacaAttention(config, layer_idx=layer_idx) self.use_moe = config.use_moe and (layer_idx % config.moe_layer_frequency == 0) if self.use_moe: self.mlp = MixtureOfExperts(config) else: self.mlp = CacaMLP(config) self.use_cross_attention = config.use_cross_attention and ( layer_idx % config.cross_attention_frequency == 0 ) if self.use_cross_attention: self.cross_attn = CacaCrossAttention(config) self.cross_attn_layernorm = CacaRMSNorm( config.hidden_size, config.rms_norm_eps ) self.input_layernorm = CacaRMSNorm(config.hidden_size, config.rms_norm_eps) self.post_attention_layernorm = CacaRMSNorm( config.hidden_size, config.rms_norm_eps ) self.residual_dropout = nn.Dropout(config.residual_dropout) if config.use_layer_scale: self.layer_scale_1 = LayerScale(config.hidden_size, config.layer_scale_init) self.layer_scale_2 = LayerScale(config.hidden_size, config.layer_scale_init) if self.use_cross_attention: self.layer_scale_cross = LayerScale( config.hidden_size, config.layer_scale_init ) else: self.layer_scale_1 = None self.layer_scale_2 = None self.layer_scale_cross = None if config.use_stochastic_depth: drop_prob = ( config.stochastic_depth_prob * layer_idx / config.num_hidden_layers ) self.stochastic_depth = StochasticDepth(drop_prob) else: self.stochastic_depth = None if config.use_mixture_of_depths: self.mod_router = MixtureOfDepthsRouter( config.hidden_size, config.mod_capacity_factor, config.mod_route_method ) else: self.mod_router = None def forward( self, hidden_states, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, use_cache=False, ): aux_loss = 0.0 z_loss = 0.0 if self.mod_router is not None: process_mask = self.mod_router(hidden_states) tokens_to_process = hidden_states[process_mask] if tokens_to_process.numel() == 0: present_key_value = past_key_value if use_cache else None return hidden_states, present_key_value, aux_loss, z_loss else: process_mask = None tokens_to_process = hidden_states residual = tokens_to_process tokens_to_process = self.input_layernorm(tokens_to_process) attn_output, present_key_value = self.self_attn( tokens_to_process, attention_mask, past_key_value=past_key_value, use_cache=use_cache, ) if self.layer_scale_1 is not None: attn_output = self.layer_scale_1(attn_output) if self.stochastic_depth is not None: attn_output = self.stochastic_depth(attn_output, self.training) tokens_to_process = residual + self.residual_dropout(attn_output) if self.use_cross_attention and encoder_hidden_states is not None: residual = tokens_to_process tokens_to_process = self.cross_attn_layernorm(tokens_to_process) cross_attn_output = self.cross_attn( tokens_to_process, encoder_hidden_states, attention_mask=encoder_attention_mask, ) if self.layer_scale_cross is not None: cross_attn_output = self.layer_scale_cross(cross_attn_output) if self.stochastic_depth is not None: cross_attn_output = self.stochastic_depth( cross_attn_output, self.training ) tokens_to_process = residual + self.residual_dropout(cross_attn_output) residual = tokens_to_process tokens_to_process = self.post_attention_layernorm(tokens_to_process) if self.use_moe: mlp_output, moe_aux_loss, moe_z_loss = self.mlp(tokens_to_process) aux_loss += moe_aux_loss z_loss += moe_z_loss else: mlp_output = self.mlp(tokens_to_process) if self.layer_scale_2 is not None: mlp_output = self.layer_scale_2(mlp_output) if self.stochastic_depth is not None: mlp_output = self.stochastic_depth(mlp_output, self.training) tokens_to_process = residual + self.residual_dropout(mlp_output) if process_mask is not None: hidden_states[process_mask] = tokens_to_process else: hidden_states = tokens_to_process return hidden_states, present_key_value, aux_loss, z_loss class VisionEncoder(nn.Module): def __init__(self, config): super().__init__() vision_config = config.vision_config self.patch_size = vision_config.get("patch_size", 14) self.image_size = vision_config.get("image_size", 224) self.num_channels = vision_config.get("num_channels", 3) self.hidden_size = vision_config.get("hidden_size", 1024) self.num_layers = vision_config.get("num_layers", 24) self.num_heads = vision_config.get("num_heads", 16) self.intermediate_size = vision_config.get("intermediate_size", 4096) self.layer_norm_eps = vision_config.get("layer_norm_eps", 1e-6) self.num_patches = (self.image_size // self.patch_size) ** 2 self.patch_embed = nn.Sequential( nn.Conv2d( self.num_channels, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size, bias=False, ), nn.Dropout(p=vision_config.get("dropout", 0.0)), ) self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size)) self.pos_embed = nn.Parameter( torch.zeros(1, self.num_patches + 1, self.hidden_size) ) self.pos_drop = nn.Dropout(p=vision_config.get("dropout", 0.0)) self.blocks = nn.ModuleList( [ VisionTransformerBlock( dim=self.hidden_size, num_heads=self.num_heads, mlp_ratio=self.intermediate_size / self.hidden_size, dropout=vision_config.get("dropout", 0.0), layer_norm_eps=self.layer_norm_eps, ) for _ in range(self.num_layers) ] ) self.norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps) self._init_weights() def _init_weights(self): nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) nn.init.trunc_normal_(self.patch_embed[0].weight, std=0.02) def forward(self, pixel_values): batch_size = pixel_values.shape[0] x = self.patch_embed(pixel_values) x = x.flatten(2).transpose(1, 2) cls_tokens = self.cls_token.expand(batch_size, -1, -1) x = torch.cat([cls_tokens, x], dim=1) x = x + self.pos_embed x = self.pos_drop(x) for block in self.blocks: x = block(x) x = self.norm(x) return x class VisionTransformerBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0, layer_norm_eps=1e-6): super().__init__() self.norm1 = nn.LayerNorm(dim, eps=layer_norm_eps) self.attn = nn.MultiheadAttention( dim, num_heads, dropout=dropout, batch_first=True ) self.drop_path1 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(dim, eps=layer_norm_eps) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(dim, mlp_hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_hidden_dim, dim), nn.Dropout(dropout), ) self.drop_path2 = nn.Dropout(dropout) def forward(self, x): residual = x x = self.norm1(x) x = self.attn(x, x, x, need_weights=False)[0] x = self.drop_path1(x) x = residual + x residual = x x = self.norm2(x) x = self.mlp(x) x = self.drop_path2(x) x = residual + x return x class AudioEncoder(nn.Module): def __init__(self, config): super().__init__() audio_config = config.audio_config self.num_mel_bins = audio_config.get("num_mel_bins", 80) self.hidden_size = audio_config.get("hidden_size", 1024) self.num_layers = audio_config.get("num_layers", 12) self.num_heads = audio_config.get("num_heads", 16) self.intermediate_size = audio_config.get("intermediate_size", 4096) self.max_audio_length = audio_config.get("max_audio_length", 3000) self.dropout = audio_config.get("dropout", 0.0) self.conv1 = nn.Sequential( nn.Conv1d(self.num_mel_bins, self.hidden_size, kernel_size=3, padding=1), nn.GELU(), nn.Dropout(p=self.dropout), ) self.conv2 = nn.Sequential( nn.Conv1d( self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1 ), nn.GELU(), nn.Dropout(p=self.dropout), ) self.pos_embed = nn.Parameter( torch.zeros(1, self.max_audio_length // 2, self.hidden_size) ) self.pos_drop = nn.Dropout(p=self.dropout) self.blocks = nn.ModuleList( [ AudioTransformerBlock( dim=self.hidden_size, num_heads=self.num_heads, mlp_ratio=self.intermediate_size / self.hidden_size, dropout=self.dropout, ) for _ in range(self.num_layers) ] ) self.norm = nn.LayerNorm(self.hidden_size) self._init_weights() def _init_weights(self): nn.init.trunc_normal_(self.pos_embed, std=0.02) def forward(self, audio_features): x = F.gelu(self.conv1(audio_features)) x = F.gelu(self.conv2(x)) x = x.transpose(1, 2) seq_len = x.shape[1] if seq_len <= self.pos_embed.shape[1]: x = x + self.pos_embed[:, :seq_len, :] else: pos_embed_interp = F.interpolate( self.pos_embed.transpose(1, 2), size=seq_len, mode="linear", align_corners=False, ).transpose(1, 2) x = x + pos_embed_interp x = self.pos_drop(x) for block in self.blocks: x = block(x) x = self.norm(x) return x class AudioTransformerBlock(nn.Module): def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0): super().__init__() self.norm1 = nn.LayerNorm(dim) self.attn = nn.MultiheadAttention( dim, num_heads, dropout=dropout, batch_first=True ) self.drop_path1 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(dim, mlp_hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_hidden_dim, dim), nn.Dropout(dropout), ) self.drop_path2 = nn.Dropout(dropout) def forward(self, x): residual = x x = self.norm1(x) x = self.attn(x, x, x, need_weights=False)[0] x = self.drop_path1(x) x = residual + x residual = x x = self.norm2(x) x = self.mlp(x) x = self.drop_path2(x) x = residual + x return x class MultiModalProjector(nn.Module): def __init__(self, input_size, output_size, projector_type="mlp", num_layers=2): super().__init__() self.projector_type = projector_type if projector_type == "linear": self.projector = nn.Linear(input_size, output_size) elif projector_type == "mlp": layers = [] current_size = input_size for i in range(num_layers - 1): layers.extend( [nn.Linear(current_size, output_size), nn.GELU(), nn.Dropout(0.1)] ) current_size = output_size layers.append(nn.Linear(current_size, output_size)) self.projector = nn.Sequential(*layers) elif projector_type == "perceiver": self.projector = PerceiverResampler( input_size, output_size, num_latents=64, num_layers=2 ) elif projector_type == "qformer": self.projector = QFormerProjector( input_size, output_size, num_queries=32, num_layers=2 ) else: raise ValueError(f"projector_type tidak dikenal: {projector_type}") def forward(self, x): return self.projector(x) class PerceiverResampler(nn.Module): def __init__(self, input_size, output_size, num_latents=64, num_layers=2): super().__init__() self.num_latents = num_latents self.latents = nn.Parameter(torch.randn(num_latents, output_size)) self.layers = nn.ModuleList( [ PerceiverLayer(output_size, input_size if i == 0 else output_size) for i in range(num_layers) ] ) self.norm = nn.LayerNorm(output_size) def forward(self, x): batch_size = x.shape[0] latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1) for i, layer in enumerate(self.layers): if i == 0: latents = layer(latents, x) else: latents = layer(latents, latents) return self.norm(latents) class PerceiverLayer(nn.Module): def __init__(self, query_dim, key_dim): super().__init__() self.cross_attn = nn.MultiheadAttention( query_dim, num_heads=8, kdim=key_dim, vdim=key_dim, batch_first=True ) self.mlp = nn.Sequential( nn.LayerNorm(query_dim), nn.Linear(query_dim, query_dim * 4), nn.GELU(), nn.Linear(query_dim * 4, query_dim), ) self.norm1 = nn.LayerNorm(query_dim) self.norm2 = nn.LayerNorm(query_dim) def forward(self, query, key): query = ( query + self.cross_attn(self.norm1(query), key, key, need_weights=False)[0] ) query = query + self.mlp(self.norm2(query)) return query class QFormerProjector(nn.Module): def __init__(self, input_size, output_size, num_queries=32, num_layers=2): super().__init__() self.num_queries = num_queries self.query_embeds = nn.Parameter(torch.randn(num_queries, output_size)) self.query_layers = nn.ModuleList( [ nn.TransformerEncoderLayer( d_model=output_size, nhead=8, dim_feedforward=output_size * 4, batch_first=True, ) for _ in range(num_layers) ] ) self.cross_attn_layers = nn.ModuleList( [ nn.MultiheadAttention( output_size, num_heads=8, kdim=input_size, vdim=input_size, batch_first=True, ) for _ in range(num_layers) ] ) self.norm = nn.LayerNorm(output_size) def forward(self, x): batch_size = x.shape[0] queries = self.query_embeds.unsqueeze(0).expand(batch_size, -1, -1) for query_layer, cross_attn_layer in zip( self.query_layers, self.cross_attn_layers ): queries = query_layer(queries) queries = queries + cross_attn_layer(queries, x, x, need_weights=False)[0] return self.norm(queries) class CacaPreTrainedModel(PreTrainedModel): config_class = CacaConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["CacaDecoderLayer"] _skip_keys_device_placement = "past_key_values" def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class CacaModel(CacaPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList( [ CacaDecoderLayer(config, layer_idx=idx) for idx in range(config.num_hidden_layers) ] ) self.norm = CacaRMSNorm(config.hidden_size, config.rms_norm_eps) self.gradient_checkpointing = False if config.use_multimodal: if config.vision_config: self.vision_encoder = VisionEncoder(config) vision_hidden_size = config.vision_config.get("hidden_size", 768) self.vision_projector = MultiModalProjector( vision_hidden_size, config.hidden_size, projector_type=config.vision_config.get("projector_type", "mlp"), ) else: self.vision_encoder = None self.vision_projector = None if config.audio_config: self.audio_encoder = AudioEncoder(config) audio_hidden_size = config.audio_config.get("hidden_size", 768) self.audio_projector = MultiModalProjector( audio_hidden_size, config.hidden_size, projector_type=config.audio_config.get("projector_type", "mlp"), ) else: self.audio_encoder = None self.audio_projector = None self.post_init() def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def _prepare_attention_mask(self, attention_mask, input_shape, dtype): if attention_mask is None: return None batch_size, seq_length = input_shape if attention_mask.dim() == 2: attention_mask = attention_mask[:, None, None, :] elif attention_mask.dim() == 3: attention_mask = attention_mask[:, None, :, :] attention_mask = attention_mask.to(dtype=dtype) attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min return attention_mask def forward( self, input_ids=None, pixel_values=None, audio_features=None, attention_mask=None, past_key_values=None, use_cache=None, output_hidden_states=False, return_dict=True, **kwargs, ): use_cache = use_cache if use_cache is not None else self.config.use_cache if input_ids is not None: batch_size, seq_length = input_ids.shape hidden_states = self.embed_tokens(input_ids) else: raise ValueError("input_ids tidak boleh None") encoder_hidden_states = None encoder_attention_mask = None if self.config.use_multimodal: multimodal_embeds = [] if pixel_values is not None and self.vision_encoder is not None: vision_features = self.vision_encoder(pixel_values) vision_embeds = self.vision_projector(vision_features) multimodal_embeds.append(vision_embeds) if audio_features is not None and self.audio_encoder is not None: audio_encoded = self.audio_encoder(audio_features) audio_embeds = self.audio_projector(audio_encoded) multimodal_embeds.append(audio_embeds) if multimodal_embeds and self.config.use_cross_attention: encoder_hidden_states = torch.cat(multimodal_embeds, dim=1) encoder_seq_len = encoder_hidden_states.shape[1] encoder_attention_mask = torch.ones( batch_size, encoder_seq_len, dtype=hidden_states.dtype, device=hidden_states.device, ) elif multimodal_embeds: multimodal_concat = torch.cat(multimodal_embeds, dim=1) hidden_states = torch.cat([multimodal_concat, hidden_states], dim=1) seq_length = hidden_states.shape[1] if attention_mask is not None: multimodal_mask = torch.ones( batch_size, multimodal_concat.shape[1], dtype=attention_mask.dtype, device=attention_mask.device, ) attention_mask = torch.cat([multimodal_mask, attention_mask], dim=1) if attention_mask is not None: attention_mask = self._prepare_attention_mask( attention_mask, (batch_size, seq_length), hidden_states.dtype ) if encoder_attention_mask is not None and self.config.use_cross_attention: encoder_attention_mask = self._prepare_attention_mask( encoder_attention_mask, (batch_size, encoder_hidden_states.shape[1]), hidden_states.dtype, ) if use_cache: if past_key_values is None: past_key_values = tuple([None] * len(self.layers)) present_key_values = [] if use_cache else None all_hidden_states = [] if output_hidden_states else None total_aux_loss = 0.0 total_z_loss = 0.0 for idx, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states.append(hidden_states) past_key_value = ( past_key_values[idx] if past_key_values is not None else None ) if self.gradient_checkpointing and self.training and (past_key_value is None or past_key_value[0] is None): hidden_states = self._gradient_checkpointing_forward( layer, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, ) present_key_value = None aux_loss = 0.0 z_loss = 0.0 else: hidden_states, present_key_value, aux_loss, z_loss = layer( hidden_states, attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_value=past_key_value, use_cache=use_cache, ) if use_cache: present_key_values.append(present_key_value) total_aux_loss += aux_loss total_z_loss += z_loss hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states.append(hidden_states) if not return_dict: return tuple( v for v in [ hidden_states, present_key_values, all_hidden_states, total_aux_loss, total_z_loss, ] if v is not None ) return ( BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=tuple(present_key_values) if use_cache else None, hidden_states=all_hidden_states, attentions=None, ), total_aux_loss, total_z_loss, ) def _gradient_checkpointing_forward( self, layer, hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, ): from torch.utils.checkpoint import checkpoint def create_custom_forward(module): def custom_forward(*inputs): ( hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, ) = inputs output, _, _, _ = module( hidden_states, attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_value=None, use_cache=False, ) return output return custom_forward hidden_states = checkpoint( create_custom_forward(layer), hidden_states, attention_mask, encoder_hidden_states, encoder_attention_mask, use_reentrant=False, ) return hidden_states class CacaForCausalLM(CacaPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] def __init__(self, config): super().__init__(config) self.model = CacaModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def set_decoder(self, decoder): self.model = decoder def get_decoder(self): return self.model def forward( self, input_ids=None, pixel_values=None, audio_features=None, attention_mask=None, labels=None, past_key_values=None, inputs_embeds=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, **kwargs, ): if input_ids is not None: if not input_ids.dtype in [torch.long, torch.int, torch.int32, torch.int64]: raise TypeError( f"input_ids harus integer dtype, dapat {input_ids.dtype}. " f"Gunakan input_ids.long() untuk convert." ) if (input_ids < 0).any(): neg_vals = input_ids[input_ids < 0].unique().tolist() raise ValueError(f"input_ids mengandung nilai negatif: {neg_vals}") max_val = input_ids.max().item() if max_val >= self.config.vocab_size: raise ValueError( f"input_ids mengandung nilai >= vocab_size. " f"Max value: {max_val}, vocab_size: {self.config.vocab_size:,}" ) if labels is not None: if not labels.dtype in [torch.long, torch.int, torch.int32, torch.int64]: raise TypeError(f"labels harus integer dtype, dapat {labels.dtype}") if (labels[labels != -100] < 0).any(): raise ValueError(f"labels mengandung nilai negatif (selain -100)") max_label = labels[labels != -100].max().item() if (labels != -100).any() else 0 if max_label >= self.config.vocab_size: raise ValueError( f"labels mengandung nilai >= vocab_size. " f"Max: {max_label}, vocab_size: {self.config.vocab_size}" ) return_dict = ( return_dict if return_dict is not None else self.config.use_return_dict ) outputs, aux_loss, z_loss = self.model( input_ids, pixel_values=pixel_values, audio_features=audio_features, attention_mask=attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if return_dict: hidden_states = outputs.last_hidden_state else: hidden_states = outputs[0] logits = self.lm_head(hidden_states) if self.config.final_logit_softcapping: logits = soft_cap_logits(logits, self.config.final_logit_softcapping) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() lm_loss = loss_fct( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) ) if self.config.use_moe: total_loss = ( lm_loss + (self.config.router_aux_loss_coef * aux_loss) + (self.config.router_z_loss_coef * z_loss) ) else: total_loss = lm_loss loss = total_loss if not return_dict: output = (logits,) if return_dict: output += tuple( v for v in [outputs.past_key_values, outputs.hidden_states] if v is not None ) return ((loss,) + output) if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values if return_dict else None, hidden_states=outputs.hidden_states if return_dict else None, attentions=None, ) def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, pixel_values=None, audio_features=None, **kwargs, ): has_past = ( past_key_values is not None and len(past_key_values) > 0 and past_key_values[0] is not None and len(past_key_values[0]) > 0 and past_key_values[0][0] is not None ) if has_past: past_length = past_key_values[0][0].shape[2] if input_ids.shape[1] > past_length: remove_prefix_length = past_length else: remove_prefix_length = input_ids.shape[1] - 1 input_ids = input_ids[:, remove_prefix_length:] if inputs_embeds is not None and not has_past: model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} model_inputs.update( { "past_key_values": past_key_values if has_past else None, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, "pixel_values": pixel_values if not has_past else None, "audio_features": audio_features if not has_past else None, } ) return model_inputs @staticmethod def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple( past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past ), ) return reordered_past class CacaForCausalLMQuantized(CacaForCausalLM): def __init__(self, config, quantization_config=None): super().__init__(config) self.quantization_config = quantization_config if quantization_config: self._apply_quantization() def _apply_quantization(self): if self.quantization_config.get("load_in_8bit"): self._quantize_8bit() elif self.quantization_config.get("load_in_4bit"): self._quantize_4bit() def _quantize_8bit(self): try: import bitsandbytes as bnb for name, module in self.named_modules(): if isinstance(module, nn.Linear): has_bias = module.bias is not None new_module = bnb.nn.Linear8bitLt( module.in_features, module.out_features, has_bias, threshold=self.quantization_config.get( "llm_int8_threshold", 6.0 ), ) new_module.weight = module.weight if has_bias: new_module.bias = module.bias parent_name = ".".join(name.split(".")[:-1]) child_name = name.split(".")[-1] if parent_name: parent = self.get_submodule(parent_name) setattr(parent, child_name, new_module) else: setattr(self, child_name, new_module) logger.info("Quantisasi 8-bit berhasil diterapkan") except ImportError: logger.error("bitsandbytes tidak terinstall! pip install bitsandbytes") def _quantize_4bit(self): try: import bitsandbytes as bnb compute_dtype = torch.float16 if self.quantization_config.get("bnb_4bit_compute_dtype"): compute_dtype = getattr( torch, self.quantization_config["bnb_4bit_compute_dtype"] ) for name, module in self.named_modules(): if isinstance(module, nn.Linear): has_bias = module.bias is not None new_module = bnb.nn.Linear4bit( module.in_features, module.out_features, bias=has_bias, compute_dtype=compute_dtype, quant_type=self.quantization_config.get( "bnb_4bit_quant_type", "nf4" ), use_double_quant=self.quantization_config.get( "bnb_4bit_use_double_quant", True ), ) new_module.weight = module.weight if has_bias: new_module.bias = module.bias parent_name = ".".join(name.split(".")[:-1]) child_name = name.split(".")[-1] if parent_name: parent = self.get_submodule(parent_name) setattr(parent, child_name, new_module) else: setattr(self, child_name, new_module) logger.info("Quantisasi 4-bit berhasil diterapkan") except ImportError: logger.error("bitsandbytes tidak terinstall!") @classmethod def from_pretrained_quantized(cls, model_path, quantization_config): config = CacaConfig.from_pretrained(model_path) model = cls(config, quantization_config=quantization_config) state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu") model.load_state_dict(state_dict, strict=False) return model