caca-5M-untrained / caca_transformers.py
Lyon28's picture
Add custom modeling file
a963c81 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, List
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.generation.utils import GenerationMixin
from collections import OrderedDict
import logging
from functools import lru_cache
logger = logging.getLogger(__name__)
try:
from flash_attn import flash_attn_func
HAS_FLASH_ATTN = True
except ImportError:
HAS_FLASH_ATTN = False
try:
from xformers.ops import memory_efficient_attention
HAS_XFORMERS = True
except ImportError:
HAS_XFORMERS = False
HAS_SDPA = hasattr(F, 'scaled_dot_product_attention')
# --- config ---
class CacaConfig(PretrainedConfig):
model_type = "caca"
def __init__(
self,
vocab_size=32000,
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=24,
num_attention_heads=32,
num_key_value_heads=8,
head_dim=64,
max_position_embeddings=8192,
rms_norm_eps=1e-6,
qk_norm_eps=1e-6,
initializer_range=0.02,
use_cache=True,
pad_token_id=None,
bos_token_id=1,
eos_token_id=2,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
use_rotary_embeddings=True,
attention_bias=False,
attention_dropout=0.0,
use_qk_norm=True,
use_alibi=False,
use_flash_attn=True,
use_grouped_query_attention=False,
use_multi_query_attention=False,
sliding_window=None,
use_longformer_attention=False,
longformer_attention_window=512,
attn_logit_softcapping=None,
final_logit_softcapping=None,
attention_sink_size=4,
attention_sink_window=1024,
use_attention_sink=False,
attention_pattern="all_global",
global_attention_every_n_layers=2,
mlp_bias=False,
hidden_dropout=0.1,
residual_dropout=0.1,
use_moe=False,
num_experts=8,
num_experts_per_tok=2,
use_expert_choice=False,
expert_choice_k=0.125,
router_aux_loss_coef=0.01,
router_z_loss_coef=0.001,
moe_layer_frequency=2,
expert_capacity_factor=1.0,
use_grouped_moe=False,
num_expert_groups=1,
use_layer_scale=False,
layer_scale_init=1e-5,
use_stochastic_depth=False,
stochastic_depth_prob=0.1,
use_mixture_of_depths=False,
mod_capacity_factor=0.5,
mod_route_method="learned",
use_cross_attention=False,
cross_attention_frequency=4,
use_multimodal=False,
vision_config=None,
audio_config=None,
projector_hidden_size=None,
use_soft_merging=False,
merge_threshold=0.5,
pretraining_tp=1,
tensor_parallel_size=1,
pipeline_parallel_size=1,
chat_template=None,
**kwargs
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim or (hidden_size // num_attention_heads if hidden_size and num_attention_heads else None)
self.max_position_embeddings = max_position_embeddings
self.rms_norm_eps = rms_norm_eps
self.qk_norm_eps = qk_norm_eps
self.initializer_range = initializer_range
self.use_cache = use_cache
self.pad_token_id = pad_token_id
self.bos_token_id = bos_token_id
self.eos_token_id = eos_token_id
self.tie_word_embeddings = tie_word_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.use_rotary_embeddings = use_rotary_embeddings
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.use_qk_norm = use_qk_norm
self.use_alibi = use_alibi
self.use_flash_attn = use_flash_attn
self.use_grouped_query_attention = use_grouped_query_attention
self.use_multi_query_attention = use_multi_query_attention
self.sliding_window = sliding_window
self.use_longformer_attention = use_longformer_attention
self.longformer_attention_window = longformer_attention_window
self.attn_logit_softcapping = attn_logit_softcapping
self.final_logit_softcapping = final_logit_softcapping
self.attention_sink_size = attention_sink_size
self.attention_sink_window = attention_sink_window
self.use_attention_sink = use_attention_sink
self.attention_pattern = attention_pattern
self.global_attention_every_n_layers = global_attention_every_n_layers
self.mlp_bias = mlp_bias
self.hidden_dropout = hidden_dropout
self.residual_dropout = residual_dropout
self.use_moe = use_moe
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.use_expert_choice = use_expert_choice
self.expert_choice_k = expert_choice_k
self.router_aux_loss_coef = router_aux_loss_coef
self.router_z_loss_coef = router_z_loss_coef
self.moe_layer_frequency = moe_layer_frequency
self.expert_capacity_factor = expert_capacity_factor
self.use_grouped_moe = use_grouped_moe
self.num_expert_groups = num_expert_groups
self.use_layer_scale = use_layer_scale
self.layer_scale_init = layer_scale_init
self.use_stochastic_depth = use_stochastic_depth
self.stochastic_depth_prob = stochastic_depth_prob
self.use_mixture_of_depths = use_mixture_of_depths
self.mod_capacity_factor = mod_capacity_factor
self.mod_route_method = mod_route_method
self.use_cross_attention = use_cross_attention
self.cross_attention_frequency = cross_attention_frequency
self.use_multimodal = use_multimodal
self.vision_config = vision_config or {}
self.audio_config = audio_config or {}
self.projector_hidden_size = projector_hidden_size or hidden_size
self.use_soft_merging = use_soft_merging
self.merge_threshold = merge_threshold
self.pretraining_tp = pretraining_tp
self.tensor_parallel_size = tensor_parallel_size
self.pipeline_parallel_size = pipeline_parallel_size
if chat_template is None:
self.chat_template = (
"{% for message in messages %}"
"{% if message['role'] == 'system' %}"
"System: {{ message['content'] }}\n"
"{% elif message['role'] == 'user' %}"
"User: {{ message['content'] }}\n"
"{% elif message['role'] == 'assistant' %}"
"Assistant: {{ message['content'] }}\n"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}Assistant:{% endif %}"
)
else:
self.chat_template = chat_template
self._validate_config()
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs
)
def _validate_config(self):
if self.num_attention_heads % self.num_key_value_heads != 0:
raise ValueError(
f"num_attention_heads ({self.num_attention_heads}) harus habis dibagi "
f"num_key_value_heads ({self.num_key_value_heads})"
)
if self.use_moe and self.num_experts < self.num_experts_per_tok:
raise ValueError(
f"num_experts ({self.num_experts}) harus >= "
f"num_experts_per_tok ({self.num_experts_per_tok})"
)
if self.hidden_size % self.num_attention_heads != 0:
raise ValueError(
f"hidden_size ({self.hidden_size}) harus habis dibagi "
f"num_attention_heads ({self.num_attention_heads})"
)
if self.vocab_size <= 0:
raise ValueError(f"vocab_size harus > 0, dapat {self.vocab_size}")
if self.use_flash_attn and not HAS_FLASH_ATTN:
logger.warning(
"use_flash_attn=True tapi flash-attn tidak terinstall. "
"Akan fallback ke SDPA/standard attention."
)
if self.sliding_window is not None:
if self.sliding_window > self.max_position_embeddings:
raise ValueError(
f"sliding_window ({self.sliding_window}) tidak boleh > "
f"max_position_embeddings ({self.max_position_embeddings})"
)
if self.use_moe:
if self.moe_layer_frequency <= 0:
raise ValueError(f"moe_layer_frequency harus > 0")
if self.moe_layer_frequency > self.num_hidden_layers:
logger.warning(
f"moe_layer_frequency ({self.moe_layer_frequency}) > "
f"num_hidden_layers ({self.num_hidden_layers}). "
f"MoE tidak akan digunakan."
)
def to_dict(self):
has_quant_config = hasattr(self, 'quantization_config')
quantization_config_backup = getattr(self, 'quantization_config', None)
if has_quant_config and quantization_config_backup is None:
delattr(self, 'quantization_config')
try:
output = super().to_dict()
output['auto_map'] = {
"AutoConfig": "caca_transformers.CacaConfig",
"AutoModel": "caca_transformers.CacaModel",
"AutoModelForCausalLM": "caca_transformers.CacaForCausalLM"
}
finally:
if has_quant_config:
self.quantization_config = quantization_config_backup
return output
# --- Arsitektur Model ---
class CacaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
input_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return (self.weight * x).to(input_dtype)
class LayerScale(nn.Module):
def __init__(self, dim, init_value=1e-5):
super().__init__()
self.gamma = nn.Parameter(init_value * torch.ones(dim))
def forward(self, x):
return self.gamma * x
class StochasticDepth(nn.Module):
def __init__(self, drop_prob=0.0):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x, training=True):
if not training or self.drop_prob == 0.0:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_()
return x.div(keep_prob) * random_tensor
class CacaRotaryEmbedding(nn.Module):
def __init__(
self,
dim,
max_position_embeddings=8192,
base=10000.0,
scaling_factor=1.0,
scaling_type=None,
):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.scaling_factor = scaling_factor
self.scaling_type = scaling_type
inv_freq = 1.0 / (
self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
)
if scaling_type == "linear":
inv_freq = inv_freq / scaling_factor
elif scaling_type == "dynamic":
inv_freq = inv_freq
elif scaling_type == "yarn":
inv_freq = self._yarn_get_inv_freq(inv_freq)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def _yarn_get_inv_freq(self, inv_freq):
if len(inv_freq) == 0:
return inv_freq
alpha = self.scaling_factor
beta_fast = 32
beta_slow = 1
freq_threshold = 1 / (self.max_position_embeddings * beta_fast)
low_freq_mask = inv_freq > freq_threshold
high_freq_mask = ~low_freq_mask
low_freq = inv_freq[low_freq_mask]
high_freq = inv_freq[high_freq_mask]
if len(low_freq) > 0:
low_freq = low_freq / alpha
if len(high_freq) > 0:
smooth_factor = (
self.max_position_embeddings * beta_slow / high_freq - beta_fast
) / (beta_slow - beta_fast)
smooth_factor = torch.clamp(smooth_factor, 0.0, 1.0)
high_freq = (1 - smooth_factor) * (
high_freq / alpha
) + smooth_factor * high_freq
result = torch.zeros_like(inv_freq)
result[low_freq_mask] = low_freq
result[high_freq_mask] = high_freq
return result
def forward(self, x, seq_len, position_offset=0):
t = torch.arange(
position_offset, position_offset + seq_len, device=x.device
).type_as(self.inv_freq)
if self.scaling_type == "dynamic":
if seq_len > self.max_position_embeddings:
dynamic_scale = seq_len / self.max_position_embeddings
t = t / dynamic_scale
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()[None, None, :, :]
sin = emb.sin()[None, None, :, :]
return cos.to(x.dtype), sin.to(x.dtype)
class ALiBiPositionalBias(nn.Module):
def __init__(self, num_heads, max_positions=8192):
super().__init__()
self.num_heads = num_heads
self.max_positions = max_positions
slopes = torch.tensor(self._get_slopes(num_heads))
self.register_buffer("slopes", slopes, persistent=False)
def _get_slopes(self, n):
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * (ratio**i) for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
return (
get_slopes_power_of_2(closest_power_of_2)
+ self._get_slopes(2 * closest_power_of_2)[0::2][
: n - closest_power_of_2
]
)
def forward(self, seq_len, key_len=None):
if key_len is None:
key_len = seq_len
query_pos = torch.arange(seq_len, device=self.slopes.device).unsqueeze(1)
key_pos = torch.arange(key_len, device=self.slopes.device).unsqueeze(0)
relative_pos = key_pos - query_pos
bias = relative_pos.unsqueeze(0) * self.slopes.unsqueeze(1).unsqueeze(2)
return bias.unsqueeze(0)
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
cos = cos.to(q.dtype)
sin = sin.to(q.dtype)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def soft_cap_logits(x, cap):
if cap is None or cap <= 0:
return x
return x.clamp(-cap * 0.99, cap * 0.99)
class TopKRouter(nn.Module):
def __init__(self, hidden_size, num_experts, num_experts_per_tok):
super().__init__()
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.gate = nn.Linear(hidden_size, num_experts, bias=False)
self.gate_norm = nn.LayerNorm(hidden_size)
def forward(self, hidden_states):
batch_size, seq_len, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size)
hidden_states = self.gate_norm(hidden_states)
router_logits = self.gate(hidden_states)
router_logits = torch.clamp(router_logits, min=-10, max=10)
routing_weights = F.softmax(router_logits, dim=-1)
top_k_weights, top_k_indices = torch.topk(
routing_weights, self.num_experts_per_tok, dim=-1
)
top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-8)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
expert_usage = router_probs.mean(dim=0)
mean_usage = expert_usage.mean()
aux_loss = ((expert_usage - mean_usage) ** 2).sum() / (mean_usage + 1e-10)
router_logits_for_z = router_logits.to(torch.float32)
z_loss = torch.logsumexp(router_logits_for_z, dim=-1).mean()
return top_k_weights, top_k_indices, aux_loss, z_loss
class ExpertChoiceRouter(nn.Module):
def __init__(self, hidden_size, num_experts, expert_choice_k):
super().__init__()
self.num_experts = num_experts
self.expert_choice_k = expert_choice_k
self.gate = nn.Linear(hidden_size, num_experts, bias=False)
def forward(self, hidden_states):
batch_size, seq_len, hidden_size = hidden_states.shape
total_tokens = batch_size * seq_len
hidden_states_flat = hidden_states.view(-1, hidden_size)
router_logits = self.gate(hidden_states_flat)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
router_probs_t = router_probs.t()
capacity = max(1, int(self.expert_choice_k * total_tokens / self.num_experts))
top_k_values, top_k_indices = torch.topk(
router_probs_t, k=min(capacity, total_tokens), dim=-1
)
expert_mask = torch.zeros(
self.num_experts, total_tokens, device=hidden_states.device
)
for expert_idx in range(self.num_experts):
expert_mask[expert_idx, top_k_indices[expert_idx]] = 1.0
routing_weights = expert_mask.t() * router_probs
aux_loss = (router_probs.mean(dim=0) ** 2).sum() * self.num_experts
z_loss = torch.logsumexp(router_logits, dim=-1).mean()
return routing_weights, aux_loss, z_loss
class Expert(nn.Module):
def __init__(self, config):
super().__init__()
self.gate_proj = nn.Linear(
config.hidden_size, config.intermediate_size, bias=config.mlp_bias
)
self.up_proj = nn.Linear(
config.hidden_size, config.intermediate_size, bias=config.mlp_bias
)
self.down_proj = nn.Linear(
config.intermediate_size, config.hidden_size, bias=config.mlp_bias
)
self.dropout = nn.Dropout(config.hidden_dropout)
def forward(self, x):
gate = F.silu(self.gate_proj(x))
up = self.up_proj(x)
hidden = gate * up
hidden = self.dropout(hidden)
return self.down_proj(hidden)
class MixtureOfExperts(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.num_experts = config.num_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.use_expert_choice = config.use_expert_choice
self.experts = nn.ModuleList([Expert(config) for _ in range(self.num_experts)])
if self.use_expert_choice:
self.router = ExpertChoiceRouter(
config.hidden_size, config.num_experts, config.expert_choice_k
)
else:
self.router = TopKRouter(
config.hidden_size, config.num_experts, config.num_experts_per_tok
)
def forward(self, hidden_states):
batch_size, seq_len, hidden_size = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_size)
if self.use_expert_choice:
routing_weights, aux_loss, z_loss = self.router(hidden_states)
final_output = torch.zeros_like(hidden_states_flat)
for expert_idx, expert in enumerate(self.experts):
expert_mask = routing_weights[:, expert_idx] > 0
if expert_mask.any():
expert_input = hidden_states_flat[expert_mask]
expert_output = expert(expert_input)
final_output[expert_mask] += (
expert_output
* routing_weights[expert_mask, expert_idx : expert_idx + 1]
)
else:
top_k_weights, top_k_indices, aux_loss, z_loss = self.router(hidden_states)
final_output = torch.zeros_like(hidden_states_flat)
for i in range(self.num_experts_per_tok):
expert_indices = top_k_indices[:, i]
expert_weights = top_k_weights[:, i : i + 1]
for expert_idx in range(self.num_experts):
expert_mask = expert_indices == expert_idx
if expert_mask.any():
expert_input = hidden_states_flat[expert_mask]
expert_output = self.experts[expert_idx](expert_input)
final_output[expert_mask] += (
expert_output * expert_weights[expert_mask]
)
final_output = final_output.view(batch_size, seq_len, hidden_size)
return final_output, aux_loss, z_loss
class MixtureOfDepthsRouter(nn.Module):
def __init__(self, hidden_size, capacity_factor=0.5, route_method="learned"):
super().__init__()
self.capacity_factor = capacity_factor
self.route_method = route_method
if route_method == "learned":
self.router = nn.Linear(hidden_size, 1)
def forward(self, hidden_states):
batch_size, seq_len, hidden_size = hidden_states.shape
if self.route_method == "learned":
routing_logits = self.router(hidden_states).squeeze(-1)
elif self.route_method == "random":
routing_logits = torch.rand(
batch_size, seq_len, device=hidden_states.device
)
else:
routing_logits = torch.zeros(
batch_size, seq_len, device=hidden_states.device
)
capacity = max(1, int(seq_len * self.capacity_factor))
_, top_indices = torch.topk(routing_logits, k=capacity, dim=-1)
process_mask = torch.zeros(
batch_size, seq_len, dtype=torch.bool, device=hidden_states.device
)
process_mask.scatter_(1, top_indices, True)
return process_mask
class CacaAttention(nn.Module):
def __init__(self, config, layer_idx=None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.sliding_window = config.sliding_window
self.attn_logit_softcapping = config.attn_logit_softcapping
self.attention_sink_size = config.attention_sink_size
self.attention_sink_window = config.attention_sink_window
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
)
if config.use_qk_norm:
self.q_norm = CacaRMSNorm(self.head_dim, eps=config.qk_norm_eps)
self.k_norm = CacaRMSNorm(self.head_dim, eps=config.qk_norm_eps)
else:
self.q_norm = None
self.k_norm = None
if config.use_rotary_embeddings:
scaling_factor = 1.0
scaling_type = None
if config.rope_scaling is not None:
scaling_type = config.rope_scaling.get("type", "linear")
scaling_factor = config.rope_scaling.get("factor", 1.0)
self.rotary_emb = CacaRotaryEmbedding(
self.head_dim,
config.max_position_embeddings,
config.rope_theta,
scaling_factor=scaling_factor,
scaling_type=scaling_type,
)
else:
self.rotary_emb = None
if config.use_alibi:
self.alibi = ALiBiPositionalBias(
self.num_heads, config.max_position_embeddings
)
else:
self.alibi = None
self.attention_dropout = nn.Dropout(config.attention_dropout)
self.is_global_attention = self._determine_attention_type(config, layer_idx)
self.has_flash_attn = HAS_FLASH_ATTN and config.use_flash_attn
self.has_xformers = HAS_XFORMERS
self.has_sdpa = HAS_SDPA
self._mask_cache = {}
self._max_cache_size = 10
def _determine_attention_type(self, config, layer_idx):
if layer_idx is None:
return False
if config.attention_pattern == "all_global":
return True
elif config.attention_pattern == "all_local":
return False
elif config.attention_pattern == "interleaved":
return (layer_idx % config.global_attention_every_n_layers) == (
config.global_attention_every_n_layers - 1
)
return False
def forward(
self, hidden_states, attention_mask=None, past_key_value=None, use_cache=False
):
batch_size, seq_length, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
batch_size, seq_length, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, seq_length, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, seq_length, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
if self.q_norm is not None and self.k_norm is not None:
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
position_offset = 0
if past_key_value is not None:
try:
if isinstance(past_key_value, (tuple, list)) and len(past_key_value) >= 2:
if past_key_value[0] is not None:
position_offset = past_key_value[0].shape[2]
except (IndexError, AttributeError, TypeError):
position_offset = 0
if self.rotary_emb is not None:
cos, sin = self.rotary_emb(query_states, seq_length, position_offset)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
if past_key_value is not None and past_key_value[0] is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
if use_cache:
present_key_value = (key_states, value_states)
else:
present_key_value = None
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
kv_seq_len = key_states.shape[-2]
use_sliding_window = (not self.is_global_attention) and (
self.sliding_window is not None
)
if self.has_flash_attn and attention_mask is None:
if query_states.device.type == "cuda" and query_states.dtype in [
torch.float16,
torch.bfloat16,
]:
try:
attn_output = self._flash_attention(
query_states, key_states, value_states, use_sliding_window
)
except Exception as e:
logger.warning(f"Flash Attention gagal, pakai fallback: {e}")
attn_output = self._fallback_attention(
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
)
else:
attn_output = self._fallback_attention(
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
)
else:
attn_output = self._fallback_attention(
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
)
attn_output = self.o_proj(attn_output)
return attn_output, present_key_value
def _flash_attention(
self, query_states, key_states, value_states, use_sliding_window
):
batch_size, num_heads, seq_length, head_dim = query_states.shape
kv_seq_len = key_states.shape[-2]
original_dtype = query_states.dtype
if original_dtype == torch.bfloat16:
if not torch.cuda.is_bf16_supported():
logger.warning("BF16 not supported on this GPU, falling back to FP16")
original_dtype = torch.float16
compute_dtype = (
torch.bfloat16
if original_dtype not in [torch.float16, torch.bfloat16]
else original_dtype
)
query_states = query_states.transpose(1, 2).contiguous().to(compute_dtype)
key_states = key_states.transpose(1, 2).contiguous().to(compute_dtype)
value_states = value_states.transpose(1, 2).contiguous().to(compute_dtype)
if use_sliding_window and self.sliding_window < kv_seq_len:
window_size = (self.sliding_window, 0)
else:
window_size = (-1, 0)
attn_output = flash_attn_func(
query_states,
key_states,
value_states,
dropout_p=self.config.attention_dropout if self.training else 0.0,
softmax_scale=None,
causal=True,
window_size=window_size,
)
attn_output = attn_output.to(original_dtype)
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
return attn_output
def _fallback_attention(
self,
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
):
device_type = query_states.device.type
if self.has_xformers and device_type == "cuda" and attention_mask is None:
try:
return self._xformers_attention(
query_states,
key_states,
value_states,
kv_seq_len,
use_sliding_window,
)
except Exception as e:
logger.warning(f"xFormers gagal, pakai SDPA: {e}")
if self.has_sdpa:
return self._sdpa_attention(
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
)
else:
return self._standard_attention(
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
)
def _create_causal_mask(
self, query_length, key_length, dtype, device, use_sliding_window
):
cache_key = (
query_length,
key_length,
str(dtype),
use_sliding_window,
self.sliding_window if use_sliding_window else None,
)
if cache_key in self._mask_cache:
cached_mask = self._mask_cache[cache_key]
return cached_mask.to(device, dtype)
if query_length > key_length:
key_length = query_length
query_pos = torch.arange(query_length, device=device) + (
key_length - query_length
)
key_pos = torch.arange(key_length, device=device)
distance = query_pos[:, None] - key_pos[None, :]
mask = distance < 0
if use_sliding_window and self.sliding_window is not None:
if self.config.use_attention_sink and self.attention_sink_size > 0:
is_sink = key_pos[None, :] < self.attention_sink_size
in_window = (distance >= 0) & (distance <= self.sliding_window)
mask = (distance < 0) | ((~is_sink) & (~in_window))
else:
too_far_mask = distance > self.sliding_window
mask = mask | too_far_mask
float_mask = torch.zeros(
1, 1, query_length, key_length, dtype=dtype, device=device
)
float_mask.masked_fill_(mask.unsqueeze(0).unsqueeze(0), -1e9)
if len(self._mask_cache) >= self._max_cache_size:
oldest_key = next(iter(self._mask_cache))
del self._mask_cache[oldest_key]
self._mask_cache[cache_key] = float_mask.detach().cpu()
return float_mask
def _xformers_attention(
self, query_states, key_states, value_states, kv_seq_len, use_sliding_window
):
batch_size, num_heads, seq_length, head_dim = query_states.shape
attn_bias = self._create_causal_mask(
seq_length,
kv_seq_len,
query_states.dtype,
query_states.device,
use_sliding_window,
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = memory_efficient_attention(
query_states,
key_states,
value_states,
attn_bias=attn_bias,
p=self.config.attention_dropout if self.training else 0.0,
)
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
return attn_output
def _sdpa_attention(
self,
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
):
batch_size, num_heads, seq_length, head_dim = query_states.shape
if attention_mask is None:
attention_mask = self._create_causal_mask(
seq_length,
kv_seq_len,
query_states.dtype,
query_states.device,
use_sliding_window,
)
if self.alibi is not None:
alibi_bias = self.alibi(seq_length, kv_seq_len)
attention_mask = attention_mask + alibi_bias
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.config.attention_dropout if self.training else 0.0,
is_causal=False,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
return attn_output
def _standard_attention(
self,
query_states,
key_states,
value_states,
attention_mask,
kv_seq_len,
use_sliding_window,
):
batch_size, num_heads, seq_length, head_dim = query_states.shape
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(head_dim)
attn_weights = torch.clamp(attn_weights, min=-50.0, max=50.0)
attn_weights = soft_cap_logits(attn_weights, self.attn_logit_softcapping)
if attention_mask is None:
attention_mask = self._create_causal_mask(
seq_length,
kv_seq_len,
attn_weights.dtype,
attn_weights.device,
use_sliding_window,
)
if self.alibi is not None:
alibi_bias = self.alibi(seq_length, kv_seq_len)
attention_mask = attention_mask + alibi_bias
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
return attn_output
class CacaCrossAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_key_value_heads = config.num_key_value_heads
self.head_dim = config.head_dim
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
)
self.attention_dropout = nn.Dropout(config.attention_dropout)
def forward(self, hidden_states, encoder_hidden_states, attention_mask=None):
batch_size, seq_length, _ = hidden_states.size()
encoder_seq_length = encoder_hidden_states.size(1)
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(encoder_hidden_states)
value_states = self.v_proj(encoder_hidden_states)
query_states = query_states.view(
batch_size, seq_length, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
batch_size, encoder_seq_length, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
batch_size, encoder_seq_length, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
query_states.dtype
)
attn_weights = self.attention_dropout(attn_weights)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
class CacaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
)
self.up_proj = nn.Linear(
self.hidden_size, self.intermediate_size, bias=config.mlp_bias
)
self.down_proj = nn.Linear(
self.intermediate_size, self.hidden_size, bias=config.mlp_bias
)
self.dropout = nn.Dropout(config.hidden_dropout)
def forward(self, x):
gate = F.silu(self.gate_proj(x))
up = self.up_proj(x)
hidden = gate * up
hidden = self.dropout(hidden)
output = self.down_proj(hidden)
return output
class CacaDecoderLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.self_attn = CacaAttention(config, layer_idx=layer_idx)
self.use_moe = config.use_moe and (layer_idx % config.moe_layer_frequency == 0)
if self.use_moe:
self.mlp = MixtureOfExperts(config)
else:
self.mlp = CacaMLP(config)
self.use_cross_attention = config.use_cross_attention and (
layer_idx % config.cross_attention_frequency == 0
)
if self.use_cross_attention:
self.cross_attn = CacaCrossAttention(config)
self.cross_attn_layernorm = CacaRMSNorm(
config.hidden_size, config.rms_norm_eps
)
self.input_layernorm = CacaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_attention_layernorm = CacaRMSNorm(
config.hidden_size, config.rms_norm_eps
)
self.residual_dropout = nn.Dropout(config.residual_dropout)
if config.use_layer_scale:
self.layer_scale_1 = LayerScale(config.hidden_size, config.layer_scale_init)
self.layer_scale_2 = LayerScale(config.hidden_size, config.layer_scale_init)
if self.use_cross_attention:
self.layer_scale_cross = LayerScale(
config.hidden_size, config.layer_scale_init
)
else:
self.layer_scale_1 = None
self.layer_scale_2 = None
self.layer_scale_cross = None
if config.use_stochastic_depth:
drop_prob = (
config.stochastic_depth_prob * layer_idx / config.num_hidden_layers
)
self.stochastic_depth = StochasticDepth(drop_prob)
else:
self.stochastic_depth = None
if config.use_mixture_of_depths:
self.mod_router = MixtureOfDepthsRouter(
config.hidden_size, config.mod_capacity_factor, config.mod_route_method
)
else:
self.mod_router = None
def forward(
self,
hidden_states,
attention_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=None,
use_cache=False,
):
aux_loss = 0.0
z_loss = 0.0
if self.mod_router is not None:
process_mask = self.mod_router(hidden_states)
tokens_to_process = hidden_states[process_mask]
if tokens_to_process.numel() == 0:
present_key_value = past_key_value if use_cache else None
return hidden_states, present_key_value, aux_loss, z_loss
else:
process_mask = None
tokens_to_process = hidden_states
residual = tokens_to_process
tokens_to_process = self.input_layernorm(tokens_to_process)
attn_output, present_key_value = self.self_attn(
tokens_to_process,
attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
)
if self.layer_scale_1 is not None:
attn_output = self.layer_scale_1(attn_output)
if self.stochastic_depth is not None:
attn_output = self.stochastic_depth(attn_output, self.training)
tokens_to_process = residual + self.residual_dropout(attn_output)
if self.use_cross_attention and encoder_hidden_states is not None:
residual = tokens_to_process
tokens_to_process = self.cross_attn_layernorm(tokens_to_process)
cross_attn_output = self.cross_attn(
tokens_to_process,
encoder_hidden_states,
attention_mask=encoder_attention_mask,
)
if self.layer_scale_cross is not None:
cross_attn_output = self.layer_scale_cross(cross_attn_output)
if self.stochastic_depth is not None:
cross_attn_output = self.stochastic_depth(
cross_attn_output, self.training
)
tokens_to_process = residual + self.residual_dropout(cross_attn_output)
residual = tokens_to_process
tokens_to_process = self.post_attention_layernorm(tokens_to_process)
if self.use_moe:
mlp_output, moe_aux_loss, moe_z_loss = self.mlp(tokens_to_process)
aux_loss += moe_aux_loss
z_loss += moe_z_loss
else:
mlp_output = self.mlp(tokens_to_process)
if self.layer_scale_2 is not None:
mlp_output = self.layer_scale_2(mlp_output)
if self.stochastic_depth is not None:
mlp_output = self.stochastic_depth(mlp_output, self.training)
tokens_to_process = residual + self.residual_dropout(mlp_output)
if process_mask is not None:
hidden_states[process_mask] = tokens_to_process
else:
hidden_states = tokens_to_process
return hidden_states, present_key_value, aux_loss, z_loss
class VisionEncoder(nn.Module):
def __init__(self, config):
super().__init__()
vision_config = config.vision_config
self.patch_size = vision_config.get("patch_size", 14)
self.image_size = vision_config.get("image_size", 224)
self.num_channels = vision_config.get("num_channels", 3)
self.hidden_size = vision_config.get("hidden_size", 1024)
self.num_layers = vision_config.get("num_layers", 24)
self.num_heads = vision_config.get("num_heads", 16)
self.intermediate_size = vision_config.get("intermediate_size", 4096)
self.layer_norm_eps = vision_config.get("layer_norm_eps", 1e-6)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.patch_embed = nn.Sequential(
nn.Conv2d(
self.num_channels,
self.hidden_size,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
),
nn.Dropout(p=vision_config.get("dropout", 0.0)),
)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches + 1, self.hidden_size)
)
self.pos_drop = nn.Dropout(p=vision_config.get("dropout", 0.0))
self.blocks = nn.ModuleList(
[
VisionTransformerBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_ratio=self.intermediate_size / self.hidden_size,
dropout=vision_config.get("dropout", 0.0),
layer_norm_eps=self.layer_norm_eps,
)
for _ in range(self.num_layers)
]
)
self.norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
nn.init.trunc_normal_(self.cls_token, std=0.02)
nn.init.trunc_normal_(self.patch_embed[0].weight, std=0.02)
def forward(self, pixel_values):
batch_size = pixel_values.shape[0]
x = self.patch_embed(pixel_values)
x = x.flatten(2).transpose(1, 2)
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_tokens, x], dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x
class VisionTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0, layer_norm_eps=1e-6):
super().__init__()
self.norm1 = nn.LayerNorm(dim, eps=layer_norm_eps)
self.attn = nn.MultiheadAttention(
dim, num_heads, dropout=dropout, batch_first=True
)
self.drop_path1 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(dim, eps=layer_norm_eps)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(dropout),
)
self.drop_path2 = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.norm1(x)
x = self.attn(x, x, x, need_weights=False)[0]
x = self.drop_path1(x)
x = residual + x
residual = x
x = self.norm2(x)
x = self.mlp(x)
x = self.drop_path2(x)
x = residual + x
return x
class AudioEncoder(nn.Module):
def __init__(self, config):
super().__init__()
audio_config = config.audio_config
self.num_mel_bins = audio_config.get("num_mel_bins", 80)
self.hidden_size = audio_config.get("hidden_size", 1024)
self.num_layers = audio_config.get("num_layers", 12)
self.num_heads = audio_config.get("num_heads", 16)
self.intermediate_size = audio_config.get("intermediate_size", 4096)
self.max_audio_length = audio_config.get("max_audio_length", 3000)
self.dropout = audio_config.get("dropout", 0.0)
self.conv1 = nn.Sequential(
nn.Conv1d(self.num_mel_bins, self.hidden_size, kernel_size=3, padding=1),
nn.GELU(),
nn.Dropout(p=self.dropout),
)
self.conv2 = nn.Sequential(
nn.Conv1d(
self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1
),
nn.GELU(),
nn.Dropout(p=self.dropout),
)
self.pos_embed = nn.Parameter(
torch.zeros(1, self.max_audio_length // 2, self.hidden_size)
)
self.pos_drop = nn.Dropout(p=self.dropout)
self.blocks = nn.ModuleList(
[
AudioTransformerBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
mlp_ratio=self.intermediate_size / self.hidden_size,
dropout=self.dropout,
)
for _ in range(self.num_layers)
]
)
self.norm = nn.LayerNorm(self.hidden_size)
self._init_weights()
def _init_weights(self):
nn.init.trunc_normal_(self.pos_embed, std=0.02)
def forward(self, audio_features):
x = F.gelu(self.conv1(audio_features))
x = F.gelu(self.conv2(x))
x = x.transpose(1, 2)
seq_len = x.shape[1]
if seq_len <= self.pos_embed.shape[1]:
x = x + self.pos_embed[:, :seq_len, :]
else:
pos_embed_interp = F.interpolate(
self.pos_embed.transpose(1, 2),
size=seq_len,
mode="linear",
align_corners=False,
).transpose(1, 2)
x = x + pos_embed_interp
x = self.pos_drop(x)
for block in self.blocks:
x = block(x)
x = self.norm(x)
return x
class AudioTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(
dim, num_heads, dropout=dropout, batch_first=True
)
self.drop_path1 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(dropout),
)
self.drop_path2 = nn.Dropout(dropout)
def forward(self, x):
residual = x
x = self.norm1(x)
x = self.attn(x, x, x, need_weights=False)[0]
x = self.drop_path1(x)
x = residual + x
residual = x
x = self.norm2(x)
x = self.mlp(x)
x = self.drop_path2(x)
x = residual + x
return x
class MultiModalProjector(nn.Module):
def __init__(self, input_size, output_size, projector_type="mlp", num_layers=2):
super().__init__()
self.projector_type = projector_type
if projector_type == "linear":
self.projector = nn.Linear(input_size, output_size)
elif projector_type == "mlp":
layers = []
current_size = input_size
for i in range(num_layers - 1):
layers.extend(
[nn.Linear(current_size, output_size), nn.GELU(), nn.Dropout(0.1)]
)
current_size = output_size
layers.append(nn.Linear(current_size, output_size))
self.projector = nn.Sequential(*layers)
elif projector_type == "perceiver":
self.projector = PerceiverResampler(
input_size, output_size, num_latents=64, num_layers=2
)
elif projector_type == "qformer":
self.projector = QFormerProjector(
input_size, output_size, num_queries=32, num_layers=2
)
else:
raise ValueError(f"projector_type tidak dikenal: {projector_type}")
def forward(self, x):
return self.projector(x)
class PerceiverResampler(nn.Module):
def __init__(self, input_size, output_size, num_latents=64, num_layers=2):
super().__init__()
self.num_latents = num_latents
self.latents = nn.Parameter(torch.randn(num_latents, output_size))
self.layers = nn.ModuleList(
[
PerceiverLayer(output_size, input_size if i == 0 else output_size)
for i in range(num_layers)
]
)
self.norm = nn.LayerNorm(output_size)
def forward(self, x):
batch_size = x.shape[0]
latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)
for i, layer in enumerate(self.layers):
if i == 0:
latents = layer(latents, x)
else:
latents = layer(latents, latents)
return self.norm(latents)
class PerceiverLayer(nn.Module):
def __init__(self, query_dim, key_dim):
super().__init__()
self.cross_attn = nn.MultiheadAttention(
query_dim, num_heads=8, kdim=key_dim, vdim=key_dim, batch_first=True
)
self.mlp = nn.Sequential(
nn.LayerNorm(query_dim),
nn.Linear(query_dim, query_dim * 4),
nn.GELU(),
nn.Linear(query_dim * 4, query_dim),
)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
def forward(self, query, key):
query = (
query + self.cross_attn(self.norm1(query), key, key, need_weights=False)[0]
)
query = query + self.mlp(self.norm2(query))
return query
class QFormerProjector(nn.Module):
def __init__(self, input_size, output_size, num_queries=32, num_layers=2):
super().__init__()
self.num_queries = num_queries
self.query_embeds = nn.Parameter(torch.randn(num_queries, output_size))
self.query_layers = nn.ModuleList(
[
nn.TransformerEncoderLayer(
d_model=output_size,
nhead=8,
dim_feedforward=output_size * 4,
batch_first=True,
)
for _ in range(num_layers)
]
)
self.cross_attn_layers = nn.ModuleList(
[
nn.MultiheadAttention(
output_size,
num_heads=8,
kdim=input_size,
vdim=input_size,
batch_first=True,
)
for _ in range(num_layers)
]
)
self.norm = nn.LayerNorm(output_size)
def forward(self, x):
batch_size = x.shape[0]
queries = self.query_embeds.unsqueeze(0).expand(batch_size, -1, -1)
for query_layer, cross_attn_layer in zip(
self.query_layers, self.cross_attn_layers
):
queries = query_layer(queries)
queries = queries + cross_attn_layer(queries, x, x, need_weights=False)[0]
return self.norm(queries)
class CacaPreTrainedModel(PreTrainedModel):
config_class = CacaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["CacaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, CacaModel):
module.gradient_checkpointing = value
class CacaModel(CacaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList(
[
CacaDecoderLayer(config, layer_idx=idx)
for idx in range(config.num_hidden_layers)
]
)
self.norm = CacaRMSNorm(config.hidden_size, config.rms_norm_eps)
self.gradient_checkpointing = False
if config.use_multimodal:
if config.vision_config:
self.vision_encoder = VisionEncoder(config)
vision_hidden_size = config.vision_config.get("hidden_size", 768)
self.vision_projector = MultiModalProjector(
vision_hidden_size,
config.hidden_size,
projector_type=config.vision_config.get("projector_type", "mlp"),
)
else:
self.vision_encoder = None
self.vision_projector = None
if config.audio_config:
self.audio_encoder = AudioEncoder(config)
audio_hidden_size = config.audio_config.get("hidden_size", 768)
self.audio_projector = MultiModalProjector(
audio_hidden_size,
config.hidden_size,
projector_type=config.audio_config.get("projector_type", "mlp"),
)
else:
self.audio_encoder = None
self.audio_projector = None
self.post_init()
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def _prepare_attention_mask(self, attention_mask, input_shape, dtype):
if attention_mask is None:
return None
batch_size, seq_length = input_shape
if attention_mask.dim() == 2:
attention_mask = attention_mask[:, None, None, :]
elif attention_mask.dim() == 3:
attention_mask = attention_mask[:, None, :, :]
attention_mask = attention_mask.to(dtype=dtype)
attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min
return attention_mask
def forward(
self,
input_ids=None,
pixel_values=None,
audio_features=None,
attention_mask=None,
past_key_values=None,
use_cache=None,
output_hidden_states=False,
return_dict=True,
**kwargs,
):
use_cache = use_cache if use_cache is not None else self.config.use_cache
if input_ids is not None:
batch_size, seq_length = input_ids.shape
device = input_ids.device
hidden_states = self.embed_tokens(input_ids)
else:
raise ValueError("input_ids tidak boleh None")
if pixel_values is not None:
pixel_values = pixel_values.to(device)
if audio_features is not None:
audio_features = audio_features.to(device)
encoder_hidden_states = None
encoder_attention_mask = None
if self.config.use_multimodal:
multimodal_embeds = []
if pixel_values is not None and self.vision_encoder is not None:
vision_features = self.vision_encoder(pixel_values.to(hidden_states.device))
vision_embeds = self.vision_projector(vision_features)
multimodal_embeds.append(vision_embeds)
if audio_features is not None and self.audio_encoder is not None:
audio_encoded = self.audio_encoder(audio_features.to(hidden_states.device))
audio_embeds = self.audio_projector(audio_encoded)
multimodal_embeds.append(audio_embeds)
if multimodal_embeds and self.config.use_cross_attention:
encoder_hidden_states = torch.cat(multimodal_embeds, dim=1)
encoder_seq_len = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
batch_size,
encoder_seq_len,
dtype=hidden_states.dtype,
device=hidden_states.device,
)
elif multimodal_embeds:
multimodal_concat = torch.cat(multimodal_embeds, dim=1)
max_multimodal_tokens = self.config.max_position_embeddings // 4
if multimodal_concat.shape[1] > max_multimodal_tokens:
logger.warning(
f"Multimodal tokens ({multimodal_concat.shape[1]}) > max ({max_multimodal_tokens}). "
f"Truncating..."
)
multimodal_concat = multimodal_concat[:, :max_multimodal_tokens]
hidden_states = torch.cat([multimodal_concat, hidden_states], dim=1)
seq_length = hidden_states.shape[1]
if attention_mask is not None:
multimodal_mask = torch.ones(
batch_size,
multimodal_concat.shape[1],
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat([multimodal_mask, attention_mask], dim=1)
else:
attention_mask = torch.ones(
batch_size,
seq_length,
dtype=hidden_states.dtype,
device=device,
)
if attention_mask is not None:
attention_mask = self._prepare_attention_mask(
attention_mask, (batch_size, seq_length), hidden_states.dtype
)
if encoder_attention_mask is not None and self.config.use_cross_attention:
encoder_attention_mask = self._prepare_attention_mask(
encoder_attention_mask,
(batch_size, encoder_hidden_states.shape[1]),
hidden_states.dtype,
)
if use_cache:
if past_key_values is None:
past_key_values = tuple([None] * len(self.layers))
present_key_values = [] if use_cache else None
all_hidden_states = [] if output_hidden_states else None
total_aux_loss = 0.0
total_z_loss = 0.0
for idx, layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states.append(hidden_states)
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
if self.gradient_checkpointing and self.training and not use_cache:
hidden_states, aux_loss, z_loss = self._gradient_checkpointing_forward(
layer,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
)
present_key_value = None
else:
hidden_states, present_key_value, aux_loss, z_loss = layer(
hidden_states,
attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
)
if use_cache:
present_key_values.append(present_key_value)
total_aux_loss += aux_loss
total_z_loss += z_loss
if self.training and torch.cuda.is_available():
allocated_gb = torch.cuda.memory_allocated() / 1024**3
reserved_gb = torch.cuda.memory_reserved() / 1024**3
if allocated_gb > 10:
logger.warning(
f"High GPU memory usage - Allocated: {allocated_gb:.2f}GB, "
f"Reserved: {reserved_gb:.2f}GB"
)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states.append(hidden_states)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
present_key_values,
all_hidden_states,
total_aux_loss,
total_z_loss,
]
if v is not None
)
return (
BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=tuple(present_key_values) if use_cache else None,
hidden_states=all_hidden_states,
attentions=None,
),
total_aux_loss,
total_z_loss,
)
def _gradient_checkpointing_forward(
self,
layer,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
):
from torch.utils.checkpoint import checkpoint
def custom_forward(hidden_states, attention_mask, encoder_hidden_states,
encoder_attention_mask):
output, _, aux_loss, z_loss = layer(
hidden_states, attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=None,
use_cache=False,
)
return output, aux_loss, z_loss
hidden_states, aux_loss, z_loss = checkpoint(
custom_forward,
hidden_states, attention_mask,
encoder_hidden_states, encoder_attention_mask,
use_reentrant=False,
)
return hidden_states, aux_loss, z_loss
class CacaForCausalLM(CacaPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = CacaModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids=None,
pixel_values=None,
audio_features=None,
attention_mask=None,
labels=None,
past_key_values=None,
inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
if input_ids is not None:
if input_ids.dtype.is_floating_point:
raise TypeError(
f"input_ids harus integer dtype, dapat {input_ids.dtype}. "
f"Gunakan input_ids.long() untuk convert."
)
if (input_ids < 0).any():
neg_vals = input_ids[input_ids < 0].unique().tolist()
raise ValueError(f"input_ids mengandung nilai negatif: {neg_vals}")
max_val = input_ids.max().item()
if max_val >= self.config.vocab_size:
raise ValueError(
f"input_ids mengandung nilai >= vocab_size. "
f"Max value: {max_val}, vocab_size: {self.config.vocab_size:,}"
)
if labels is not None:
if not labels.dtype in [torch.long, torch.int, torch.int32, torch.int64]:
raise TypeError(f"labels harus integer dtype, dapat {labels.dtype}")
if (labels[labels != -100] < 0).any():
raise ValueError(f"labels mengandung nilai negatif (selain -100)")
max_label = labels[labels != -100].max().item() if (labels != -100).any() else 0
if max_label >= self.config.vocab_size:
raise ValueError(
f"labels mengandung nilai >= vocab_size. "
f"Max: {max_label}, vocab_size: {self.config.vocab_size}"
)
if attention_mask is not None:
if attention_mask.shape[0] != input_ids.shape[0]:
raise ValueError(
f"attention_mask batch size ({attention_mask.shape[0]}) != "
f"input_ids batch size ({input_ids.shape[0]})"
)
if attention_mask.shape[1] != input_ids.shape[1]:
raise ValueError(
f"attention_mask seq length ({attention_mask.shape[1]}) != "
f"input_ids seq length ({input_ids.shape[1]})"
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs, aux_loss, z_loss = self.model(
input_ids,
pixel_values=pixel_values,
audio_features=audio_features,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
if return_dict:
hidden_states = outputs.last_hidden_state
else:
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
if self.config.final_logit_softcapping:
logits = soft_cap_logits(logits, self.config.final_logit_softcapping)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
if self.config.use_moe:
total_loss = (
lm_loss
+ (self.config.router_aux_loss_coef * aux_loss)
+ (self.config.router_z_loss_coef * z_loss)
)
else:
total_loss = lm_loss
loss = total_loss
if not return_dict:
output = (logits,)
if return_dict:
output += tuple(
v
for v in [outputs.past_key_values, outputs.hidden_states]
if v is not None
)
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values if return_dict else None,
hidden_states=outputs.hidden_states if return_dict else None,
attentions=None,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
pixel_values=None,
audio_features=None,
**kwargs,
):
has_past = (
past_key_values is not None
and len(past_key_values) > 0
and past_key_values[0] is not None
)
if has_past:
input_ids = input_ids[:, -1:]
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
attention_mask = attention_mask[:, -input_ids.shape[1]:]
if inputs_embeds is not None and not has_past:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
model_inputs.update(
{
"past_key_values": past_key_values if has_past else None,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"pixel_values": pixel_values if not has_past else None,
"audio_features": audio_features if not has_past else None,
}
)
return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
if layer_past is not None and len(layer_past) > 0:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
if past_state is not None
),
)
else:
reordered_past += (None,)
return reordered_past
def save_pretrained(self, save_directory, **kwargs):
has_quant_config = hasattr(self.config, 'quantization_config')
quantization_config_backup = getattr(self.config, 'quantization_config', None)
if has_quant_config and quantization_config_backup is None:
delattr(self.config, 'quantization_config')
try:
super().save_pretrained(save_directory, **kwargs)
finally:
if has_quant_config:
self.config.quantization_config = quantization_config_backup
class CacaForCausalLMQuantized(CacaForCausalLM):
def __init__(self, config, quantization_config=None):
super().__init__(config)
self.quantization_config = quantization_config
if quantization_config:
self._apply_quantization()
def _apply_quantization(self):
if self.quantization_config.get("load_in_8bit"):
self._quantize_8bit()
elif self.quantization_config.get("load_in_4bit"):
self._quantize_4bit()
def _quantize_8bit(self):
try:
import bitsandbytes as bnb
for name, module in self.named_modules():
if isinstance(module, nn.Linear):
has_bias = module.bias is not None
new_module = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
has_bias,
threshold=self.quantization_config.get(
"llm_int8_threshold", 6.0
),
)
new_module.weight = module.weight
if has_bias:
new_module.bias = module.bias
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
if parent_name:
parent = self.get_submodule(parent_name)
setattr(parent, child_name, new_module)
else:
setattr(self, child_name, new_module)
logger.info("Quantisasi 8-bit berhasil diterapkan")
except ImportError:
logger.error("bitsandbytes tidak terinstall! pip install bitsandbytes")
def _quantize_4bit(self):
try:
import bitsandbytes as bnb
compute_dtype = torch.float16
if self.quantization_config.get("bnb_4bit_compute_dtype"):
compute_dtype = getattr(
torch, self.quantization_config["bnb_4bit_compute_dtype"]
)
for name, module in self.named_modules():
if isinstance(module, nn.Linear):
has_bias = module.bias is not None
new_module = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
bias=has_bias,
compute_dtype=compute_dtype,
quant_type=self.quantization_config.get(
"bnb_4bit_quant_type", "nf4"
),
use_double_quant=self.quantization_config.get(
"bnb_4bit_use_double_quant", True
),
)
new_module.weight = module.weight
if has_bias:
new_module.bias = module.bias
parent_name = ".".join(name.split(".")[:-1])
child_name = name.split(".")[-1]
if parent_name:
parent = self.get_submodule(parent_name)
setattr(parent, child_name, new_module)
else:
setattr(self, child_name, new_module)
logger.info("Quantisasi 4-bit berhasil diterapkan")
except ImportError:
logger.error("bitsandbytes tidak terinstall!")
@classmethod
def from_pretrained_quantized(cls, model_path, quantization_config):
config = CacaConfig.from_pretrained(model_path)
model = cls(config, quantization_config=quantization_config)
state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict, strict=False)
return model