Lyon28 commited on
Commit
a963c81
ยท
verified ยท
1 Parent(s): 1b87288

Add custom modeling file

Browse files
Files changed (1) hide show
  1. caca_transformers.py +2005 -0
caca_transformers.py ADDED
@@ -0,0 +1,2005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import Optional, Tuple, List
6
+ from transformers import PreTrainedModel, PretrainedConfig
7
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
8
+ from transformers.generation.utils import GenerationMixin
9
+ from collections import OrderedDict
10
+ import logging
11
+ from functools import lru_cache
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ try:
16
+ from flash_attn import flash_attn_func
17
+ HAS_FLASH_ATTN = True
18
+ except ImportError:
19
+ HAS_FLASH_ATTN = False
20
+
21
+ try:
22
+ from xformers.ops import memory_efficient_attention
23
+ HAS_XFORMERS = True
24
+ except ImportError:
25
+ HAS_XFORMERS = False
26
+
27
+ HAS_SDPA = hasattr(F, 'scaled_dot_product_attention')
28
+
29
+ # --- config ---
30
+ class CacaConfig(PretrainedConfig):
31
+ model_type = "caca"
32
+
33
+ def __init__(
34
+ self,
35
+ vocab_size=32000,
36
+ hidden_size=2048,
37
+ intermediate_size=8192,
38
+ num_hidden_layers=24,
39
+ num_attention_heads=32,
40
+ num_key_value_heads=8,
41
+ head_dim=64,
42
+ max_position_embeddings=8192,
43
+ rms_norm_eps=1e-6,
44
+ qk_norm_eps=1e-6,
45
+ initializer_range=0.02,
46
+ use_cache=True,
47
+ pad_token_id=None,
48
+ bos_token_id=1,
49
+ eos_token_id=2,
50
+ tie_word_embeddings=False,
51
+ rope_theta=10000.0,
52
+ rope_scaling=None,
53
+ use_rotary_embeddings=True,
54
+ attention_bias=False,
55
+ attention_dropout=0.0,
56
+ use_qk_norm=True,
57
+ use_alibi=False,
58
+ use_flash_attn=True,
59
+ use_grouped_query_attention=False,
60
+ use_multi_query_attention=False,
61
+ sliding_window=None,
62
+ use_longformer_attention=False,
63
+ longformer_attention_window=512,
64
+ attn_logit_softcapping=None,
65
+ final_logit_softcapping=None,
66
+ attention_sink_size=4,
67
+ attention_sink_window=1024,
68
+ use_attention_sink=False,
69
+ attention_pattern="all_global",
70
+ global_attention_every_n_layers=2,
71
+ mlp_bias=False,
72
+ hidden_dropout=0.1,
73
+ residual_dropout=0.1,
74
+ use_moe=False,
75
+ num_experts=8,
76
+ num_experts_per_tok=2,
77
+ use_expert_choice=False,
78
+ expert_choice_k=0.125,
79
+ router_aux_loss_coef=0.01,
80
+ router_z_loss_coef=0.001,
81
+ moe_layer_frequency=2,
82
+ expert_capacity_factor=1.0,
83
+ use_grouped_moe=False,
84
+ num_expert_groups=1,
85
+ use_layer_scale=False,
86
+ layer_scale_init=1e-5,
87
+ use_stochastic_depth=False,
88
+ stochastic_depth_prob=0.1,
89
+ use_mixture_of_depths=False,
90
+ mod_capacity_factor=0.5,
91
+ mod_route_method="learned",
92
+ use_cross_attention=False,
93
+ cross_attention_frequency=4,
94
+ use_multimodal=False,
95
+ vision_config=None,
96
+ audio_config=None,
97
+ projector_hidden_size=None,
98
+ use_soft_merging=False,
99
+ merge_threshold=0.5,
100
+ pretraining_tp=1,
101
+ tensor_parallel_size=1,
102
+ pipeline_parallel_size=1,
103
+ chat_template=None,
104
+ **kwargs
105
+ ):
106
+ self.vocab_size = vocab_size
107
+ self.hidden_size = hidden_size
108
+ self.intermediate_size = intermediate_size
109
+ self.num_hidden_layers = num_hidden_layers
110
+ self.num_attention_heads = num_attention_heads
111
+ self.num_key_value_heads = num_key_value_heads
112
+ self.head_dim = head_dim or (hidden_size // num_attention_heads if hidden_size and num_attention_heads else None)
113
+ self.max_position_embeddings = max_position_embeddings
114
+ self.rms_norm_eps = rms_norm_eps
115
+ self.qk_norm_eps = qk_norm_eps
116
+ self.initializer_range = initializer_range
117
+ self.use_cache = use_cache
118
+ self.pad_token_id = pad_token_id
119
+ self.bos_token_id = bos_token_id
120
+ self.eos_token_id = eos_token_id
121
+ self.tie_word_embeddings = tie_word_embeddings
122
+ self.rope_theta = rope_theta
123
+ self.rope_scaling = rope_scaling
124
+ self.use_rotary_embeddings = use_rotary_embeddings
125
+ self.attention_bias = attention_bias
126
+ self.attention_dropout = attention_dropout
127
+ self.use_qk_norm = use_qk_norm
128
+ self.use_alibi = use_alibi
129
+ self.use_flash_attn = use_flash_attn
130
+ self.use_grouped_query_attention = use_grouped_query_attention
131
+ self.use_multi_query_attention = use_multi_query_attention
132
+ self.sliding_window = sliding_window
133
+ self.use_longformer_attention = use_longformer_attention
134
+ self.longformer_attention_window = longformer_attention_window
135
+ self.attn_logit_softcapping = attn_logit_softcapping
136
+ self.final_logit_softcapping = final_logit_softcapping
137
+ self.attention_sink_size = attention_sink_size
138
+ self.attention_sink_window = attention_sink_window
139
+ self.use_attention_sink = use_attention_sink
140
+ self.attention_pattern = attention_pattern
141
+ self.global_attention_every_n_layers = global_attention_every_n_layers
142
+ self.mlp_bias = mlp_bias
143
+ self.hidden_dropout = hidden_dropout
144
+ self.residual_dropout = residual_dropout
145
+ self.use_moe = use_moe
146
+ self.num_experts = num_experts
147
+ self.num_experts_per_tok = num_experts_per_tok
148
+ self.use_expert_choice = use_expert_choice
149
+ self.expert_choice_k = expert_choice_k
150
+ self.router_aux_loss_coef = router_aux_loss_coef
151
+ self.router_z_loss_coef = router_z_loss_coef
152
+ self.moe_layer_frequency = moe_layer_frequency
153
+ self.expert_capacity_factor = expert_capacity_factor
154
+ self.use_grouped_moe = use_grouped_moe
155
+ self.num_expert_groups = num_expert_groups
156
+ self.use_layer_scale = use_layer_scale
157
+ self.layer_scale_init = layer_scale_init
158
+ self.use_stochastic_depth = use_stochastic_depth
159
+ self.stochastic_depth_prob = stochastic_depth_prob
160
+ self.use_mixture_of_depths = use_mixture_of_depths
161
+ self.mod_capacity_factor = mod_capacity_factor
162
+ self.mod_route_method = mod_route_method
163
+ self.use_cross_attention = use_cross_attention
164
+ self.cross_attention_frequency = cross_attention_frequency
165
+ self.use_multimodal = use_multimodal
166
+ self.vision_config = vision_config or {}
167
+ self.audio_config = audio_config or {}
168
+ self.projector_hidden_size = projector_hidden_size or hidden_size
169
+ self.use_soft_merging = use_soft_merging
170
+ self.merge_threshold = merge_threshold
171
+ self.pretraining_tp = pretraining_tp
172
+ self.tensor_parallel_size = tensor_parallel_size
173
+ self.pipeline_parallel_size = pipeline_parallel_size
174
+
175
+ if chat_template is None:
176
+ self.chat_template = (
177
+ "{% for message in messages %}"
178
+ "{% if message['role'] == 'system' %}"
179
+ "System: {{ message['content'] }}\n"
180
+ "{% elif message['role'] == 'user' %}"
181
+ "User: {{ message['content'] }}\n"
182
+ "{% elif message['role'] == 'assistant' %}"
183
+ "Assistant: {{ message['content'] }}\n"
184
+ "{% endif %}"
185
+ "{% endfor %}"
186
+ "{% if add_generation_prompt %}Assistant:{% endif %}"
187
+ )
188
+ else:
189
+ self.chat_template = chat_template
190
+
191
+ self._validate_config()
192
+ super().__init__(
193
+ pad_token_id=pad_token_id,
194
+ bos_token_id=bos_token_id,
195
+ eos_token_id=eos_token_id,
196
+ tie_word_embeddings=tie_word_embeddings,
197
+ **kwargs
198
+ )
199
+
200
+ def _validate_config(self):
201
+ if self.num_attention_heads % self.num_key_value_heads != 0:
202
+ raise ValueError(
203
+ f"num_attention_heads ({self.num_attention_heads}) harus habis dibagi "
204
+ f"num_key_value_heads ({self.num_key_value_heads})"
205
+ )
206
+
207
+ if self.use_moe and self.num_experts < self.num_experts_per_tok:
208
+ raise ValueError(
209
+ f"num_experts ({self.num_experts}) harus >= "
210
+ f"num_experts_per_tok ({self.num_experts_per_tok})"
211
+ )
212
+
213
+ if self.hidden_size % self.num_attention_heads != 0:
214
+ raise ValueError(
215
+ f"hidden_size ({self.hidden_size}) harus habis dibagi "
216
+ f"num_attention_heads ({self.num_attention_heads})"
217
+ )
218
+
219
+ if self.vocab_size <= 0:
220
+ raise ValueError(f"vocab_size harus > 0, dapat {self.vocab_size}")
221
+
222
+ if self.use_flash_attn and not HAS_FLASH_ATTN:
223
+ logger.warning(
224
+ "use_flash_attn=True tapi flash-attn tidak terinstall. "
225
+ "Akan fallback ke SDPA/standard attention."
226
+ )
227
+
228
+ if self.sliding_window is not None:
229
+ if self.sliding_window > self.max_position_embeddings:
230
+ raise ValueError(
231
+ f"sliding_window ({self.sliding_window}) tidak boleh > "
232
+ f"max_position_embeddings ({self.max_position_embeddings})"
233
+ )
234
+
235
+ if self.use_moe:
236
+ if self.moe_layer_frequency <= 0:
237
+ raise ValueError(f"moe_layer_frequency harus > 0")
238
+ if self.moe_layer_frequency > self.num_hidden_layers:
239
+ logger.warning(
240
+ f"moe_layer_frequency ({self.moe_layer_frequency}) > "
241
+ f"num_hidden_layers ({self.num_hidden_layers}). "
242
+ f"MoE tidak akan digunakan."
243
+ )
244
+
245
+ def to_dict(self):
246
+ has_quant_config = hasattr(self, 'quantization_config')
247
+ quantization_config_backup = getattr(self, 'quantization_config', None)
248
+
249
+ if has_quant_config and quantization_config_backup is None:
250
+ delattr(self, 'quantization_config')
251
+
252
+ try:
253
+ output = super().to_dict()
254
+ output['auto_map'] = {
255
+ "AutoConfig": "caca_transformers.CacaConfig",
256
+ "AutoModel": "caca_transformers.CacaModel",
257
+ "AutoModelForCausalLM": "caca_transformers.CacaForCausalLM"
258
+ }
259
+ finally:
260
+ if has_quant_config:
261
+ self.quantization_config = quantization_config_backup
262
+
263
+ return output
264
+
265
+ # --- Arsitektur Model ---
266
+ class CacaRMSNorm(nn.Module):
267
+ def __init__(self, hidden_size, eps=1e-6):
268
+ super().__init__()
269
+ self.weight = nn.Parameter(torch.ones(hidden_size))
270
+ self.eps = eps
271
+
272
+ def forward(self, x):
273
+ input_dtype = x.dtype
274
+ x = x.float()
275
+ variance = x.pow(2).mean(-1, keepdim=True)
276
+ x = x * torch.rsqrt(variance + self.eps)
277
+ return (self.weight * x).to(input_dtype)
278
+
279
+ class LayerScale(nn.Module):
280
+ def __init__(self, dim, init_value=1e-5):
281
+ super().__init__()
282
+ self.gamma = nn.Parameter(init_value * torch.ones(dim))
283
+
284
+ def forward(self, x):
285
+ return self.gamma * x
286
+
287
+ class StochasticDepth(nn.Module):
288
+ def __init__(self, drop_prob=0.0):
289
+ super().__init__()
290
+ self.drop_prob = drop_prob
291
+
292
+ def forward(self, x, training=True):
293
+ if not training or self.drop_prob == 0.0:
294
+ return x
295
+ keep_prob = 1 - self.drop_prob
296
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
297
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
298
+ random_tensor.floor_()
299
+ return x.div(keep_prob) * random_tensor
300
+
301
+ class CacaRotaryEmbedding(nn.Module):
302
+ def __init__(
303
+ self,
304
+ dim,
305
+ max_position_embeddings=8192,
306
+ base=10000.0,
307
+ scaling_factor=1.0,
308
+ scaling_type=None,
309
+ ):
310
+ super().__init__()
311
+ self.dim = dim
312
+ self.max_position_embeddings = max_position_embeddings
313
+ self.base = base
314
+ self.scaling_factor = scaling_factor
315
+ self.scaling_type = scaling_type
316
+ inv_freq = 1.0 / (
317
+ self.base ** (torch.arange(0, self.dim, 2).float() / self.dim)
318
+ )
319
+ if scaling_type == "linear":
320
+ inv_freq = inv_freq / scaling_factor
321
+ elif scaling_type == "dynamic":
322
+ inv_freq = inv_freq
323
+ elif scaling_type == "yarn":
324
+ inv_freq = self._yarn_get_inv_freq(inv_freq)
325
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
326
+
327
+ def _yarn_get_inv_freq(self, inv_freq):
328
+ if len(inv_freq) == 0:
329
+ return inv_freq
330
+ alpha = self.scaling_factor
331
+ beta_fast = 32
332
+ beta_slow = 1
333
+ freq_threshold = 1 / (self.max_position_embeddings * beta_fast)
334
+ low_freq_mask = inv_freq > freq_threshold
335
+ high_freq_mask = ~low_freq_mask
336
+ low_freq = inv_freq[low_freq_mask]
337
+ high_freq = inv_freq[high_freq_mask]
338
+ if len(low_freq) > 0:
339
+ low_freq = low_freq / alpha
340
+ if len(high_freq) > 0:
341
+ smooth_factor = (
342
+ self.max_position_embeddings * beta_slow / high_freq - beta_fast
343
+ ) / (beta_slow - beta_fast)
344
+ smooth_factor = torch.clamp(smooth_factor, 0.0, 1.0)
345
+ high_freq = (1 - smooth_factor) * (
346
+ high_freq / alpha
347
+ ) + smooth_factor * high_freq
348
+ result = torch.zeros_like(inv_freq)
349
+ result[low_freq_mask] = low_freq
350
+ result[high_freq_mask] = high_freq
351
+ return result
352
+
353
+ def forward(self, x, seq_len, position_offset=0):
354
+ t = torch.arange(
355
+ position_offset, position_offset + seq_len, device=x.device
356
+ ).type_as(self.inv_freq)
357
+ if self.scaling_type == "dynamic":
358
+ if seq_len > self.max_position_embeddings:
359
+ dynamic_scale = seq_len / self.max_position_embeddings
360
+ t = t / dynamic_scale
361
+ freqs = torch.outer(t, self.inv_freq)
362
+ emb = torch.cat((freqs, freqs), dim=-1)
363
+ cos = emb.cos()[None, None, :, :]
364
+ sin = emb.sin()[None, None, :, :]
365
+ return cos.to(x.dtype), sin.to(x.dtype)
366
+
367
+ class ALiBiPositionalBias(nn.Module):
368
+ def __init__(self, num_heads, max_positions=8192):
369
+ super().__init__()
370
+ self.num_heads = num_heads
371
+ self.max_positions = max_positions
372
+ slopes = torch.tensor(self._get_slopes(num_heads))
373
+ self.register_buffer("slopes", slopes, persistent=False)
374
+
375
+ def _get_slopes(self, n):
376
+ def get_slopes_power_of_2(n):
377
+ start = 2 ** (-(2 ** -(math.log2(n) - 3)))
378
+ ratio = start
379
+ return [start * (ratio**i) for i in range(n)]
380
+
381
+ if math.log2(n).is_integer():
382
+ return get_slopes_power_of_2(n)
383
+ else:
384
+ closest_power_of_2 = 2 ** math.floor(math.log2(n))
385
+ return (
386
+ get_slopes_power_of_2(closest_power_of_2)
387
+ + self._get_slopes(2 * closest_power_of_2)[0::2][
388
+ : n - closest_power_of_2
389
+ ]
390
+ )
391
+
392
+ def forward(self, seq_len, key_len=None):
393
+ if key_len is None:
394
+ key_len = seq_len
395
+ query_pos = torch.arange(seq_len, device=self.slopes.device).unsqueeze(1)
396
+ key_pos = torch.arange(key_len, device=self.slopes.device).unsqueeze(0)
397
+ relative_pos = key_pos - query_pos
398
+ bias = relative_pos.unsqueeze(0) * self.slopes.unsqueeze(1).unsqueeze(2)
399
+ return bias.unsqueeze(0)
400
+
401
+ def rotate_half(x):
402
+ x1 = x[..., : x.shape[-1] // 2]
403
+ x2 = x[..., x.shape[-1] // 2 :]
404
+ return torch.cat((-x2, x1), dim=-1)
405
+
406
+ def apply_rotary_pos_emb(q, k, cos, sin):
407
+ cos = cos.to(q.dtype)
408
+ sin = sin.to(q.dtype)
409
+ q_embed = (q * cos) + (rotate_half(q) * sin)
410
+ k_embed = (k * cos) + (rotate_half(k) * sin)
411
+ return q_embed, k_embed
412
+
413
+ def soft_cap_logits(x, cap):
414
+ if cap is None or cap <= 0:
415
+ return x
416
+ return x.clamp(-cap * 0.99, cap * 0.99)
417
+
418
+ class TopKRouter(nn.Module):
419
+ def __init__(self, hidden_size, num_experts, num_experts_per_tok):
420
+ super().__init__()
421
+ self.num_experts = num_experts
422
+ self.num_experts_per_tok = num_experts_per_tok
423
+ self.gate = nn.Linear(hidden_size, num_experts, bias=False)
424
+ self.gate_norm = nn.LayerNorm(hidden_size)
425
+
426
+ def forward(self, hidden_states):
427
+ batch_size, seq_len, hidden_size = hidden_states.shape
428
+ hidden_states = hidden_states.view(-1, hidden_size)
429
+ hidden_states = self.gate_norm(hidden_states)
430
+ router_logits = self.gate(hidden_states)
431
+ router_logits = torch.clamp(router_logits, min=-10, max=10)
432
+ routing_weights = F.softmax(router_logits, dim=-1)
433
+ top_k_weights, top_k_indices = torch.topk(
434
+ routing_weights, self.num_experts_per_tok, dim=-1
435
+ )
436
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-8)
437
+ router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
438
+ expert_usage = router_probs.mean(dim=0)
439
+ mean_usage = expert_usage.mean()
440
+ aux_loss = ((expert_usage - mean_usage) ** 2).sum() / (mean_usage + 1e-10)
441
+ router_logits_for_z = router_logits.to(torch.float32)
442
+ z_loss = torch.logsumexp(router_logits_for_z, dim=-1).mean()
443
+ return top_k_weights, top_k_indices, aux_loss, z_loss
444
+
445
+ class ExpertChoiceRouter(nn.Module):
446
+ def __init__(self, hidden_size, num_experts, expert_choice_k):
447
+ super().__init__()
448
+ self.num_experts = num_experts
449
+ self.expert_choice_k = expert_choice_k
450
+ self.gate = nn.Linear(hidden_size, num_experts, bias=False)
451
+
452
+ def forward(self, hidden_states):
453
+ batch_size, seq_len, hidden_size = hidden_states.shape
454
+ total_tokens = batch_size * seq_len
455
+ hidden_states_flat = hidden_states.view(-1, hidden_size)
456
+ router_logits = self.gate(hidden_states_flat)
457
+ router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
458
+ router_probs_t = router_probs.t()
459
+ capacity = max(1, int(self.expert_choice_k * total_tokens / self.num_experts))
460
+ top_k_values, top_k_indices = torch.topk(
461
+ router_probs_t, k=min(capacity, total_tokens), dim=-1
462
+ )
463
+ expert_mask = torch.zeros(
464
+ self.num_experts, total_tokens, device=hidden_states.device
465
+ )
466
+ for expert_idx in range(self.num_experts):
467
+ expert_mask[expert_idx, top_k_indices[expert_idx]] = 1.0
468
+ routing_weights = expert_mask.t() * router_probs
469
+ aux_loss = (router_probs.mean(dim=0) ** 2).sum() * self.num_experts
470
+ z_loss = torch.logsumexp(router_logits, dim=-1).mean()
471
+ return routing_weights, aux_loss, z_loss
472
+
473
+ class Expert(nn.Module):
474
+ def __init__(self, config):
475
+ super().__init__()
476
+ self.gate_proj = nn.Linear(
477
+ config.hidden_size, config.intermediate_size, bias=config.mlp_bias
478
+ )
479
+ self.up_proj = nn.Linear(
480
+ config.hidden_size, config.intermediate_size, bias=config.mlp_bias
481
+ )
482
+ self.down_proj = nn.Linear(
483
+ config.intermediate_size, config.hidden_size, bias=config.mlp_bias
484
+ )
485
+ self.dropout = nn.Dropout(config.hidden_dropout)
486
+
487
+ def forward(self, x):
488
+ gate = F.silu(self.gate_proj(x))
489
+ up = self.up_proj(x)
490
+ hidden = gate * up
491
+ hidden = self.dropout(hidden)
492
+ return self.down_proj(hidden)
493
+
494
+ class MixtureOfExperts(nn.Module):
495
+ def __init__(self, config):
496
+ super().__init__()
497
+ self.config = config
498
+ self.num_experts = config.num_experts
499
+ self.num_experts_per_tok = config.num_experts_per_tok
500
+ self.use_expert_choice = config.use_expert_choice
501
+ self.experts = nn.ModuleList([Expert(config) for _ in range(self.num_experts)])
502
+ if self.use_expert_choice:
503
+ self.router = ExpertChoiceRouter(
504
+ config.hidden_size, config.num_experts, config.expert_choice_k
505
+ )
506
+ else:
507
+ self.router = TopKRouter(
508
+ config.hidden_size, config.num_experts, config.num_experts_per_tok
509
+ )
510
+
511
+ def forward(self, hidden_states):
512
+ batch_size, seq_len, hidden_size = hidden_states.shape
513
+ hidden_states_flat = hidden_states.view(-1, hidden_size)
514
+ if self.use_expert_choice:
515
+ routing_weights, aux_loss, z_loss = self.router(hidden_states)
516
+ final_output = torch.zeros_like(hidden_states_flat)
517
+ for expert_idx, expert in enumerate(self.experts):
518
+ expert_mask = routing_weights[:, expert_idx] > 0
519
+ if expert_mask.any():
520
+ expert_input = hidden_states_flat[expert_mask]
521
+ expert_output = expert(expert_input)
522
+ final_output[expert_mask] += (
523
+ expert_output
524
+ * routing_weights[expert_mask, expert_idx : expert_idx + 1]
525
+ )
526
+ else:
527
+ top_k_weights, top_k_indices, aux_loss, z_loss = self.router(hidden_states)
528
+ final_output = torch.zeros_like(hidden_states_flat)
529
+ for i in range(self.num_experts_per_tok):
530
+ expert_indices = top_k_indices[:, i]
531
+ expert_weights = top_k_weights[:, i : i + 1]
532
+ for expert_idx in range(self.num_experts):
533
+ expert_mask = expert_indices == expert_idx
534
+ if expert_mask.any():
535
+ expert_input = hidden_states_flat[expert_mask]
536
+ expert_output = self.experts[expert_idx](expert_input)
537
+ final_output[expert_mask] += (
538
+ expert_output * expert_weights[expert_mask]
539
+ )
540
+ final_output = final_output.view(batch_size, seq_len, hidden_size)
541
+ return final_output, aux_loss, z_loss
542
+
543
+ class MixtureOfDepthsRouter(nn.Module):
544
+ def __init__(self, hidden_size, capacity_factor=0.5, route_method="learned"):
545
+ super().__init__()
546
+ self.capacity_factor = capacity_factor
547
+ self.route_method = route_method
548
+ if route_method == "learned":
549
+ self.router = nn.Linear(hidden_size, 1)
550
+
551
+ def forward(self, hidden_states):
552
+ batch_size, seq_len, hidden_size = hidden_states.shape
553
+ if self.route_method == "learned":
554
+ routing_logits = self.router(hidden_states).squeeze(-1)
555
+ elif self.route_method == "random":
556
+ routing_logits = torch.rand(
557
+ batch_size, seq_len, device=hidden_states.device
558
+ )
559
+ else:
560
+ routing_logits = torch.zeros(
561
+ batch_size, seq_len, device=hidden_states.device
562
+ )
563
+ capacity = max(1, int(seq_len * self.capacity_factor))
564
+ _, top_indices = torch.topk(routing_logits, k=capacity, dim=-1)
565
+ process_mask = torch.zeros(
566
+ batch_size, seq_len, dtype=torch.bool, device=hidden_states.device
567
+ )
568
+ process_mask.scatter_(1, top_indices, True)
569
+ return process_mask
570
+
571
+ class CacaAttention(nn.Module):
572
+ def __init__(self, config, layer_idx=None):
573
+ super().__init__()
574
+ self.config = config
575
+ self.layer_idx = layer_idx
576
+ self.hidden_size = config.hidden_size
577
+ self.num_heads = config.num_attention_heads
578
+ self.num_key_value_heads = config.num_key_value_heads
579
+ self.head_dim = config.head_dim
580
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
581
+ self.sliding_window = config.sliding_window
582
+ self.attn_logit_softcapping = config.attn_logit_softcapping
583
+ self.attention_sink_size = config.attention_sink_size
584
+ self.attention_sink_window = config.attention_sink_window
585
+ self.q_proj = nn.Linear(
586
+ self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
587
+ )
588
+ self.k_proj = nn.Linear(
589
+ self.hidden_size,
590
+ self.num_key_value_heads * self.head_dim,
591
+ bias=config.attention_bias,
592
+ )
593
+ self.v_proj = nn.Linear(
594
+ self.hidden_size,
595
+ self.num_key_value_heads * self.head_dim,
596
+ bias=config.attention_bias,
597
+ )
598
+ self.o_proj = nn.Linear(
599
+ self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
600
+ )
601
+ if config.use_qk_norm:
602
+ self.q_norm = CacaRMSNorm(self.head_dim, eps=config.qk_norm_eps)
603
+ self.k_norm = CacaRMSNorm(self.head_dim, eps=config.qk_norm_eps)
604
+ else:
605
+ self.q_norm = None
606
+ self.k_norm = None
607
+ if config.use_rotary_embeddings:
608
+ scaling_factor = 1.0
609
+ scaling_type = None
610
+ if config.rope_scaling is not None:
611
+ scaling_type = config.rope_scaling.get("type", "linear")
612
+ scaling_factor = config.rope_scaling.get("factor", 1.0)
613
+ self.rotary_emb = CacaRotaryEmbedding(
614
+ self.head_dim,
615
+ config.max_position_embeddings,
616
+ config.rope_theta,
617
+ scaling_factor=scaling_factor,
618
+ scaling_type=scaling_type,
619
+ )
620
+ else:
621
+ self.rotary_emb = None
622
+ if config.use_alibi:
623
+ self.alibi = ALiBiPositionalBias(
624
+ self.num_heads, config.max_position_embeddings
625
+ )
626
+ else:
627
+ self.alibi = None
628
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
629
+ self.is_global_attention = self._determine_attention_type(config, layer_idx)
630
+ self.has_flash_attn = HAS_FLASH_ATTN and config.use_flash_attn
631
+ self.has_xformers = HAS_XFORMERS
632
+ self.has_sdpa = HAS_SDPA
633
+ self._mask_cache = {}
634
+ self._max_cache_size = 10
635
+
636
+ def _determine_attention_type(self, config, layer_idx):
637
+ if layer_idx is None:
638
+ return False
639
+ if config.attention_pattern == "all_global":
640
+ return True
641
+ elif config.attention_pattern == "all_local":
642
+ return False
643
+ elif config.attention_pattern == "interleaved":
644
+ return (layer_idx % config.global_attention_every_n_layers) == (
645
+ config.global_attention_every_n_layers - 1
646
+ )
647
+ return False
648
+
649
+ def forward(
650
+ self, hidden_states, attention_mask=None, past_key_value=None, use_cache=False
651
+ ):
652
+ batch_size, seq_length, _ = hidden_states.size()
653
+ query_states = self.q_proj(hidden_states)
654
+ key_states = self.k_proj(hidden_states)
655
+ value_states = self.v_proj(hidden_states)
656
+ query_states = query_states.view(
657
+ batch_size, seq_length, self.num_heads, self.head_dim
658
+ ).transpose(1, 2)
659
+ key_states = key_states.view(
660
+ batch_size, seq_length, self.num_key_value_heads, self.head_dim
661
+ ).transpose(1, 2)
662
+ value_states = value_states.view(
663
+ batch_size, seq_length, self.num_key_value_heads, self.head_dim
664
+ ).transpose(1, 2)
665
+ if self.q_norm is not None and self.k_norm is not None:
666
+ query_states = self.q_norm(query_states)
667
+ key_states = self.k_norm(key_states)
668
+
669
+ position_offset = 0
670
+ if past_key_value is not None:
671
+ try:
672
+ if isinstance(past_key_value, (tuple, list)) and len(past_key_value) >= 2:
673
+ if past_key_value[0] is not None:
674
+ position_offset = past_key_value[0].shape[2]
675
+ except (IndexError, AttributeError, TypeError):
676
+ position_offset = 0
677
+
678
+ if self.rotary_emb is not None:
679
+ cos, sin = self.rotary_emb(query_states, seq_length, position_offset)
680
+ query_states, key_states = apply_rotary_pos_emb(
681
+ query_states, key_states, cos, sin
682
+ )
683
+
684
+ if past_key_value is not None and past_key_value[0] is not None:
685
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
686
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
687
+
688
+ if use_cache:
689
+ present_key_value = (key_states, value_states)
690
+ else:
691
+ present_key_value = None
692
+
693
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
694
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
695
+ kv_seq_len = key_states.shape[-2]
696
+
697
+ use_sliding_window = (not self.is_global_attention) and (
698
+ self.sliding_window is not None
699
+ )
700
+ if self.has_flash_attn and attention_mask is None:
701
+ if query_states.device.type == "cuda" and query_states.dtype in [
702
+ torch.float16,
703
+ torch.bfloat16,
704
+ ]:
705
+ try:
706
+ attn_output = self._flash_attention(
707
+ query_states, key_states, value_states, use_sliding_window
708
+ )
709
+ except Exception as e:
710
+ logger.warning(f"Flash Attention gagal, pakai fallback: {e}")
711
+ attn_output = self._fallback_attention(
712
+ query_states,
713
+ key_states,
714
+ value_states,
715
+ attention_mask,
716
+ kv_seq_len,
717
+ use_sliding_window,
718
+ )
719
+ else:
720
+ attn_output = self._fallback_attention(
721
+ query_states,
722
+ key_states,
723
+ value_states,
724
+ attention_mask,
725
+ kv_seq_len,
726
+ use_sliding_window,
727
+ )
728
+ else:
729
+ attn_output = self._fallback_attention(
730
+ query_states,
731
+ key_states,
732
+ value_states,
733
+ attention_mask,
734
+ kv_seq_len,
735
+ use_sliding_window,
736
+ )
737
+ attn_output = self.o_proj(attn_output)
738
+ return attn_output, present_key_value
739
+
740
+ def _flash_attention(
741
+ self, query_states, key_states, value_states, use_sliding_window
742
+ ):
743
+ batch_size, num_heads, seq_length, head_dim = query_states.shape
744
+ kv_seq_len = key_states.shape[-2]
745
+ original_dtype = query_states.dtype
746
+ if original_dtype == torch.bfloat16:
747
+ if not torch.cuda.is_bf16_supported():
748
+ logger.warning("BF16 not supported on this GPU, falling back to FP16")
749
+ original_dtype = torch.float16
750
+ compute_dtype = (
751
+ torch.bfloat16
752
+ if original_dtype not in [torch.float16, torch.bfloat16]
753
+ else original_dtype
754
+ )
755
+ query_states = query_states.transpose(1, 2).contiguous().to(compute_dtype)
756
+ key_states = key_states.transpose(1, 2).contiguous().to(compute_dtype)
757
+ value_states = value_states.transpose(1, 2).contiguous().to(compute_dtype)
758
+ if use_sliding_window and self.sliding_window < kv_seq_len:
759
+ window_size = (self.sliding_window, 0)
760
+ else:
761
+ window_size = (-1, 0)
762
+ attn_output = flash_attn_func(
763
+ query_states,
764
+ key_states,
765
+ value_states,
766
+ dropout_p=self.config.attention_dropout if self.training else 0.0,
767
+ softmax_scale=None,
768
+ causal=True,
769
+ window_size=window_size,
770
+ )
771
+ attn_output = attn_output.to(original_dtype)
772
+ attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
773
+ return attn_output
774
+
775
+ def _fallback_attention(
776
+ self,
777
+ query_states,
778
+ key_states,
779
+ value_states,
780
+ attention_mask,
781
+ kv_seq_len,
782
+ use_sliding_window,
783
+ ):
784
+ device_type = query_states.device.type
785
+ if self.has_xformers and device_type == "cuda" and attention_mask is None:
786
+ try:
787
+ return self._xformers_attention(
788
+ query_states,
789
+ key_states,
790
+ value_states,
791
+ kv_seq_len,
792
+ use_sliding_window,
793
+ )
794
+ except Exception as e:
795
+ logger.warning(f"xFormers gagal, pakai SDPA: {e}")
796
+ if self.has_sdpa:
797
+ return self._sdpa_attention(
798
+ query_states,
799
+ key_states,
800
+ value_states,
801
+ attention_mask,
802
+ kv_seq_len,
803
+ use_sliding_window,
804
+ )
805
+ else:
806
+ return self._standard_attention(
807
+ query_states,
808
+ key_states,
809
+ value_states,
810
+ attention_mask,
811
+ kv_seq_len,
812
+ use_sliding_window,
813
+ )
814
+
815
+ def _create_causal_mask(
816
+ self, query_length, key_length, dtype, device, use_sliding_window
817
+ ):
818
+ cache_key = (
819
+ query_length,
820
+ key_length,
821
+ str(dtype),
822
+ use_sliding_window,
823
+ self.sliding_window if use_sliding_window else None,
824
+ )
825
+ if cache_key in self._mask_cache:
826
+ cached_mask = self._mask_cache[cache_key]
827
+ return cached_mask.to(device, dtype)
828
+ if query_length > key_length:
829
+ key_length = query_length
830
+ query_pos = torch.arange(query_length, device=device) + (
831
+ key_length - query_length
832
+ )
833
+ key_pos = torch.arange(key_length, device=device)
834
+ distance = query_pos[:, None] - key_pos[None, :]
835
+ mask = distance < 0
836
+
837
+ if use_sliding_window and self.sliding_window is not None:
838
+ if self.config.use_attention_sink and self.attention_sink_size > 0:
839
+ is_sink = key_pos[None, :] < self.attention_sink_size
840
+ in_window = (distance >= 0) & (distance <= self.sliding_window)
841
+ mask = (distance < 0) | ((~is_sink) & (~in_window))
842
+
843
+ else:
844
+ too_far_mask = distance > self.sliding_window
845
+ mask = mask | too_far_mask
846
+ float_mask = torch.zeros(
847
+ 1, 1, query_length, key_length, dtype=dtype, device=device
848
+ )
849
+ float_mask.masked_fill_(mask.unsqueeze(0).unsqueeze(0), -1e9)
850
+ if len(self._mask_cache) >= self._max_cache_size:
851
+ oldest_key = next(iter(self._mask_cache))
852
+ del self._mask_cache[oldest_key]
853
+ self._mask_cache[cache_key] = float_mask.detach().cpu()
854
+ return float_mask
855
+
856
+ def _xformers_attention(
857
+ self, query_states, key_states, value_states, kv_seq_len, use_sliding_window
858
+ ):
859
+ batch_size, num_heads, seq_length, head_dim = query_states.shape
860
+ attn_bias = self._create_causal_mask(
861
+ seq_length,
862
+ kv_seq_len,
863
+ query_states.dtype,
864
+ query_states.device,
865
+ use_sliding_window,
866
+ )
867
+ query_states = query_states.transpose(1, 2)
868
+ key_states = key_states.transpose(1, 2)
869
+ value_states = value_states.transpose(1, 2)
870
+ attn_output = memory_efficient_attention(
871
+ query_states,
872
+ key_states,
873
+ value_states,
874
+ attn_bias=attn_bias,
875
+ p=self.config.attention_dropout if self.training else 0.0,
876
+ )
877
+ attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
878
+ return attn_output
879
+
880
+ def _sdpa_attention(
881
+ self,
882
+ query_states,
883
+ key_states,
884
+ value_states,
885
+ attention_mask,
886
+ kv_seq_len,
887
+ use_sliding_window,
888
+ ):
889
+ batch_size, num_heads, seq_length, head_dim = query_states.shape
890
+ if attention_mask is None:
891
+ attention_mask = self._create_causal_mask(
892
+ seq_length,
893
+ kv_seq_len,
894
+ query_states.dtype,
895
+ query_states.device,
896
+ use_sliding_window,
897
+ )
898
+ if self.alibi is not None:
899
+ alibi_bias = self.alibi(seq_length, kv_seq_len)
900
+ attention_mask = attention_mask + alibi_bias
901
+ attn_output = F.scaled_dot_product_attention(
902
+ query_states,
903
+ key_states,
904
+ value_states,
905
+ attn_mask=attention_mask,
906
+ dropout_p=self.config.attention_dropout if self.training else 0.0,
907
+ is_causal=False,
908
+ )
909
+ attn_output = attn_output.transpose(1, 2).contiguous()
910
+ attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
911
+ return attn_output
912
+
913
+ def _standard_attention(
914
+ self,
915
+ query_states,
916
+ key_states,
917
+ value_states,
918
+ attention_mask,
919
+ kv_seq_len,
920
+ use_sliding_window,
921
+ ):
922
+ batch_size, num_heads, seq_length, head_dim = query_states.shape
923
+ attn_weights = torch.matmul(
924
+ query_states, key_states.transpose(2, 3)
925
+ ) / math.sqrt(head_dim)
926
+ attn_weights = torch.clamp(attn_weights, min=-50.0, max=50.0)
927
+ attn_weights = soft_cap_logits(attn_weights, self.attn_logit_softcapping)
928
+ if attention_mask is None:
929
+ attention_mask = self._create_causal_mask(
930
+ seq_length,
931
+ kv_seq_len,
932
+ attn_weights.dtype,
933
+ attn_weights.device,
934
+ use_sliding_window,
935
+ )
936
+ if self.alibi is not None:
937
+ alibi_bias = self.alibi(seq_length, kv_seq_len)
938
+ attention_mask = attention_mask + alibi_bias
939
+ attn_weights = attn_weights + attention_mask
940
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
941
+ query_states.dtype
942
+ )
943
+ attn_weights = self.attention_dropout(attn_weights)
944
+ attn_output = torch.matmul(attn_weights, value_states)
945
+ attn_output = attn_output.transpose(1, 2).contiguous()
946
+ attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
947
+ return attn_output
948
+
949
+ class CacaCrossAttention(nn.Module):
950
+ def __init__(self, config):
951
+ super().__init__()
952
+ self.config = config
953
+ self.hidden_size = config.hidden_size
954
+ self.num_heads = config.num_attention_heads
955
+ self.num_key_value_heads = config.num_key_value_heads
956
+ self.head_dim = config.head_dim
957
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
958
+ self.q_proj = nn.Linear(
959
+ self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
960
+ )
961
+ self.k_proj = nn.Linear(
962
+ self.hidden_size,
963
+ self.num_key_value_heads * self.head_dim,
964
+ bias=config.attention_bias,
965
+ )
966
+ self.v_proj = nn.Linear(
967
+ self.hidden_size,
968
+ self.num_key_value_heads * self.head_dim,
969
+ bias=config.attention_bias,
970
+ )
971
+ self.o_proj = nn.Linear(
972
+ self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias
973
+ )
974
+ self.attention_dropout = nn.Dropout(config.attention_dropout)
975
+
976
+ def forward(self, hidden_states, encoder_hidden_states, attention_mask=None):
977
+ batch_size, seq_length, _ = hidden_states.size()
978
+ encoder_seq_length = encoder_hidden_states.size(1)
979
+ query_states = self.q_proj(hidden_states)
980
+ key_states = self.k_proj(encoder_hidden_states)
981
+ value_states = self.v_proj(encoder_hidden_states)
982
+ query_states = query_states.view(
983
+ batch_size, seq_length, self.num_heads, self.head_dim
984
+ ).transpose(1, 2)
985
+ key_states = key_states.view(
986
+ batch_size, encoder_seq_length, self.num_key_value_heads, self.head_dim
987
+ ).transpose(1, 2)
988
+ value_states = value_states.view(
989
+ batch_size, encoder_seq_length, self.num_key_value_heads, self.head_dim
990
+ ).transpose(1, 2)
991
+ key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
992
+ value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)
993
+ attn_weights = torch.matmul(
994
+ query_states, key_states.transpose(2, 3)
995
+ ) / math.sqrt(self.head_dim)
996
+ if attention_mask is not None:
997
+ attn_weights = attn_weights + attention_mask
998
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
999
+ query_states.dtype
1000
+ )
1001
+ attn_weights = self.attention_dropout(attn_weights)
1002
+ attn_output = torch.matmul(attn_weights, value_states)
1003
+ attn_output = attn_output.transpose(1, 2).contiguous()
1004
+ attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
1005
+ attn_output = self.o_proj(attn_output)
1006
+ return attn_output
1007
+
1008
+ class CacaMLP(nn.Module):
1009
+ def __init__(self, config):
1010
+ super().__init__()
1011
+ self.hidden_size = config.hidden_size
1012
+ self.intermediate_size = config.intermediate_size
1013
+ self.gate_proj = nn.Linear(
1014
+ self.hidden_size, self.intermediate_size, bias=config.mlp_bias
1015
+ )
1016
+ self.up_proj = nn.Linear(
1017
+ self.hidden_size, self.intermediate_size, bias=config.mlp_bias
1018
+ )
1019
+ self.down_proj = nn.Linear(
1020
+ self.intermediate_size, self.hidden_size, bias=config.mlp_bias
1021
+ )
1022
+ self.dropout = nn.Dropout(config.hidden_dropout)
1023
+
1024
+ def forward(self, x):
1025
+ gate = F.silu(self.gate_proj(x))
1026
+ up = self.up_proj(x)
1027
+ hidden = gate * up
1028
+ hidden = self.dropout(hidden)
1029
+ output = self.down_proj(hidden)
1030
+ return output
1031
+
1032
+ class CacaDecoderLayer(nn.Module):
1033
+ def __init__(self, config, layer_idx):
1034
+ super().__init__()
1035
+ self.config = config
1036
+ self.layer_idx = layer_idx
1037
+ self.self_attn = CacaAttention(config, layer_idx=layer_idx)
1038
+ self.use_moe = config.use_moe and (layer_idx % config.moe_layer_frequency == 0)
1039
+ if self.use_moe:
1040
+ self.mlp = MixtureOfExperts(config)
1041
+ else:
1042
+ self.mlp = CacaMLP(config)
1043
+ self.use_cross_attention = config.use_cross_attention and (
1044
+ layer_idx % config.cross_attention_frequency == 0
1045
+ )
1046
+ if self.use_cross_attention:
1047
+ self.cross_attn = CacaCrossAttention(config)
1048
+ self.cross_attn_layernorm = CacaRMSNorm(
1049
+ config.hidden_size, config.rms_norm_eps
1050
+ )
1051
+ self.input_layernorm = CacaRMSNorm(config.hidden_size, config.rms_norm_eps)
1052
+ self.post_attention_layernorm = CacaRMSNorm(
1053
+ config.hidden_size, config.rms_norm_eps
1054
+ )
1055
+ self.residual_dropout = nn.Dropout(config.residual_dropout)
1056
+ if config.use_layer_scale:
1057
+ self.layer_scale_1 = LayerScale(config.hidden_size, config.layer_scale_init)
1058
+ self.layer_scale_2 = LayerScale(config.hidden_size, config.layer_scale_init)
1059
+ if self.use_cross_attention:
1060
+ self.layer_scale_cross = LayerScale(
1061
+ config.hidden_size, config.layer_scale_init
1062
+ )
1063
+ else:
1064
+ self.layer_scale_1 = None
1065
+ self.layer_scale_2 = None
1066
+ self.layer_scale_cross = None
1067
+ if config.use_stochastic_depth:
1068
+ drop_prob = (
1069
+ config.stochastic_depth_prob * layer_idx / config.num_hidden_layers
1070
+ )
1071
+ self.stochastic_depth = StochasticDepth(drop_prob)
1072
+ else:
1073
+ self.stochastic_depth = None
1074
+ if config.use_mixture_of_depths:
1075
+ self.mod_router = MixtureOfDepthsRouter(
1076
+ config.hidden_size, config.mod_capacity_factor, config.mod_route_method
1077
+ )
1078
+ else:
1079
+ self.mod_router = None
1080
+
1081
+ def forward(
1082
+ self,
1083
+ hidden_states,
1084
+ attention_mask=None,
1085
+ encoder_hidden_states=None,
1086
+ encoder_attention_mask=None,
1087
+ past_key_value=None,
1088
+ use_cache=False,
1089
+ ):
1090
+ aux_loss = 0.0
1091
+ z_loss = 0.0
1092
+ if self.mod_router is not None:
1093
+ process_mask = self.mod_router(hidden_states)
1094
+ tokens_to_process = hidden_states[process_mask]
1095
+ if tokens_to_process.numel() == 0:
1096
+ present_key_value = past_key_value if use_cache else None
1097
+ return hidden_states, present_key_value, aux_loss, z_loss
1098
+ else:
1099
+ process_mask = None
1100
+ tokens_to_process = hidden_states
1101
+ residual = tokens_to_process
1102
+ tokens_to_process = self.input_layernorm(tokens_to_process)
1103
+ attn_output, present_key_value = self.self_attn(
1104
+ tokens_to_process,
1105
+ attention_mask,
1106
+ past_key_value=past_key_value,
1107
+ use_cache=use_cache,
1108
+ )
1109
+ if self.layer_scale_1 is not None:
1110
+ attn_output = self.layer_scale_1(attn_output)
1111
+ if self.stochastic_depth is not None:
1112
+ attn_output = self.stochastic_depth(attn_output, self.training)
1113
+ tokens_to_process = residual + self.residual_dropout(attn_output)
1114
+ if self.use_cross_attention and encoder_hidden_states is not None:
1115
+ residual = tokens_to_process
1116
+ tokens_to_process = self.cross_attn_layernorm(tokens_to_process)
1117
+ cross_attn_output = self.cross_attn(
1118
+ tokens_to_process,
1119
+ encoder_hidden_states,
1120
+ attention_mask=encoder_attention_mask,
1121
+ )
1122
+ if self.layer_scale_cross is not None:
1123
+ cross_attn_output = self.layer_scale_cross(cross_attn_output)
1124
+ if self.stochastic_depth is not None:
1125
+ cross_attn_output = self.stochastic_depth(
1126
+ cross_attn_output, self.training
1127
+ )
1128
+ tokens_to_process = residual + self.residual_dropout(cross_attn_output)
1129
+ residual = tokens_to_process
1130
+ tokens_to_process = self.post_attention_layernorm(tokens_to_process)
1131
+ if self.use_moe:
1132
+ mlp_output, moe_aux_loss, moe_z_loss = self.mlp(tokens_to_process)
1133
+ aux_loss += moe_aux_loss
1134
+ z_loss += moe_z_loss
1135
+ else:
1136
+ mlp_output = self.mlp(tokens_to_process)
1137
+ if self.layer_scale_2 is not None:
1138
+ mlp_output = self.layer_scale_2(mlp_output)
1139
+ if self.stochastic_depth is not None:
1140
+ mlp_output = self.stochastic_depth(mlp_output, self.training)
1141
+ tokens_to_process = residual + self.residual_dropout(mlp_output)
1142
+ if process_mask is not None:
1143
+ hidden_states[process_mask] = tokens_to_process
1144
+ else:
1145
+ hidden_states = tokens_to_process
1146
+ return hidden_states, present_key_value, aux_loss, z_loss
1147
+
1148
+ class VisionEncoder(nn.Module):
1149
+ def __init__(self, config):
1150
+ super().__init__()
1151
+ vision_config = config.vision_config
1152
+ self.patch_size = vision_config.get("patch_size", 14)
1153
+ self.image_size = vision_config.get("image_size", 224)
1154
+ self.num_channels = vision_config.get("num_channels", 3)
1155
+ self.hidden_size = vision_config.get("hidden_size", 1024)
1156
+ self.num_layers = vision_config.get("num_layers", 24)
1157
+ self.num_heads = vision_config.get("num_heads", 16)
1158
+ self.intermediate_size = vision_config.get("intermediate_size", 4096)
1159
+ self.layer_norm_eps = vision_config.get("layer_norm_eps", 1e-6)
1160
+ self.num_patches = (self.image_size // self.patch_size) ** 2
1161
+ self.patch_embed = nn.Sequential(
1162
+ nn.Conv2d(
1163
+ self.num_channels,
1164
+ self.hidden_size,
1165
+ kernel_size=self.patch_size,
1166
+ stride=self.patch_size,
1167
+ bias=False,
1168
+ ),
1169
+ nn.Dropout(p=vision_config.get("dropout", 0.0)),
1170
+ )
1171
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, self.hidden_size))
1172
+ self.pos_embed = nn.Parameter(
1173
+ torch.zeros(1, self.num_patches + 1, self.hidden_size)
1174
+ )
1175
+ self.pos_drop = nn.Dropout(p=vision_config.get("dropout", 0.0))
1176
+ self.blocks = nn.ModuleList(
1177
+ [
1178
+ VisionTransformerBlock(
1179
+ dim=self.hidden_size,
1180
+ num_heads=self.num_heads,
1181
+ mlp_ratio=self.intermediate_size / self.hidden_size,
1182
+ dropout=vision_config.get("dropout", 0.0),
1183
+ layer_norm_eps=self.layer_norm_eps,
1184
+ )
1185
+ for _ in range(self.num_layers)
1186
+ ]
1187
+ )
1188
+ self.norm = nn.LayerNorm(self.hidden_size, eps=self.layer_norm_eps)
1189
+ self._init_weights()
1190
+
1191
+ def _init_weights(self):
1192
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
1193
+ nn.init.trunc_normal_(self.cls_token, std=0.02)
1194
+ nn.init.trunc_normal_(self.patch_embed[0].weight, std=0.02)
1195
+
1196
+ def forward(self, pixel_values):
1197
+ batch_size = pixel_values.shape[0]
1198
+ x = self.patch_embed(pixel_values)
1199
+ x = x.flatten(2).transpose(1, 2)
1200
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
1201
+ x = torch.cat([cls_tokens, x], dim=1)
1202
+ x = x + self.pos_embed
1203
+ x = self.pos_drop(x)
1204
+ for block in self.blocks:
1205
+ x = block(x)
1206
+ x = self.norm(x)
1207
+ return x
1208
+
1209
+ class VisionTransformerBlock(nn.Module):
1210
+ def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0, layer_norm_eps=1e-6):
1211
+ super().__init__()
1212
+ self.norm1 = nn.LayerNorm(dim, eps=layer_norm_eps)
1213
+ self.attn = nn.MultiheadAttention(
1214
+ dim, num_heads, dropout=dropout, batch_first=True
1215
+ )
1216
+ self.drop_path1 = nn.Dropout(dropout)
1217
+ self.norm2 = nn.LayerNorm(dim, eps=layer_norm_eps)
1218
+ mlp_hidden_dim = int(dim * mlp_ratio)
1219
+ self.mlp = nn.Sequential(
1220
+ nn.Linear(dim, mlp_hidden_dim),
1221
+ nn.GELU(),
1222
+ nn.Dropout(dropout),
1223
+ nn.Linear(mlp_hidden_dim, dim),
1224
+ nn.Dropout(dropout),
1225
+ )
1226
+ self.drop_path2 = nn.Dropout(dropout)
1227
+
1228
+ def forward(self, x):
1229
+ residual = x
1230
+ x = self.norm1(x)
1231
+ x = self.attn(x, x, x, need_weights=False)[0]
1232
+ x = self.drop_path1(x)
1233
+ x = residual + x
1234
+ residual = x
1235
+ x = self.norm2(x)
1236
+ x = self.mlp(x)
1237
+ x = self.drop_path2(x)
1238
+ x = residual + x
1239
+ return x
1240
+
1241
+ class AudioEncoder(nn.Module):
1242
+ def __init__(self, config):
1243
+ super().__init__()
1244
+ audio_config = config.audio_config
1245
+ self.num_mel_bins = audio_config.get("num_mel_bins", 80)
1246
+ self.hidden_size = audio_config.get("hidden_size", 1024)
1247
+ self.num_layers = audio_config.get("num_layers", 12)
1248
+ self.num_heads = audio_config.get("num_heads", 16)
1249
+ self.intermediate_size = audio_config.get("intermediate_size", 4096)
1250
+ self.max_audio_length = audio_config.get("max_audio_length", 3000)
1251
+ self.dropout = audio_config.get("dropout", 0.0)
1252
+ self.conv1 = nn.Sequential(
1253
+ nn.Conv1d(self.num_mel_bins, self.hidden_size, kernel_size=3, padding=1),
1254
+ nn.GELU(),
1255
+ nn.Dropout(p=self.dropout),
1256
+ )
1257
+ self.conv2 = nn.Sequential(
1258
+ nn.Conv1d(
1259
+ self.hidden_size, self.hidden_size, kernel_size=3, stride=2, padding=1
1260
+ ),
1261
+ nn.GELU(),
1262
+ nn.Dropout(p=self.dropout),
1263
+ )
1264
+ self.pos_embed = nn.Parameter(
1265
+ torch.zeros(1, self.max_audio_length // 2, self.hidden_size)
1266
+ )
1267
+ self.pos_drop = nn.Dropout(p=self.dropout)
1268
+ self.blocks = nn.ModuleList(
1269
+ [
1270
+ AudioTransformerBlock(
1271
+ dim=self.hidden_size,
1272
+ num_heads=self.num_heads,
1273
+ mlp_ratio=self.intermediate_size / self.hidden_size,
1274
+ dropout=self.dropout,
1275
+ )
1276
+ for _ in range(self.num_layers)
1277
+ ]
1278
+ )
1279
+ self.norm = nn.LayerNorm(self.hidden_size)
1280
+ self._init_weights()
1281
+
1282
+ def _init_weights(self):
1283
+ nn.init.trunc_normal_(self.pos_embed, std=0.02)
1284
+
1285
+ def forward(self, audio_features):
1286
+ x = F.gelu(self.conv1(audio_features))
1287
+ x = F.gelu(self.conv2(x))
1288
+ x = x.transpose(1, 2)
1289
+ seq_len = x.shape[1]
1290
+ if seq_len <= self.pos_embed.shape[1]:
1291
+ x = x + self.pos_embed[:, :seq_len, :]
1292
+ else:
1293
+ pos_embed_interp = F.interpolate(
1294
+ self.pos_embed.transpose(1, 2),
1295
+ size=seq_len,
1296
+ mode="linear",
1297
+ align_corners=False,
1298
+ ).transpose(1, 2)
1299
+ x = x + pos_embed_interp
1300
+ x = self.pos_drop(x)
1301
+ for block in self.blocks:
1302
+ x = block(x)
1303
+ x = self.norm(x)
1304
+ return x
1305
+
1306
+ class AudioTransformerBlock(nn.Module):
1307
+ def __init__(self, dim, num_heads, mlp_ratio=4.0, dropout=0.0):
1308
+ super().__init__()
1309
+ self.norm1 = nn.LayerNorm(dim)
1310
+ self.attn = nn.MultiheadAttention(
1311
+ dim, num_heads, dropout=dropout, batch_first=True
1312
+ )
1313
+ self.drop_path1 = nn.Dropout(dropout)
1314
+ self.norm2 = nn.LayerNorm(dim)
1315
+ mlp_hidden_dim = int(dim * mlp_ratio)
1316
+ self.mlp = nn.Sequential(
1317
+ nn.Linear(dim, mlp_hidden_dim),
1318
+ nn.GELU(),
1319
+ nn.Dropout(dropout),
1320
+ nn.Linear(mlp_hidden_dim, dim),
1321
+ nn.Dropout(dropout),
1322
+ )
1323
+ self.drop_path2 = nn.Dropout(dropout)
1324
+
1325
+ def forward(self, x):
1326
+ residual = x
1327
+ x = self.norm1(x)
1328
+ x = self.attn(x, x, x, need_weights=False)[0]
1329
+ x = self.drop_path1(x)
1330
+ x = residual + x
1331
+ residual = x
1332
+ x = self.norm2(x)
1333
+ x = self.mlp(x)
1334
+ x = self.drop_path2(x)
1335
+ x = residual + x
1336
+ return x
1337
+
1338
+ class MultiModalProjector(nn.Module):
1339
+ def __init__(self, input_size, output_size, projector_type="mlp", num_layers=2):
1340
+ super().__init__()
1341
+ self.projector_type = projector_type
1342
+ if projector_type == "linear":
1343
+ self.projector = nn.Linear(input_size, output_size)
1344
+ elif projector_type == "mlp":
1345
+ layers = []
1346
+ current_size = input_size
1347
+ for i in range(num_layers - 1):
1348
+ layers.extend(
1349
+ [nn.Linear(current_size, output_size), nn.GELU(), nn.Dropout(0.1)]
1350
+ )
1351
+ current_size = output_size
1352
+ layers.append(nn.Linear(current_size, output_size))
1353
+ self.projector = nn.Sequential(*layers)
1354
+ elif projector_type == "perceiver":
1355
+ self.projector = PerceiverResampler(
1356
+ input_size, output_size, num_latents=64, num_layers=2
1357
+ )
1358
+ elif projector_type == "qformer":
1359
+ self.projector = QFormerProjector(
1360
+ input_size, output_size, num_queries=32, num_layers=2
1361
+ )
1362
+ else:
1363
+ raise ValueError(f"projector_type tidak dikenal: {projector_type}")
1364
+
1365
+ def forward(self, x):
1366
+ return self.projector(x)
1367
+
1368
+ class PerceiverResampler(nn.Module):
1369
+ def __init__(self, input_size, output_size, num_latents=64, num_layers=2):
1370
+ super().__init__()
1371
+ self.num_latents = num_latents
1372
+ self.latents = nn.Parameter(torch.randn(num_latents, output_size))
1373
+ self.layers = nn.ModuleList(
1374
+ [
1375
+ PerceiverLayer(output_size, input_size if i == 0 else output_size)
1376
+ for i in range(num_layers)
1377
+ ]
1378
+ )
1379
+ self.norm = nn.LayerNorm(output_size)
1380
+
1381
+ def forward(self, x):
1382
+ batch_size = x.shape[0]
1383
+ latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)
1384
+ for i, layer in enumerate(self.layers):
1385
+ if i == 0:
1386
+ latents = layer(latents, x)
1387
+ else:
1388
+ latents = layer(latents, latents)
1389
+ return self.norm(latents)
1390
+
1391
+ class PerceiverLayer(nn.Module):
1392
+ def __init__(self, query_dim, key_dim):
1393
+ super().__init__()
1394
+ self.cross_attn = nn.MultiheadAttention(
1395
+ query_dim, num_heads=8, kdim=key_dim, vdim=key_dim, batch_first=True
1396
+ )
1397
+ self.mlp = nn.Sequential(
1398
+ nn.LayerNorm(query_dim),
1399
+ nn.Linear(query_dim, query_dim * 4),
1400
+ nn.GELU(),
1401
+ nn.Linear(query_dim * 4, query_dim),
1402
+ )
1403
+ self.norm1 = nn.LayerNorm(query_dim)
1404
+ self.norm2 = nn.LayerNorm(query_dim)
1405
+
1406
+ def forward(self, query, key):
1407
+ query = (
1408
+ query + self.cross_attn(self.norm1(query), key, key, need_weights=False)[0]
1409
+ )
1410
+ query = query + self.mlp(self.norm2(query))
1411
+ return query
1412
+
1413
+ class QFormerProjector(nn.Module):
1414
+ def __init__(self, input_size, output_size, num_queries=32, num_layers=2):
1415
+ super().__init__()
1416
+ self.num_queries = num_queries
1417
+ self.query_embeds = nn.Parameter(torch.randn(num_queries, output_size))
1418
+ self.query_layers = nn.ModuleList(
1419
+ [
1420
+ nn.TransformerEncoderLayer(
1421
+ d_model=output_size,
1422
+ nhead=8,
1423
+ dim_feedforward=output_size * 4,
1424
+ batch_first=True,
1425
+ )
1426
+ for _ in range(num_layers)
1427
+ ]
1428
+ )
1429
+ self.cross_attn_layers = nn.ModuleList(
1430
+ [
1431
+ nn.MultiheadAttention(
1432
+ output_size,
1433
+ num_heads=8,
1434
+ kdim=input_size,
1435
+ vdim=input_size,
1436
+ batch_first=True,
1437
+ )
1438
+ for _ in range(num_layers)
1439
+ ]
1440
+ )
1441
+ self.norm = nn.LayerNorm(output_size)
1442
+
1443
+ def forward(self, x):
1444
+ batch_size = x.shape[0]
1445
+ queries = self.query_embeds.unsqueeze(0).expand(batch_size, -1, -1)
1446
+ for query_layer, cross_attn_layer in zip(
1447
+ self.query_layers, self.cross_attn_layers
1448
+ ):
1449
+ queries = query_layer(queries)
1450
+ queries = queries + cross_attn_layer(queries, x, x, need_weights=False)[0]
1451
+ return self.norm(queries)
1452
+
1453
+ class CacaPreTrainedModel(PreTrainedModel):
1454
+ config_class = CacaConfig
1455
+ base_model_prefix = "model"
1456
+ supports_gradient_checkpointing = True
1457
+ _no_split_modules = ["CacaDecoderLayer"]
1458
+ _skip_keys_device_placement = "past_key_values"
1459
+
1460
+ def _init_weights(self, module):
1461
+ std = self.config.initializer_range
1462
+ if isinstance(module, nn.Linear):
1463
+ module.weight.data.normal_(mean=0.0, std=std)
1464
+ if module.bias is not None:
1465
+ module.bias.data.zero_()
1466
+ elif isinstance(module, nn.Embedding):
1467
+ module.weight.data.normal_(mean=0.0, std=std)
1468
+ if module.padding_idx is not None:
1469
+ module.weight.data[module.padding_idx].zero_()
1470
+
1471
+ def _set_gradient_checkpointing(self, module, value=False):
1472
+ if isinstance(module, CacaModel):
1473
+ module.gradient_checkpointing = value
1474
+
1475
+ class CacaModel(CacaPreTrainedModel):
1476
+ def __init__(self, config):
1477
+ super().__init__(config)
1478
+ self.config = config
1479
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
1480
+ self.layers = nn.ModuleList(
1481
+ [
1482
+ CacaDecoderLayer(config, layer_idx=idx)
1483
+ for idx in range(config.num_hidden_layers)
1484
+ ]
1485
+ )
1486
+ self.norm = CacaRMSNorm(config.hidden_size, config.rms_norm_eps)
1487
+ self.gradient_checkpointing = False
1488
+ if config.use_multimodal:
1489
+ if config.vision_config:
1490
+ self.vision_encoder = VisionEncoder(config)
1491
+ vision_hidden_size = config.vision_config.get("hidden_size", 768)
1492
+ self.vision_projector = MultiModalProjector(
1493
+ vision_hidden_size,
1494
+ config.hidden_size,
1495
+ projector_type=config.vision_config.get("projector_type", "mlp"),
1496
+ )
1497
+ else:
1498
+ self.vision_encoder = None
1499
+ self.vision_projector = None
1500
+ if config.audio_config:
1501
+ self.audio_encoder = AudioEncoder(config)
1502
+ audio_hidden_size = config.audio_config.get("hidden_size", 768)
1503
+ self.audio_projector = MultiModalProjector(
1504
+ audio_hidden_size,
1505
+ config.hidden_size,
1506
+ projector_type=config.audio_config.get("projector_type", "mlp"),
1507
+ )
1508
+ else:
1509
+ self.audio_encoder = None
1510
+ self.audio_projector = None
1511
+ self.post_init()
1512
+
1513
+ def get_input_embeddings(self):
1514
+ return self.embed_tokens
1515
+
1516
+ def set_input_embeddings(self, value):
1517
+ self.embed_tokens = value
1518
+
1519
+ def _prepare_attention_mask(self, attention_mask, input_shape, dtype):
1520
+ if attention_mask is None:
1521
+ return None
1522
+ batch_size, seq_length = input_shape
1523
+ if attention_mask.dim() == 2:
1524
+ attention_mask = attention_mask[:, None, None, :]
1525
+ elif attention_mask.dim() == 3:
1526
+ attention_mask = attention_mask[:, None, :, :]
1527
+ attention_mask = attention_mask.to(dtype=dtype)
1528
+ attention_mask = (1.0 - attention_mask) * torch.finfo(dtype).min
1529
+ return attention_mask
1530
+
1531
+ def forward(
1532
+ self,
1533
+ input_ids=None,
1534
+ pixel_values=None,
1535
+ audio_features=None,
1536
+ attention_mask=None,
1537
+ past_key_values=None,
1538
+ use_cache=None,
1539
+ output_hidden_states=False,
1540
+ return_dict=True,
1541
+ **kwargs,
1542
+ ):
1543
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1544
+ if input_ids is not None:
1545
+ batch_size, seq_length = input_ids.shape
1546
+ device = input_ids.device
1547
+ hidden_states = self.embed_tokens(input_ids)
1548
+ else:
1549
+ raise ValueError("input_ids tidak boleh None")
1550
+
1551
+ if pixel_values is not None:
1552
+ pixel_values = pixel_values.to(device)
1553
+ if audio_features is not None:
1554
+ audio_features = audio_features.to(device)
1555
+
1556
+ encoder_hidden_states = None
1557
+ encoder_attention_mask = None
1558
+ if self.config.use_multimodal:
1559
+ multimodal_embeds = []
1560
+ if pixel_values is not None and self.vision_encoder is not None:
1561
+ vision_features = self.vision_encoder(pixel_values.to(hidden_states.device))
1562
+ vision_embeds = self.vision_projector(vision_features)
1563
+ multimodal_embeds.append(vision_embeds)
1564
+ if audio_features is not None and self.audio_encoder is not None:
1565
+ audio_encoded = self.audio_encoder(audio_features.to(hidden_states.device))
1566
+ audio_embeds = self.audio_projector(audio_encoded)
1567
+ multimodal_embeds.append(audio_embeds)
1568
+ if multimodal_embeds and self.config.use_cross_attention:
1569
+ encoder_hidden_states = torch.cat(multimodal_embeds, dim=1)
1570
+ encoder_seq_len = encoder_hidden_states.shape[1]
1571
+ encoder_attention_mask = torch.ones(
1572
+ batch_size,
1573
+ encoder_seq_len,
1574
+ dtype=hidden_states.dtype,
1575
+ device=hidden_states.device,
1576
+ )
1577
+
1578
+ elif multimodal_embeds:
1579
+ multimodal_concat = torch.cat(multimodal_embeds, dim=1)
1580
+ max_multimodal_tokens = self.config.max_position_embeddings // 4
1581
+ if multimodal_concat.shape[1] > max_multimodal_tokens:
1582
+ logger.warning(
1583
+ f"Multimodal tokens ({multimodal_concat.shape[1]}) > max ({max_multimodal_tokens}). "
1584
+ f"Truncating..."
1585
+ )
1586
+ multimodal_concat = multimodal_concat[:, :max_multimodal_tokens]
1587
+ hidden_states = torch.cat([multimodal_concat, hidden_states], dim=1)
1588
+ seq_length = hidden_states.shape[1]
1589
+ if attention_mask is not None:
1590
+ multimodal_mask = torch.ones(
1591
+ batch_size,
1592
+ multimodal_concat.shape[1],
1593
+ dtype=attention_mask.dtype,
1594
+ device=attention_mask.device,
1595
+ )
1596
+ attention_mask = torch.cat([multimodal_mask, attention_mask], dim=1)
1597
+ else:
1598
+ attention_mask = torch.ones(
1599
+ batch_size,
1600
+ seq_length,
1601
+ dtype=hidden_states.dtype,
1602
+ device=device,
1603
+ )
1604
+ if attention_mask is not None:
1605
+ attention_mask = self._prepare_attention_mask(
1606
+ attention_mask, (batch_size, seq_length), hidden_states.dtype
1607
+ )
1608
+ if encoder_attention_mask is not None and self.config.use_cross_attention:
1609
+ encoder_attention_mask = self._prepare_attention_mask(
1610
+ encoder_attention_mask,
1611
+ (batch_size, encoder_hidden_states.shape[1]),
1612
+ hidden_states.dtype,
1613
+ )
1614
+
1615
+ if use_cache:
1616
+ if past_key_values is None:
1617
+ past_key_values = tuple([None] * len(self.layers))
1618
+
1619
+ present_key_values = [] if use_cache else None
1620
+ all_hidden_states = [] if output_hidden_states else None
1621
+ total_aux_loss = 0.0
1622
+ total_z_loss = 0.0
1623
+ for idx, layer in enumerate(self.layers):
1624
+ if output_hidden_states:
1625
+ all_hidden_states.append(hidden_states)
1626
+ past_key_value = (
1627
+ past_key_values[idx] if past_key_values is not None else None
1628
+ )
1629
+ if self.gradient_checkpointing and self.training and not use_cache:
1630
+ hidden_states, aux_loss, z_loss = self._gradient_checkpointing_forward(
1631
+ layer,
1632
+ hidden_states,
1633
+ attention_mask,
1634
+ encoder_hidden_states,
1635
+ encoder_attention_mask,
1636
+ )
1637
+ present_key_value = None
1638
+ else:
1639
+ hidden_states, present_key_value, aux_loss, z_loss = layer(
1640
+ hidden_states,
1641
+ attention_mask,
1642
+ encoder_hidden_states=encoder_hidden_states,
1643
+ encoder_attention_mask=encoder_attention_mask,
1644
+ past_key_value=past_key_value,
1645
+ use_cache=use_cache,
1646
+ )
1647
+ if use_cache:
1648
+ present_key_values.append(present_key_value)
1649
+ total_aux_loss += aux_loss
1650
+ total_z_loss += z_loss
1651
+
1652
+ if self.training and torch.cuda.is_available():
1653
+ allocated_gb = torch.cuda.memory_allocated() / 1024**3
1654
+ reserved_gb = torch.cuda.memory_reserved() / 1024**3
1655
+ if allocated_gb > 10:
1656
+ logger.warning(
1657
+ f"High GPU memory usage - Allocated: {allocated_gb:.2f}GB, "
1658
+ f"Reserved: {reserved_gb:.2f}GB"
1659
+ )
1660
+ hidden_states = self.norm(hidden_states)
1661
+ if output_hidden_states:
1662
+ all_hidden_states.append(hidden_states)
1663
+ if not return_dict:
1664
+ return tuple(
1665
+ v
1666
+ for v in [
1667
+ hidden_states,
1668
+ present_key_values,
1669
+ all_hidden_states,
1670
+ total_aux_loss,
1671
+ total_z_loss,
1672
+ ]
1673
+ if v is not None
1674
+ )
1675
+ return (
1676
+ BaseModelOutputWithPast(
1677
+ last_hidden_state=hidden_states,
1678
+ past_key_values=tuple(present_key_values) if use_cache else None,
1679
+ hidden_states=all_hidden_states,
1680
+ attentions=None,
1681
+ ),
1682
+ total_aux_loss,
1683
+ total_z_loss,
1684
+ )
1685
+
1686
+ def _gradient_checkpointing_forward(
1687
+ self,
1688
+ layer,
1689
+ hidden_states,
1690
+ attention_mask,
1691
+ encoder_hidden_states,
1692
+ encoder_attention_mask,
1693
+ ):
1694
+ from torch.utils.checkpoint import checkpoint
1695
+
1696
+ def custom_forward(hidden_states, attention_mask, encoder_hidden_states,
1697
+ encoder_attention_mask):
1698
+ output, _, aux_loss, z_loss = layer(
1699
+ hidden_states, attention_mask,
1700
+ encoder_hidden_states=encoder_hidden_states,
1701
+ encoder_attention_mask=encoder_attention_mask,
1702
+ past_key_value=None,
1703
+ use_cache=False,
1704
+ )
1705
+
1706
+ return output, aux_loss, z_loss
1707
+
1708
+ hidden_states, aux_loss, z_loss = checkpoint(
1709
+ custom_forward,
1710
+ hidden_states, attention_mask,
1711
+ encoder_hidden_states, encoder_attention_mask,
1712
+ use_reentrant=False,
1713
+ )
1714
+ return hidden_states, aux_loss, z_loss
1715
+
1716
+ class CacaForCausalLM(CacaPreTrainedModel, GenerationMixin):
1717
+ _tied_weights_keys = ["lm_head.weight"]
1718
+
1719
+ def __init__(self, config):
1720
+ super().__init__(config)
1721
+ self.model = CacaModel(config)
1722
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1723
+ self.post_init()
1724
+
1725
+ def get_input_embeddings(self):
1726
+ return self.model.embed_tokens
1727
+
1728
+ def set_input_embeddings(self, value):
1729
+ self.model.embed_tokens = value
1730
+
1731
+ def get_output_embeddings(self):
1732
+ return self.lm_head
1733
+
1734
+ def set_output_embeddings(self, new_embeddings):
1735
+ self.lm_head = new_embeddings
1736
+
1737
+ def set_decoder(self, decoder):
1738
+ self.model = decoder
1739
+
1740
+ def get_decoder(self):
1741
+ return self.model
1742
+
1743
+ def forward(
1744
+ self,
1745
+ input_ids=None,
1746
+ pixel_values=None,
1747
+ audio_features=None,
1748
+ attention_mask=None,
1749
+ labels=None,
1750
+ past_key_values=None,
1751
+ inputs_embeds=None,
1752
+ use_cache=None,
1753
+ output_attentions=None,
1754
+ output_hidden_states=None,
1755
+ return_dict=None,
1756
+ **kwargs,
1757
+ ):
1758
+ if input_ids is not None:
1759
+ if input_ids.dtype.is_floating_point:
1760
+ raise TypeError(
1761
+ f"input_ids harus integer dtype, dapat {input_ids.dtype}. "
1762
+ f"Gunakan input_ids.long() untuk convert."
1763
+ )
1764
+ if (input_ids < 0).any():
1765
+ neg_vals = input_ids[input_ids < 0].unique().tolist()
1766
+ raise ValueError(f"input_ids mengandung nilai negatif: {neg_vals}")
1767
+ max_val = input_ids.max().item()
1768
+ if max_val >= self.config.vocab_size:
1769
+ raise ValueError(
1770
+ f"input_ids mengandung nilai >= vocab_size. "
1771
+ f"Max value: {max_val}, vocab_size: {self.config.vocab_size:,}"
1772
+ )
1773
+
1774
+ if labels is not None:
1775
+ if not labels.dtype in [torch.long, torch.int, torch.int32, torch.int64]:
1776
+ raise TypeError(f"labels harus integer dtype, dapat {labels.dtype}")
1777
+ if (labels[labels != -100] < 0).any():
1778
+ raise ValueError(f"labels mengandung nilai negatif (selain -100)")
1779
+ max_label = labels[labels != -100].max().item() if (labels != -100).any() else 0
1780
+ if max_label >= self.config.vocab_size:
1781
+ raise ValueError(
1782
+ f"labels mengandung nilai >= vocab_size. "
1783
+ f"Max: {max_label}, vocab_size: {self.config.vocab_size}"
1784
+ )
1785
+ if attention_mask is not None:
1786
+ if attention_mask.shape[0] != input_ids.shape[0]:
1787
+ raise ValueError(
1788
+ f"attention_mask batch size ({attention_mask.shape[0]}) != "
1789
+ f"input_ids batch size ({input_ids.shape[0]})"
1790
+ )
1791
+ if attention_mask.shape[1] != input_ids.shape[1]:
1792
+ raise ValueError(
1793
+ f"attention_mask seq length ({attention_mask.shape[1]}) != "
1794
+ f"input_ids seq length ({input_ids.shape[1]})"
1795
+ )
1796
+
1797
+ return_dict = (
1798
+ return_dict if return_dict is not None else self.config.use_return_dict
1799
+ )
1800
+ outputs, aux_loss, z_loss = self.model(
1801
+ input_ids,
1802
+ pixel_values=pixel_values,
1803
+ audio_features=audio_features,
1804
+ attention_mask=attention_mask,
1805
+ past_key_values=past_key_values,
1806
+ use_cache=use_cache,
1807
+ output_hidden_states=output_hidden_states,
1808
+ return_dict=return_dict,
1809
+ )
1810
+ if return_dict:
1811
+ hidden_states = outputs.last_hidden_state
1812
+ else:
1813
+ hidden_states = outputs[0]
1814
+ logits = self.lm_head(hidden_states)
1815
+ if self.config.final_logit_softcapping:
1816
+ logits = soft_cap_logits(logits, self.config.final_logit_softcapping)
1817
+ loss = None
1818
+ if labels is not None:
1819
+ shift_logits = logits[..., :-1, :].contiguous()
1820
+ shift_labels = labels[..., 1:].contiguous()
1821
+ loss_fct = nn.CrossEntropyLoss()
1822
+ lm_loss = loss_fct(
1823
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
1824
+ )
1825
+ if self.config.use_moe:
1826
+ total_loss = (
1827
+ lm_loss
1828
+ + (self.config.router_aux_loss_coef * aux_loss)
1829
+ + (self.config.router_z_loss_coef * z_loss)
1830
+ )
1831
+ else:
1832
+ total_loss = lm_loss
1833
+ loss = total_loss
1834
+ if not return_dict:
1835
+ output = (logits,)
1836
+ if return_dict:
1837
+ output += tuple(
1838
+ v
1839
+ for v in [outputs.past_key_values, outputs.hidden_states]
1840
+ if v is not None
1841
+ )
1842
+ return ((loss,) + output) if loss is not None else output
1843
+ return CausalLMOutputWithPast(
1844
+ loss=loss,
1845
+ logits=logits,
1846
+ past_key_values=outputs.past_key_values if return_dict else None,
1847
+ hidden_states=outputs.hidden_states if return_dict else None,
1848
+ attentions=None,
1849
+ )
1850
+
1851
+ def prepare_inputs_for_generation(
1852
+ self,
1853
+ input_ids,
1854
+ past_key_values=None,
1855
+ attention_mask=None,
1856
+ inputs_embeds=None,
1857
+ pixel_values=None,
1858
+ audio_features=None,
1859
+ **kwargs,
1860
+ ):
1861
+
1862
+ has_past = (
1863
+ past_key_values is not None
1864
+ and len(past_key_values) > 0
1865
+ and past_key_values[0] is not None
1866
+ )
1867
+
1868
+ if has_past:
1869
+ input_ids = input_ids[:, -1:]
1870
+
1871
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1872
+ attention_mask = attention_mask[:, -input_ids.shape[1]:]
1873
+
1874
+ if inputs_embeds is not None and not has_past:
1875
+ model_inputs = {"inputs_embeds": inputs_embeds}
1876
+ else:
1877
+ model_inputs = {"input_ids": input_ids}
1878
+
1879
+ model_inputs.update(
1880
+ {
1881
+ "past_key_values": past_key_values if has_past else None,
1882
+ "use_cache": kwargs.get("use_cache"),
1883
+ "attention_mask": attention_mask,
1884
+ "pixel_values": pixel_values if not has_past else None,
1885
+ "audio_features": audio_features if not has_past else None,
1886
+ }
1887
+ )
1888
+ return model_inputs
1889
+
1890
+ @staticmethod
1891
+ def _reorder_cache(past_key_values, beam_idx):
1892
+ reordered_past = ()
1893
+ for layer_past in past_key_values:
1894
+ if layer_past is not None and len(layer_past) > 0:
1895
+ reordered_past += (
1896
+ tuple(
1897
+ past_state.index_select(0, beam_idx.to(past_state.device))
1898
+ for past_state in layer_past
1899
+ if past_state is not None
1900
+ ),
1901
+ )
1902
+ else:
1903
+ reordered_past += (None,)
1904
+ return reordered_past
1905
+
1906
+ def save_pretrained(self, save_directory, **kwargs):
1907
+ has_quant_config = hasattr(self.config, 'quantization_config')
1908
+ quantization_config_backup = getattr(self.config, 'quantization_config', None)
1909
+
1910
+ if has_quant_config and quantization_config_backup is None:
1911
+ delattr(self.config, 'quantization_config')
1912
+
1913
+ try:
1914
+ super().save_pretrained(save_directory, **kwargs)
1915
+ finally:
1916
+ if has_quant_config:
1917
+ self.config.quantization_config = quantization_config_backup
1918
+
1919
+ class CacaForCausalLMQuantized(CacaForCausalLM):
1920
+ def __init__(self, config, quantization_config=None):
1921
+ super().__init__(config)
1922
+ self.quantization_config = quantization_config
1923
+ if quantization_config:
1924
+ self._apply_quantization()
1925
+
1926
+ def _apply_quantization(self):
1927
+ if self.quantization_config.get("load_in_8bit"):
1928
+ self._quantize_8bit()
1929
+ elif self.quantization_config.get("load_in_4bit"):
1930
+ self._quantize_4bit()
1931
+
1932
+ def _quantize_8bit(self):
1933
+ try:
1934
+ import bitsandbytes as bnb
1935
+
1936
+ for name, module in self.named_modules():
1937
+ if isinstance(module, nn.Linear):
1938
+ has_bias = module.bias is not None
1939
+ new_module = bnb.nn.Linear8bitLt(
1940
+ module.in_features,
1941
+ module.out_features,
1942
+ has_bias,
1943
+ threshold=self.quantization_config.get(
1944
+ "llm_int8_threshold", 6.0
1945
+ ),
1946
+ )
1947
+ new_module.weight = module.weight
1948
+ if has_bias:
1949
+ new_module.bias = module.bias
1950
+ parent_name = ".".join(name.split(".")[:-1])
1951
+ child_name = name.split(".")[-1]
1952
+ if parent_name:
1953
+ parent = self.get_submodule(parent_name)
1954
+ setattr(parent, child_name, new_module)
1955
+ else:
1956
+ setattr(self, child_name, new_module)
1957
+ logger.info("Quantisasi 8-bit berhasil diterapkan")
1958
+ except ImportError:
1959
+ logger.error("bitsandbytes tidak terinstall! pip install bitsandbytes")
1960
+
1961
+ def _quantize_4bit(self):
1962
+ try:
1963
+ import bitsandbytes as bnb
1964
+
1965
+ compute_dtype = torch.float16
1966
+ if self.quantization_config.get("bnb_4bit_compute_dtype"):
1967
+ compute_dtype = getattr(
1968
+ torch, self.quantization_config["bnb_4bit_compute_dtype"]
1969
+ )
1970
+ for name, module in self.named_modules():
1971
+ if isinstance(module, nn.Linear):
1972
+ has_bias = module.bias is not None
1973
+ new_module = bnb.nn.Linear4bit(
1974
+ module.in_features,
1975
+ module.out_features,
1976
+ bias=has_bias,
1977
+ compute_dtype=compute_dtype,
1978
+ quant_type=self.quantization_config.get(
1979
+ "bnb_4bit_quant_type", "nf4"
1980
+ ),
1981
+ use_double_quant=self.quantization_config.get(
1982
+ "bnb_4bit_use_double_quant", True
1983
+ ),
1984
+ )
1985
+ new_module.weight = module.weight
1986
+ if has_bias:
1987
+ new_module.bias = module.bias
1988
+ parent_name = ".".join(name.split(".")[:-1])
1989
+ child_name = name.split(".")[-1]
1990
+ if parent_name:
1991
+ parent = self.get_submodule(parent_name)
1992
+ setattr(parent, child_name, new_module)
1993
+ else:
1994
+ setattr(self, child_name, new_module)
1995
+ logger.info("Quantisasi 4-bit berhasil diterapkan")
1996
+ except ImportError:
1997
+ logger.error("bitsandbytes tidak terinstall!")
1998
+
1999
+ @classmethod
2000
+ def from_pretrained_quantized(cls, model_path, quantization_config):
2001
+ config = CacaConfig.from_pretrained(model_path)
2002
+ model = cls(config, quantization_config=quantization_config)
2003
+ state_dict = torch.load(f"{model_path}/pytorch_model.bin", map_location="cpu")
2004
+ model.load_state_dict(state_dict, strict=False)
2005
+ return model