YuhuaJiang commited on
Commit
b510dde
·
verified ·
1 Parent(s): f7b7447

initial upload

Browse files
__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
4
+
5
+ from fla.models.transformer.configuration_transformer import TransformerConfig
6
+ from fla.models.transformer.modeling_transformer import (
7
+ TransformerForCausalLM, TransformerModel)
8
+
9
+ from .configuration_transformer_rnn import TransformerConfig_rnn
10
+ from .modeling_transformer_rnn import TransformerForCausalLM_rnn, TransformerModel_rnn
11
+ from .task_aware_delta_net import Task_Aware_Delta_Net
12
+ from .ttt_cross_layer import TTT_Cross_Layer
13
+
14
+
15
+ AutoConfig.register(TransformerConfig.model_type, TransformerConfig)
16
+ AutoModel.register(TransformerConfig, TransformerModel)
17
+ AutoModelForCausalLM.register(TransformerConfig, TransformerForCausalLM)
18
+
19
+ AutoConfig.register(TransformerConfig_rnn.model_type, TransformerConfig_rnn)
20
+ AutoModel.register(TransformerConfig_rnn, TransformerModel_rnn)
21
+ AutoModelForCausalLM.register(TransformerConfig_rnn, TransformerForCausalLM_rnn)
22
+
23
+ __all__ = ['TransformerConfig', 'TransformerForCausalLM', 'TransformerModel',
24
+ 'TransformerConfig_rnn', 'TransformerForCausalLM_rnn', 'TransformerModel_rnn',
25
+ 'Task_Aware_Delta_Net', 'TTT_Cross_Layer']
config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/cpfs02/user/jiangyuhua/flash-linear-attention/training/configs/nirvana_1.3B-t3/nirvana_1_3B.json",
3
+ "architectures": [
4
+ "TransformerForCausalLM_rnn"
5
+ ],
6
+ "attention_bias": false,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_transformer_rnn.TransformerConfig_rnn",
9
+ "AutoModel": "modeling_transformer_rnn.TransformerModel_rnn",
10
+ "AutoModelForCausalLM": "modeling_transformer_rnn.TransformerForCausalLM_rnn"
11
+ },
12
+ "bos_token_id": 1,
13
+ "concept_dim": 64,
14
+ "elementwise_affine": true,
15
+ "eos_token_id": 2,
16
+ "fuse_cross_entropy": true,
17
+ "fuse_norm": true,
18
+ "hidden_act": "swish",
19
+ "hidden_ratio": 4,
20
+ "hidden_size": 2048,
21
+ "initializer_range": 0.006,
22
+ "intermediate_size": null,
23
+ "logit_dim": 32,
24
+ "max_position_embeddings": 32768,
25
+ "model_type": "transformer_rnn",
26
+ "norm_eps": 1e-06,
27
+ "norm_first": false,
28
+ "num_heads": 16,
29
+ "num_hidden_layers": 16,
30
+ "num_kv_heads": null,
31
+ "pad_token_id": 2,
32
+ "recurrent_depth": 4,
33
+ "rope_theta": 10000.0,
34
+ "tie_word_embeddings": false,
35
+ "torch_dtype": "bfloat16",
36
+ "transformers_version": "4.46.0",
37
+ "use_cache": false,
38
+ "vocab_size": 128512,
39
+ "window_size": 2048
40
+ }
configuration_transformer_rnn.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from typing import Optional
4
+
5
+ from transformers.configuration_utils import PretrainedConfig
6
+
7
+ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
8
+
9
+ # AutoConfig
10
+
11
+ class TransformerConfig_rnn(PretrainedConfig):
12
+
13
+ model_type = 'transformer_rnn'
14
+ keys_to_ignore_at_inference = ['past_key_values']
15
+
16
+ def __init__(
17
+ self,
18
+ vocab_size: int = 32000,
19
+ hidden_size: int = 2048,
20
+ num_hidden_layers: int = 24,
21
+ num_heads: int = 32,
22
+ num_kv_heads: int = None,
23
+ window_size: Optional[int] = None,
24
+ rope_theta: Optional[float] = 10000.,
25
+ max_position_embeddings: int = 2048,
26
+ hidden_ratio: Optional[int] = 4,
27
+ intermediate_size: Optional[int] = None,
28
+ hidden_act: str = "swish",
29
+ initializer_range: float = 0.02,
30
+ elementwise_affine: Optional[bool] = True,
31
+ norm_first: bool = False,
32
+ norm_eps: float = 1e-6,
33
+ use_cache: bool = True,
34
+ pad_token_id: int = None,
35
+ bos_token_id: int = 1,
36
+ eos_token_id: int = 2,
37
+ tie_word_embeddings: bool = False,
38
+ attention_bias: bool = False,
39
+ fuse_norm: bool = True,
40
+ fuse_cross_entropy: bool = True,
41
+ recurrent_depth: int = 4,
42
+ concept_dim: int = 128,
43
+ **kwargs,
44
+ ):
45
+ self.vocab_size = vocab_size
46
+ self.hidden_size = hidden_size
47
+ self.num_hidden_layers = num_hidden_layers
48
+ self.num_heads = num_heads
49
+ self.num_kv_heads = num_kv_heads
50
+ self.window_size = window_size
51
+ self.rope_theta = rope_theta
52
+ self.max_position_embeddings = max_position_embeddings
53
+
54
+ self.hidden_ratio = hidden_ratio
55
+ self.intermediate_size = intermediate_size
56
+ self.hidden_act = hidden_act
57
+
58
+ self.initializer_range = initializer_range
59
+ self.elementwise_affine = elementwise_affine
60
+ self.norm_first = norm_first
61
+ self.norm_eps = norm_eps
62
+ self.use_cache = use_cache
63
+ self.attention_bias = attention_bias
64
+ self.fuse_cross_entropy = fuse_cross_entropy
65
+ self.fuse_norm = fuse_norm
66
+ self.recurrent_depth = recurrent_depth
67
+ self.concept_dim = concept_dim
68
+
69
+ super().__init__(
70
+ pad_token_id=pad_token_id,
71
+ bos_token_id=bos_token_id,
72
+ eos_token_id=eos_token_id,
73
+ tie_word_embeddings=tie_word_embeddings,
74
+ **kwargs,
75
+ )
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 2,
6
+ "transformers_version": "4.46.0",
7
+ "use_cache": false
8
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9c8b64a92f136b61f25043ccd79b6b14efd5ca5287a4f3ea185c5c19bd39bcf
3
+ size 3226267140
modeling_transformer_rnn.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from __future__ import annotations
4
+
5
+ import math
6
+ import warnings
7
+ from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, Dict
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from transformers.activations import ACT2FN
13
+ from transformers.generation import GenerationMixin
14
+ from transformers.modeling_outputs import (BaseModelOutputWithPast,
15
+ CausalLMOutputWithPast)
16
+ from dataclasses import dataclass
17
+ from transformers.utils import ModelOutput
18
+ @dataclass
19
+ class BaseModelOutputWithPast_with_two_caches(ModelOutput):
20
+
21
+ last_hidden_state: torch.FloatTensor = None
22
+ past_key_values1: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
23
+ all_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
24
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
25
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
26
+ @dataclass
27
+ class CausalLMOutputWithPast_with_two_caches(ModelOutput):
28
+ logits: torch.FloatTensor = None
29
+ loss: Optional[torch.FloatTensor] = None
30
+ past_key_values1: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
31
+ all_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
32
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
33
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
34
+
35
+ from transformers.modeling_utils import PreTrainedModel
36
+ from transformers.utils import logging
37
+
38
+ # from fla.layers.attn import Attention
39
+ from configuration_transformer_rnn import TransformerConfig_rnn
40
+
41
+ import sys
42
+ import os
43
+ # # 添加当前目录的上上级目录到 Python 路径
44
+ # current_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
45
+ # sys.path.append(current_dir)
46
+ # sys.path.append("/cpfs02/user/jiangyuhua/flash-linear-attention/fla/layers")
47
+ # from attn_rnn import Attention_rnn ###########################################################
48
+ # from attn_svd import Attention_svd ###########################################################
49
+ # from attn import Attention ###########################################################
50
+ # from gated_deltanet import GatedDeltaNet ###########################################################
51
+ # from rwkv7 import RWKV7Attention ###########################################################
52
+ # from attn_gated_delta import GatedDeltaNet_attention ###########################################################
53
+ # from scattering_mixer2 import Scattering_Mixer ###########################################################
54
+ from task_aware_delta_net import Task_Aware_Delta_Net ###########################################################
55
+
56
+ # from moe_rnn import CustomGRUCell, CustomRNNCell
57
+ from ttt_cross_layer import TTT_Cross_Layer
58
+
59
+ # from fla.models.transformer.configuration_transformer import TransformerConfig
60
+ from fla.models.utils import Cache
61
+ from fla.modules import (FusedCrossEntropyLoss, FusedLinearCrossEntropyLoss,
62
+ RMSNorm)
63
+ from fla.modules.activations import swiglu_linear
64
+ from fla.modules.layernorm import rms_norm_linear
65
+
66
+ if TYPE_CHECKING:
67
+ from transformers.processing_utils import Unpack
68
+
69
+ logger = logging.get_logger(__name__)
70
+
71
+ class TransformerMLP(nn.Module):
72
+
73
+ def __init__(
74
+ self,
75
+ hidden_size: int,
76
+ hidden_ratio: Optional[int] = None,
77
+ intermediate_size: Optional[int] = None,
78
+ hidden_act: str = 'swish',
79
+ norm_first: bool = True,
80
+ norm_eps: float = 1e-5
81
+ ) -> TransformerMLP:
82
+ super().__init__()
83
+
84
+ self.hidden_size = hidden_size
85
+ # the final number of params is `hidden_ratio * hidden_size^2`
86
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
87
+ if hidden_ratio is None:
88
+ hidden_ratio = 4
89
+ if intermediate_size is None:
90
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
91
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
92
+ self.hidden_ratio = hidden_ratio
93
+ self.intermediate_size = intermediate_size
94
+ self.norm_first = norm_first
95
+
96
+ if norm_first:
97
+ self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)
98
+
99
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
100
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
101
+ self.act_fn = ACT2FN[hidden_act]
102
+
103
+ def forward(
104
+ self,
105
+ x: torch.Tensor,
106
+ **kwargs: Unpack[Any]
107
+ ) -> torch.Tensor:
108
+ if self.norm_first:
109
+ x = rms_norm_linear(x, self.norm.weight, self.norm.bias, self.gate_proj.weight, self.gate_proj.bias)
110
+ else:
111
+ x = self.gate_proj(x)
112
+ gate, y = x.chunk(2, -1)
113
+ return swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
114
+
115
+ class TransformerMLP_svd(nn.Module):
116
+ def __init__(
117
+ self,
118
+ hidden_size: int,
119
+ hidden_ratio: Optional[int] = None,
120
+ intermediate_size: Optional[int] = None,
121
+ hidden_act: str = 'swish',
122
+ norm_first: bool = True,
123
+ norm_eps: float = 1e-5
124
+ ) -> TransformerMLP_svd:
125
+ super().__init__()
126
+
127
+ self.hidden_size = hidden_size
128
+ # the final number of params is `hidden_ratio * hidden_size^2`
129
+ # `intermediate_size` is chosen to be a multiple of 256 closest to `2/3 * hidden_size * hidden_ratio`
130
+ if hidden_ratio is None:
131
+ hidden_ratio = 4
132
+ if intermediate_size is None:
133
+ intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
134
+ intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
135
+ self.hidden_ratio = hidden_ratio
136
+ self.intermediate_size = intermediate_size
137
+ self.norm_first = norm_first
138
+
139
+ if norm_first:
140
+ self.norm = RMSNorm(hidden_size=hidden_size, eps=norm_eps)
141
+
142
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
143
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
144
+ self.act_fn = ACT2FN[hidden_act]
145
+
146
+ self.reflector_qkvo = nn.Linear(self.intermediate_size, self.hidden_size * 4)
147
+
148
+ def forward(
149
+ self,
150
+ x: torch.Tensor,
151
+ reflect: bool = False,
152
+ **kwargs: Unpack[Any]
153
+ ) -> torch.Tensor:
154
+ if self.norm_first:
155
+ x = rms_norm_linear(x, self.norm.weight, self.norm.bias, self.gate_proj.weight, self.gate_proj.bias)
156
+ else:
157
+ x = self.gate_proj(x)
158
+ gate, y = x.chunk(2, -1)
159
+ hidden_states = swiglu_linear(gate, y, self.down_proj.weight, self.down_proj.bias)
160
+ if reflect:
161
+ reflector_qkvo = swiglu_linear(gate, y, self.reflector_qkvo.weight, self.reflector_qkvo.bias)
162
+ reflector_qkvo = nn.Sigmoid()(reflector_qkvo)
163
+ return hidden_states, reflector_qkvo
164
+ else:
165
+ return hidden_states
166
+
167
+ class TransformerBlock_rnn(nn.Module):
168
+ def __init__(self, config: TransformerConfig_rnn, layer_idx: int):
169
+ super().__init__()
170
+
171
+ self.hidden_size = config.hidden_size
172
+ self.layer_idx = layer_idx
173
+
174
+ if not config.norm_first:
175
+ self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
176
+
177
+ self.head_dim = config.hidden_size // config.num_heads
178
+ self.Task_Aware_Delta_Net = Task_Aware_Delta_Net(
179
+ hidden_size=config.hidden_size,
180
+ head_dim=self.head_dim,
181
+ num_heads=config.num_heads,
182
+ mode='chunk',
183
+ rope_theta=config.rope_theta,
184
+ max_position_embeddings=config.max_position_embeddings,
185
+ norm_first=config.norm_first,
186
+ norm_eps=config.norm_eps,
187
+ layer_idx=layer_idx,
188
+ concept_dim=config.concept_dim
189
+ )
190
+ # use_ttt = True
191
+ # if use_ttt:
192
+ # self.rnn_router = TTT_Cross_Layer(config)
193
+ # else:
194
+ # self.rnn_router = CustomGRUCell(config)
195
+
196
+ if not config.norm_first:
197
+ self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
198
+ self.mlp = TransformerMLP(
199
+ hidden_size=config.hidden_size,
200
+ hidden_ratio=config.hidden_ratio,
201
+ intermediate_size=config.intermediate_size,
202
+ hidden_act=config.hidden_act,
203
+ norm_first=config.norm_first,
204
+ norm_eps=config.norm_eps
205
+ )
206
+
207
+ def forward(
208
+ self,
209
+ hidden_states: torch.Tensor,
210
+ attention_mask: Optional[torch.Tensor] = None,
211
+ past_key_values1: Optional[Tuple[torch.Tensor]] = None,
212
+ all_past_key_values: Optional[Tuple[torch.Tensor]] = None,
213
+ output_attentions: Optional[bool] = False,
214
+ use_cache: Optional[bool] = False,
215
+ h_old: Optional[torch.Tensor] = None,
216
+ rnn_router: Optional[nn.Module] = None,
217
+ params: Optional[Dict] = None,
218
+ **kwargs: Unpack[Any]
219
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
220
+
221
+ residual = hidden_states
222
+ if hasattr(self, 'attn_norm'):
223
+ hidden_states = self.attn_norm(hidden_states)
224
+
225
+ hidden_states, attentions, past_key_values1, all_past_key_values, h_new, params = self.Task_Aware_Delta_Net(
226
+ hidden_states=hidden_states,
227
+ attention_mask=attention_mask,
228
+ past_key_values1=past_key_values1,
229
+ all_past_key_values=all_past_key_values,
230
+ use_cache=use_cache,
231
+ output_attentions=output_attentions,
232
+ rnn_router=rnn_router,
233
+ h_old=h_old,
234
+ params=params,
235
+ **kwargs
236
+ )
237
+ # if self.rnn_router is not None:
238
+ # hidden_states = self.rnn_router(hidden_states, **kwargs)
239
+
240
+ if hasattr(self, 'mlp_norm'):
241
+ hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
242
+ else:
243
+ hidden_states = residual + hidden_states
244
+ residual = hidden_states
245
+
246
+ hidden_states = self.mlp(hidden_states, **kwargs)
247
+ hidden_states = residual + hidden_states
248
+
249
+ outputs = (hidden_states,)
250
+
251
+ if output_attentions:
252
+ outputs += (attentions,)
253
+
254
+ if use_cache:
255
+ outputs += (past_key_values1, all_past_key_values)
256
+
257
+ outputs += (h_new,)
258
+ outputs += (params,)
259
+ return outputs
260
+
261
+ class TransformerPreTrainedModel_rnn(PreTrainedModel):
262
+
263
+ config_class = TransformerConfig_rnn
264
+ supports_gradient_checkpointing = True
265
+ _no_split_modules = ['TransformerBlock_rnn']
266
+
267
+ def __init__(self, *inputs, **kwargs):
268
+ super().__init__(*inputs, **kwargs)
269
+
270
+ def _init_weights(
271
+ self,
272
+ module: nn.Module,
273
+ rescale_prenorm_residual: bool = False,
274
+ num_residuals_per_layer: int = 2,
275
+ ):
276
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
277
+ # Slightly different from the TF version which uses truncated_normal for initialization
278
+ # cf https://github.com/pytorch/pytorch/pull/5617
279
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
280
+ if module.bias is not None:
281
+ nn.init.zeros_(module.bias)
282
+ elif isinstance(module, nn.Embedding):
283
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
284
+ if module.padding_idx is not None:
285
+ module.weight.data[module.padding_idx].zero_()
286
+ elif hasattr(module, 'reset_parameters'):
287
+ module.reset_parameters()
288
+
289
+ if rescale_prenorm_residual:
290
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
291
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
292
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
293
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
294
+ #
295
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
296
+ for name, p in module.named_parameters():
297
+ if name in ["o_proj.weight", "down_proj.weight"]:
298
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
299
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
300
+ # We need to reinit p since this code could be called multiple times
301
+ # Having just p *= scale would repeatedly scale it down
302
+ with torch.no_grad():
303
+ p /= math.sqrt(num_residuals_per_layer * self.config.num_hidden_layers)
304
+
305
+
306
+ class TransformerModel_rnn(TransformerPreTrainedModel_rnn):
307
+ def __init__(
308
+ self,
309
+ config: TransformerConfig_rnn
310
+ ) -> TransformerModel_rnn:
311
+ super().__init__(config)
312
+ self.padding_idx = config.pad_token_id
313
+ self.vocab_size = config.vocab_size
314
+ self.concept_dim = config.concept_dim
315
+
316
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
317
+ self.layers = nn.ModuleList([TransformerBlock_rnn(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
318
+ self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
319
+
320
+ self.gradient_checkpointing = False
321
+
322
+ self.post_init()
323
+
324
+ self.rnn_router = TTT_Cross_Layer(config)
325
+
326
+ def get_input_embeddings(self):
327
+ return self.embeddings
328
+
329
+ def set_input_embeddings(self, value):
330
+ self.embeddings = value
331
+
332
+ def forward(
333
+ self,
334
+ input_ids: Optional[torch.LongTensor] = None,
335
+ attention_mask: Optional[torch.Tensor] = None,
336
+ past_key_values1: Optional[List[torch.FloatTensor]] = None,
337
+ all_past_key_values: Optional[List[torch.FloatTensor]] = None,
338
+ inputs_embeds: Optional[torch.FloatTensor] = None,
339
+ use_cache: Optional[bool] = None,
340
+ output_attentions: Optional[bool] = None,
341
+ output_hidden_states: Optional[bool] = None,
342
+ return_dict: Optional[bool] = None,
343
+ **kwargs: Unpack[Any]
344
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
345
+ if output_attentions:
346
+ warnings.warn(
347
+ "`TransformerModel` does not support output attention weights now, so `output_attentions` is set to `False`."
348
+ )
349
+ output_attentions = False
350
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
351
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
352
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
353
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
354
+
355
+ # retrieve input_ids and inputs_embeds
356
+ if input_ids is not None and inputs_embeds is not None:
357
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
358
+ elif input_ids is None and inputs_embeds is None:
359
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
360
+
361
+ if use_cache and not isinstance(past_key_values1, Cache):
362
+ past_key_values1 = Cache.from_legacy_cache(past_key_values1)
363
+ if use_cache and not isinstance(all_past_key_values, Cache):
364
+ all_past_key_values = Cache.from_legacy_cache(all_past_key_values)
365
+
366
+ if inputs_embeds is None:
367
+ inputs_embeds = self.embeddings(input_ids)
368
+
369
+ # embed positions
370
+ hidden_states = inputs_embeds
371
+
372
+ if self.gradient_checkpointing and self.training:
373
+ if use_cache:
374
+ logger.warning_once(
375
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
376
+ )
377
+ use_cache = False
378
+
379
+ all_hidden_states = () if output_hidden_states else None
380
+ all_attns = () if output_attentions else None
381
+ next_cache1 = None
382
+ next_cache2 = None
383
+ h_old = None
384
+ params = None
385
+
386
+ for layer in self.layers:
387
+ if output_hidden_states:
388
+ all_hidden_states += (hidden_states,)
389
+
390
+ if self.gradient_checkpointing and self.training:
391
+ layer_outputs = self._gradient_checkpointing_func(
392
+ layer.__call__,
393
+ hidden_states,
394
+ attention_mask,
395
+ past_key_values1,
396
+ all_past_key_values,
397
+ output_attentions,
398
+ use_cache,
399
+ h_old=h_old,
400
+ params=params,
401
+ rnn_router=self.rnn_router,
402
+ **kwargs
403
+ )
404
+ else:
405
+ layer_outputs = layer(
406
+ hidden_states,
407
+ attention_mask=attention_mask,
408
+ past_key_values1=past_key_values1,
409
+ all_past_key_values=all_past_key_values,
410
+ output_attentions=output_attentions,
411
+ use_cache=use_cache,
412
+ h_old=h_old,
413
+ params=params,
414
+ rnn_router=self.rnn_router,
415
+ **kwargs
416
+ )
417
+
418
+ hidden_states = layer_outputs[0]
419
+
420
+ h_old = layer_outputs[-2]
421
+ params = layer_outputs[-1]
422
+ if use_cache:
423
+ next_cache1 = layer_outputs[2 if output_attentions else 1]
424
+ next_cache2 = layer_outputs[3 if output_attentions else 2]
425
+
426
+ if output_attentions:
427
+ all_attns += (layer_outputs[1],)
428
+
429
+ hidden_states = self.norm(hidden_states)
430
+
431
+ # add hidden states from the last decoder layer
432
+ if output_hidden_states:
433
+ all_hidden_states += (hidden_states,)
434
+ if not return_dict:
435
+ return tuple(v for v in [hidden_states, next_cache1, all_hidden_states, all_attns] if v is not None)
436
+
437
+ # return BaseModelOutputWithPast_with_two_caches(
438
+ # last_hidden_state=hidden_states,
439
+ # past_key_values1=next_cache1,
440
+ # all_past_key_values=next_cache2,
441
+ # hidden_states=all_hidden_states,
442
+ # attentions=all_attns
443
+ # )
444
+ return BaseModelOutputWithPast(
445
+ last_hidden_state=hidden_states,
446
+ past_key_values=next_cache1,
447
+ hidden_states=all_hidden_states,
448
+ attentions=all_attns
449
+ )
450
+
451
+ class TransformerForCausalLM_rnn(TransformerPreTrainedModel_rnn, GenerationMixin):
452
+
453
+ _tied_weights_keys = ["lm_head.weight"]
454
+ def __init__(self, config):
455
+ super().__init__(config)
456
+ self.model = TransformerModel_rnn(config)
457
+ self.vocab_size = config.vocab_size
458
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
459
+
460
+ # Initialize weights and apply final processing
461
+ self.post_init()
462
+
463
+ def get_input_embeddings(self):
464
+ return self.model.embeddings
465
+
466
+ def set_input_embeddings(self, value):
467
+ self.model.embeddings = value
468
+
469
+ def get_output_embeddings(self):
470
+ return self.lm_head
471
+
472
+ def set_output_embeddings(self, new_embeddings):
473
+ self.lm_head = new_embeddings
474
+
475
+ def set_decoder(self, decoder):
476
+ self.model = decoder
477
+
478
+ def get_decoder(self):
479
+ return self.model
480
+
481
+ def prepare_inputs_for_generation(
482
+ self,
483
+ input_ids: torch.LongTensor = None,
484
+ past_key_values1: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
485
+ all_past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
486
+ attention_mask: Optional[torch.Tensor] = None,
487
+ inputs_embeds: Optional[torch.Tensor] = None,
488
+ use_cache: bool = True,
489
+ num_logits_to_keep: Optional[int] = None,
490
+ **kwargs
491
+ ):
492
+ # only last token for `inputs_ids` if the `past_key_values` is passed along.
493
+ if past_key_values1 is not None:
494
+ input_ids = input_ids[:, -1:]
495
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
496
+ if inputs_embeds is not None and past_key_values1 is None:
497
+ model_inputs = {'inputs_embeds': inputs_embeds}
498
+ else:
499
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
500
+ # recompiles graphs as the stride of the inputs is a guard.
501
+ # Ref: https://github.com/huggingface/transformers/pull/29114
502
+ # TODO: use `next_tokens` directly instead.
503
+ model_inputs = {'input_ids': input_ids.contiguous()}
504
+
505
+ if num_logits_to_keep is not None:
506
+ model_inputs['num_logits_to_keep'] = num_logits_to_keep
507
+ # model_inputs.update({
508
+ # 'past_key_values1': past_key_values1,
509
+ # 'all_past_key_values': all_past_key_values,
510
+ # 'use_cache': use_cache,
511
+ # 'attention_mask': attention_mask,
512
+ # 'num_logits_to_keep': num_logits_to_keep,
513
+ # })
514
+ model_inputs.update({
515
+ 'past_key_values1': past_key_values1,
516
+ 'use_cache': use_cache,
517
+ 'attention_mask': attention_mask,
518
+ 'num_logits_to_keep': num_logits_to_keep,
519
+ })
520
+ return model_inputs
521
+
522
+ def forward(
523
+ self,
524
+ input_ids: torch.LongTensor = None,
525
+ attention_mask: Optional[torch.Tensor] = None,
526
+ past_key_values1: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
527
+ all_past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
528
+ inputs_embeds: Optional[torch.FloatTensor] = None,
529
+ labels: Optional[torch.LongTensor] = None,
530
+ use_cache: Optional[bool] = None,
531
+ output_attentions: Optional[bool] = None,
532
+ output_hidden_states: Optional[bool] = None,
533
+ return_dict: Optional[bool] = None,
534
+ num_logits_to_keep: Optional[int] = 0,
535
+ **kwargs: Unpack[Any]
536
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
537
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
538
+ output_hidden_states = (
539
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
540
+ )
541
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
542
+
543
+ outputs = self.model(
544
+ input_ids=input_ids,
545
+ attention_mask=attention_mask,
546
+ past_key_values1=past_key_values1,
547
+ all_past_key_values=all_past_key_values,
548
+ inputs_embeds=inputs_embeds,
549
+ use_cache=use_cache,
550
+ output_attentions=output_attentions,
551
+ output_hidden_states=output_hidden_states,
552
+ return_dict=return_dict,
553
+ **kwargs
554
+ )
555
+
556
+ hidden_states = outputs[0]
557
+ fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
558
+ logits = None if fuse_linear_and_cross_entropy else self.lm_head(hidden_states[:, -num_logits_to_keep:])
559
+ loss = None
560
+ if labels is not None:
561
+ if self.config.fuse_cross_entropy:
562
+ if fuse_linear_and_cross_entropy:
563
+ loss_fct = FusedLinearCrossEntropyLoss()
564
+ else:
565
+ loss_fct = FusedCrossEntropyLoss(inplace_backward=True)
566
+ else:
567
+ loss_fct = nn.CrossEntropyLoss()
568
+ # Enable model parallelism
569
+ # labels = labels.to(hidden_states.device)
570
+ # labels = torch.cat((labels[..., 1:], torch.full_like(labels[:, :1], loss_fct.ignore_index)), 1)
571
+ if fuse_linear_and_cross_entropy:
572
+ loss = loss_fct(hidden_states.view(-1, self.config.hidden_size),
573
+ labels.view(-1),
574
+ self.lm_head.weight,
575
+ self.lm_head.bias)
576
+ else:
577
+ loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
578
+
579
+ if not return_dict:
580
+ output = (logits,) + outputs[1:]
581
+ return (loss,) + output if loss is not None else output
582
+
583
+ # return CausalLMOutputWithPast_with_two_caches(
584
+ # loss=loss,
585
+ # logits=logits,
586
+ # past_key_values1=outputs.past_key_values1,
587
+ # all_past_key_values=outputs.all_past_key_values,
588
+ # hidden_states=outputs.hidden_states,
589
+ # attentions=outputs.attentions,
590
+ # )
591
+ return CausalLMOutputWithPast(
592
+ loss=loss,
593
+ logits=logits,
594
+ past_key_values=outputs.past_key_values,
595
+ hidden_states=outputs.hidden_states,
596
+ attentions=outputs.attentions,
597
+ )
598
+
599
+ if __name__ == '__main__':
600
+ config = TransformerConfig_rnn(
601
+ concept_dim=128,
602
+ attention_bias=False,
603
+ bos_token_id=1,
604
+ eos_token_id=2,
605
+ fuse_cross_entropy=True,
606
+ fuse_norm=True,
607
+ hidden_act="swish",
608
+ hidden_size=1024,
609
+ initializer_range=0.02,
610
+ max_position_embeddings=8192,
611
+ model_type="transformer_rnn",
612
+ num_heads=16,
613
+ num_hidden_layers=24,
614
+ norm_eps=1e-06,
615
+ tie_word_embeddings=True,
616
+ use_cache=True,
617
+ vocab_size=32000,
618
+ )
619
+ model = TransformerForCausalLM_rnn(config).cuda().to(torch.bfloat16)
620
+ input_ids = torch.randint(0, 100, (2, 70)).cuda()
621
+ attention_mask = torch.ones_like(input_ids).cuda()
622
+ output = model(input_ids, attention_mask=attention_mask)
623
+ print(output)
624
+ print(output.loss)
625
+ print(output.logits)
626
+ print(output.all_past_key_values)
627
+ print(output.hidden_states)
628
+ print(output.attentions)
nirvana_1_3B.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "attention_bias": false,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 2,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_transformer_rnn.TransformerConfig_rnn",
8
+ "AutoModel": "modeling_transformer_rnn.TransformerModel_rnn",
9
+ "AutoModelForCausalLM": "modeling_transformer_rnn.TransformerForCausalLM_rnn"
10
+ },
11
+ "fuse_cross_entropy": true,
12
+ "fuse_norm": true,
13
+ "hidden_act": "swish",
14
+ "hidden_size": 2048,
15
+ "initializer_range": 6e-3,
16
+ "max_position_embeddings": 32768,
17
+ "rope_theta": 10000.0,
18
+ "model_type": "transformer_rnn",
19
+ "num_heads": 16,
20
+ "num_hidden_layers": 16,
21
+ "norm_eps": 1e-06,
22
+ "tie_word_embeddings": true,
23
+ "use_cache": false,
24
+ "vocab_size": 128512,
25
+ "concept_dim": 64,
26
+ "logit_dim": 32,
27
+ "window_size": 2048
28
+ }
task_aware_delta_net.py ADDED
@@ -0,0 +1,754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) 2024, Songlin Yang, Yu Zhang
3
+
4
+ from __future__ import annotations
5
+
6
+ import math
7
+ from typing import TYPE_CHECKING, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from einops import rearrange
12
+ from torch.nn import functional as F
13
+
14
+ from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution
15
+ from fla.ops.gated_delta_rule import (chunk_gated_delta_rule,
16
+ fused_recurrent_gated_delta_rule)
17
+
18
+ if TYPE_CHECKING:
19
+ from transformers.processing_utils import Unpack
20
+
21
+ from fla.models.utils import Cache
22
+
23
+
24
+ def elu_p1(x):
25
+ return (F.elu(x, 1., False) + 1.).to(x)
26
+
27
+ def sum_norm(x):
28
+ return (x / x.sum(-1, keepdim=True)).to(x)
29
+
30
+ from fla.modules import RMSNorm, RotaryEmbedding
31
+
32
+ if TYPE_CHECKING:
33
+ from fla.models.utils import Cache
34
+ import warnings
35
+ try:
36
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
37
+ from flash_attn.bert_padding import (index_first_axis, pad_input,
38
+ unpad_input)
39
+ except ImportError:
40
+ warnings.warn(
41
+ "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`",
42
+ category=ImportWarning
43
+ )
44
+ flash_attn_func = None
45
+
46
+ # https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1
47
+
48
+ def lambda_init_fn(depth):
49
+ return 0.8 - 0.6 * math.exp(-0.3 * depth)
50
+
51
+
52
+
53
+
54
+ # -*- coding: utf-8 -*-
55
+ from typing import Optional, Tuple
56
+ import torch
57
+ from einops import rearrange
58
+ from fla.ops.linear_attn.utils import normalize_output
59
+ # def scattering_mixer(
60
+ # q: torch.Tensor,
61
+ # k: torch.Tensor,
62
+ # v: torch.Tensor,
63
+ # gamma: torch.Tensor,
64
+ # # chi: torch.Tensor,
65
+ # scale: Optional[float] = None,
66
+ # normalize: bool = False
67
+ # ) -> Tuple[torch.Tensor, torch.Tensor]:
68
+ # if scale is None:
69
+ # scale = q.shape[-1] ** -0.5
70
+ # chunk_size = 64
71
+ # # split_size = 2
72
+ # q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale
73
+ # # k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size)
74
+
75
+ # # gamma (b , n*c, h) -> (b, h, n*c, 1)
76
+ # gamma = rearrange(gamma, 'b l h -> b h l').unsqueeze(-1)
77
+ # gamma_cumprod = torch.cumprod(gamma, dim=2)
78
+ # gamma_cumprod_chunk = rearrange(gamma_cumprod, 'b h (n c) d -> b h n c d', c=chunk_size)
79
+ # gamma_cumprod_chunk = gamma_cumprod_chunk[:, :, :, -1, :].unsqueeze(-2) # [b, h, n, 1, 1]
80
+
81
+ # gamma_cumprod = rearrange(gamma_cumprod, 'b h l d -> b l h d')
82
+ # k_cumprod = k / gamma_cumprod
83
+ # k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size)
84
+ # k_cumprod_chunk = rearrange(k_cumprod, 'b (n c) h d -> b h n c d', c=chunk_size)
85
+ # # gamma_cumprod_chunk = rearrange(gamma_cumprod, 'b h n c d -> b h (n c) d')
86
+
87
+ # v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size)
88
+
89
+ # gamma = rearrange(gamma, 'b h (n c) d -> b h n c d', c=chunk_size) # d = 1
90
+ # # gamma_cumprod_chunk_inter = torch.cumprod(gamma, dim=3)
91
+ # gamma_inter = torch.cumprod(gamma, dim=3) # [b, h, n, c, 1]
92
+
93
+ # kv = k_cumprod_chunk.transpose(-1, -2) @ v # [b, h, n, d, d]
94
+ # kv = kv.cumsum(2) # [b, h, n, d, d] n << seq_len
95
+ # kv = kv * gamma_cumprod_chunk # [b, h, n, d, d]
96
+
97
+ # kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) # [b, h, n, d, d]
98
+ # inter = (q @ kv) * gamma_inter # [b, h, n, c, d]
99
+ # intra = (
100
+ # ((q @ (k / gamma_inter).transpose(-1, -2)) ).masked_fill_(
101
+ # torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1),
102
+ # 0
103
+ # )) @ v * gamma_inter # [b, h, n, c, d]
104
+ # o = inter + intra # [b, h, n, c, d]
105
+ # if normalize:
106
+ # o = normalize_output(q * scale, k, o)
107
+ # return rearrange(o, 'b h n c d -> b (n c) h d') , None
108
+ def scattering_mixer_recurrent(
109
+ q: torch.Tensor,
110
+ k: torch.Tensor,
111
+ v: torch.Tensor,
112
+ G0: torch.Tensor,
113
+ split_size: int,
114
+ past_kv: Optional[torch.Tensor] = None,
115
+ beta: Optional[torch.Tensor] = None,
116
+ # chi: torch.Tensor,
117
+ scale: Optional[float] = None,
118
+ normalize: bool = False,
119
+ order: int = 2,
120
+ perturb: Optional[torch.Tensor] = None
121
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
122
+ if scale is None:
123
+ scale = q.shape[-1] ** -0.5
124
+ # chunk_size = 64
125
+ q = rearrange(q, 'b l h (f s) -> b h l s f', s=split_size) * scale
126
+ k = rearrange(k, 'b l h (f s) -> b h l s f', s=split_size)
127
+ v = rearrange(v, 'b l h (d s) -> b h l s d', s=split_size)
128
+ if order == 2:
129
+ G0 = rearrange(G0, 'b l h d f -> b h l d f')
130
+ # kv = k.transpose(-1, -2) @ v # [b, h, l, f, d]
131
+ second_term = torch.einsum('b h l s d, b h l d f -> b h l s f', v, G0) # [b, h, l, s, f]
132
+ G1 = second_term @ k.transpose(-1, -2) # [b, h, l, s, s]
133
+ kv2 = k.transpose(-1, -2) @ G1 + k.transpose(-1, -2) # [b, h, l, f ,s]
134
+ else:
135
+ kv2 = k.transpose(-1, -2) # [b, h, l, f ,s]
136
+ kv = kv2 @ v # [b, h, l, f, d]
137
+ # kv = kv + kv2
138
+
139
+ perturb = rearrange(perturb, 'b l h f k -> b h l f k') # [b, h, l, f, f]
140
+ M = q.transpose(-1, -2) @ q # [b, h, l, f, f]
141
+ M = perturb @ M # [b, h, l, f, f]
142
+ M = q @ M # [b, h, l, s, f]
143
+ q = q + M # [b, h, l, s, f]
144
+
145
+ if past_kv is None:
146
+ if beta is not None:
147
+ beta = rearrange(beta, 'b l h -> b h l')
148
+ beta_cumprod = torch.cumprod(beta, dim=2)
149
+ # print('the shape of beta_cumprod', beta_cumprod.shape)
150
+ beta_cumprod = torch.cat([torch.ones_like(beta_cumprod[:, :, :1]), beta_cumprod[:, :, :-1]], dim=2)
151
+ # kv = kv + kv2
152
+ beta_cumprod = rearrange(beta_cumprod, 'b h l -> b h l 1 1')
153
+ kv = kv / beta_cumprod # [b, h, l, f, d]
154
+ kv = kv.cumsum(2) # [b, h, l, f, d]
155
+ kv = kv * beta_cumprod # [b, h, l, f, d]
156
+ else:
157
+ kv = kv.cumsum(2) # [b, h, l, f, d]
158
+ o = q @ kv # [b, h, l, s, d]
159
+ else:
160
+ if beta is not None:
161
+ beta = rearrange(beta, 'b l h -> b h l')
162
+ kv = kv[:, :, -1, :, :] + past_kv * (beta[:, :, -2]).unsqueeze(-1).unsqueeze(-1)
163
+ else:
164
+ kv = kv[:, :, -1, :, :] + past_kv # [b, h, l, f, d]
165
+ o = q @ kv # [b, h, l, s, d]
166
+ # print('the shape of o', o.shape)
167
+ if normalize:
168
+ o = normalize_output(q * scale, k, o) # [b, h, l, s, d]
169
+ return rearrange(o, 'b h l s d -> b l h (s d)') , kv
170
+
171
+ def safe_exp(x):
172
+ return torch.exp(x - torch.max(x,dim=-1,keepdim=True)[0])
173
+
174
+ def random_proj(q, down_proj_matrix, up_proj_matrix, control_vec):
175
+ temp = q @ down_proj_matrix
176
+ temp = temp * control_vec
177
+ temp = temp @ up_proj_matrix
178
+ return torch.concat([torch.cos(temp), torch.sin(temp)], dim=-1)
179
+
180
+ def lora_proj(x, down_proj_matrix, up_proj_matrix, control_vec):
181
+ temp = x @ down_proj_matrix
182
+ temp = temp * control_vec
183
+ temp = temp @ up_proj_matrix
184
+ return temp
185
+
186
+ def gaussian_basis(x, basis_a, basis_c, basis_h):
187
+ # x.shape = [b, q_len, channel]
188
+ x = x.unsqueeze(-1) # [b, q_len, channel, 1]
189
+ # basis_a.shape = [b, q_len, 1, num_basis]
190
+ # basis_c.shape = [b, q_len, 1, num_basis]
191
+ # basis_h.shape = [b, q_len, 1, num_basis]
192
+ eps = 1e-6
193
+ temp = F.sigmoid(basis_a) * torch.exp(-(x - basis_c) ** 2 / (2 * basis_h ** 2 + eps)) # [b, q_len, channel, num_basis]
194
+ # temp = F.sigmoid(basis_a) * torch.exp(-(x - basis_c) ** 2 * (basis_h ** 2) ) # [b, q_len, channel, num_basis]
195
+ return temp.sum(dim=-1, keepdim=False) # [b, q_len, channel]
196
+
197
+ def pad_time_cond(t, len):
198
+ t_sin = torch.cat([torch.sin(w * t) for w in range(1, len + 1)], dim=-1)
199
+ t_cos = torch.cat([torch.cos(w * t) for w in range(1, len + 1)], dim=-1)
200
+ t = torch.cat([t_sin, t_cos, t], dim=-1)
201
+ return t
202
+
203
+
204
+ class condition_interpolation(nn.Module):
205
+ def __init__(
206
+ self,
207
+ hidden_size: int = 2048,
208
+ concept_dim: int = 64,
209
+ ):
210
+ super().__init__()
211
+
212
+ self.hidden_size = hidden_size
213
+ self.concept_dim = concept_dim
214
+ self.r = 8
215
+ # self.len = 15
216
+
217
+ self.lora = nn.Sequential(
218
+ nn.Linear(self.hidden_size * 2 + self.concept_dim * 2, self.hidden_size // self.r, bias=False),
219
+ nn.SiLU(),
220
+ nn.Linear(self.hidden_size // self.r, self.hidden_size, bias=False)
221
+ )
222
+ nn.init.xavier_uniform_(self.lora[0].weight)
223
+ nn.init.zeros_(self.lora[2].weight)
224
+
225
+ def forward(self, start, end, h_new):
226
+ # t = pad_time_cond(t, self.len)
227
+
228
+ x = torch.cat([start, end, h_new, h_new], dim=-1)
229
+ x = self.lora(x)
230
+
231
+ return x
232
+
233
+
234
+
235
+ class Task_Aware_Delta_Net(nn.Module):
236
+ """
237
+ The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa
238
+
239
+ Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters.
240
+ Parameter alloation when use_gate=True:
241
+ - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each
242
+ - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each
243
+ - Others are ignorably small.
244
+ - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size
245
+ NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim.
246
+
247
+ Parameter allocation when use_gate=False:
248
+ - 1 * hidden_size * hidden_size for the q_proj and k_proj each
249
+ - 2 * hidden_size * hidden_size for the v_proj and o_proj each
250
+ - Others are ignorably small.
251
+ - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size
252
+ Args:
253
+ hidden_size (int, Optional):
254
+ The hidden size of the input. Default: 2048.
255
+ expand_v (float, Optional):
256
+ The expansion ratio for the value dim. Default: 2.0.
257
+ head_dim (int, Optional):
258
+ The dimension of each head. Default: 256.
259
+ num_heads (int, Optional):
260
+ The number of heads. Default: 4.
261
+ mode (str, Optional):
262
+ Which Gated DeltaNet kernel to use.
263
+ Currently available: `chunk` and `fused_recurrent`.
264
+ Default: `chunk`.
265
+ use_beta (bool, Optional):
266
+ Whether to use beta. Default: `True`.
267
+ use_gate (bool, Optional):
268
+ Whether to use output gate. Default: `True`.
269
+ use_short_conv (bool, Optional):
270
+ Whether to use short convolutions. Default: `True`.
271
+ conv_size (int, Optional):
272
+ The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4.
273
+ conv_bias (bool, Optional):
274
+ Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`.
275
+ layer_idx (int, Optional):
276
+ The index of the layer. Default: None.
277
+ norm_eps (float, Optional):
278
+ The epsilon value for the normalization layer. Default: 1e-5.
279
+ """
280
+
281
+ def __init__(
282
+ self,
283
+ hidden_size: int = 2048,
284
+ expand_v: float = 1,
285
+ head_dim: int = 256,
286
+ num_heads: int = 6,
287
+ num_heads_delta: int = 6,
288
+ mode: str = 'chunk',
289
+ use_gate: bool = True,
290
+ use_short_conv: bool = True,
291
+ conv_size: int = 4,
292
+ conv_bias: bool = False,
293
+ layer_idx: int = None,
294
+ norm_eps: float = 1e-5,
295
+ rope_theta: float = 10000.,
296
+ max_position_embeddings: int = None,
297
+ window_size: int = None,
298
+ concept_dim: int = 128,
299
+ **kwargs: Unpack[Dict]
300
+ ) -> Task_Aware_Delta_Net:
301
+ super().__init__()
302
+ self.split_size = 64 # 64
303
+
304
+ self.mode = mode
305
+ self.hidden_size = hidden_size
306
+ self.expand_v = expand_v
307
+
308
+ self.use_gate = use_gate
309
+ self.use_short_conv = use_short_conv
310
+ # self.use_short_conv = False
311
+ self.conv_size = conv_size
312
+ self.conv_bias = conv_bias
313
+ self.head_dim = head_dim
314
+ self.strict_head = False
315
+ if self.strict_head:
316
+ head_dim_delta = int (0.75 * hidden_size / num_heads_delta)
317
+ head_dim = head_dim_delta
318
+ self.head_dim_delta = head_dim_delta
319
+ self.head_dim = head_dim_delta
320
+ self.num_heads = num_heads
321
+ self.key_dim = self.num_heads * self.head_dim
322
+ self.value_dim = self.key_dim * self.expand_v
323
+ self.head_qk_dim = head_dim
324
+ self.head_v_dim = head_dim * self.expand_v
325
+ self.layer_idx = layer_idx
326
+ self.silu = nn.SiLU()
327
+ assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`."
328
+ self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
329
+ self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False)
330
+ self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
331
+ self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False)
332
+
333
+ self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
334
+ self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False)
335
+ A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16)
336
+ A_log = torch.log(A)
337
+ self.A_log = nn.Parameter(A_log)
338
+ self.A_log._no_weight_decay = True
339
+ # self.D = nn.Parameter(torch.ones(self.num_heads))
340
+ # self.D._no_weight_decay = True
341
+ # hard coded for now
342
+ dt_min = 0.001
343
+ dt_max = 0.1
344
+ dt_init_floor = 1e-4
345
+ dt = torch.exp(
346
+ torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min))
347
+ + math.log(dt_min)
348
+ )
349
+ dt = torch.clamp(dt, min=dt_init_floor)
350
+ # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
351
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
352
+ self.dt_bias = nn.Parameter(inv_dt)
353
+ # Just to be explicit. Without this we already don't put wd on dt_bias because of the check
354
+ # name.endswith("bias") in param_grouping.py
355
+ self.dt_bias._no_weight_decay = True
356
+ if use_short_conv:
357
+ self.conv_size = conv_size
358
+ self.q_conv1d = ShortConvolution(
359
+ hidden_size=self.key_dim,
360
+ kernel_size=conv_size,
361
+ activation='silu'
362
+ )
363
+ self.k_conv1d = ShortConvolution(
364
+ hidden_size=self.key_dim,
365
+ kernel_size=conv_size,
366
+ activation='silu'
367
+ )
368
+ self.v_conv1d = ShortConvolution(
369
+ hidden_size=self.value_dim,
370
+ kernel_size=conv_size,
371
+ activation='silu'
372
+ )
373
+ else:
374
+ raise UserWarning(
375
+ "ShortConvolution is crucial to the performance. "
376
+ "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing."
377
+ )
378
+ if use_gate:
379
+ self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False)
380
+ self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps)
381
+ else:
382
+ self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
383
+
384
+ self.num_prelude = 2
385
+ self.ttt = True
386
+ if self.ttt and self.layer_idx >= self.num_prelude: # use TTT as cross-layer concept learner
387
+ self.concept_dim = concept_dim # hidden_size // 8
388
+ self.concept_proj = nn.Linear(hidden_size, self.concept_dim * 3, bias=False)
389
+ self.lr1_proj = nn.Linear(hidden_size, 1, bias=False)
390
+ self.lr2_proj = nn.Linear(hidden_size, 1, bias=False)
391
+ # self.router = nn.Linear(hidden_size, self.num_heads * 2, bias=False) # , bias=False
392
+ self.router2 = nn.Linear(self.concept_dim, self.num_heads * 2, bias=False)
393
+ self.router3 = nn.Linear(self.concept_dim, 2, bias=False)
394
+
395
+ self.condition_interpolation = condition_interpolation(hidden_size, concept_dim)
396
+ self.t_proj = nn.Linear(concept_dim, 1, bias=False)
397
+
398
+ # self.num_basis = 2
399
+ # self.basis_proj = nn.Linear(self.concept_dim, self.num_basis * 3, bias=False)
400
+ self.special_mask = nn.Parameter(torch.zeros(self.hidden_size))
401
+ # self.special_mask_gated_delta = nn.Parameter(torch.zeros(self.hidden_size))
402
+ self.use_bias = True
403
+ if self.use_bias:
404
+ self.learnable_bias0 = nn.Parameter(torch.zeros(1))
405
+
406
+ self.apply(self._initialize_weights)
407
+ # Initialize LoRA matrices for q, k, v, and o projections using nn.Sequential
408
+ self.r = 4
409
+ self.q_lora = nn.Sequential(
410
+ nn.Linear(self.hidden_size, self.key_dim // self.r, bias=False),
411
+ nn.SiLU(),
412
+ nn.Linear(self.key_dim // self.r, self.key_dim, bias=False)
413
+ )
414
+ nn.init.xavier_uniform_(self.q_lora[0].weight)
415
+ nn.init.zeros_(self.q_lora[2].weight)
416
+ self.k_lora = nn.Sequential(
417
+ nn.Linear(self.hidden_size, self.key_dim // self.r, bias=False),
418
+ nn.SiLU(),
419
+ nn.Linear(self.key_dim // self.r, self.key_dim, bias=False)
420
+ )
421
+ nn.init.xavier_uniform_(self.k_lora[0].weight)
422
+ nn.init.zeros_(self.k_lora[2].weight)
423
+
424
+ self.v_lora = nn.Sequential(
425
+ nn.Linear(self.hidden_size, self.value_dim // self.r, bias=False),
426
+ nn.SiLU(),
427
+ nn.Linear(self.value_dim // self.r, self.value_dim, bias=False)
428
+ )
429
+ nn.init.xavier_uniform_(self.v_lora[0].weight)
430
+ nn.init.zeros_(self.v_lora[2].weight)
431
+
432
+ self.o_proj_attn = nn.Linear(self.value_dim, self.hidden_size, bias=False)
433
+ nn.init.xavier_uniform_(self.o_proj_attn.weight, gain=2 ** -2.5)
434
+ # self.o_proj_attention = nn.Linear(self.value_dim, self.hidden_size, bias=False)
435
+ self.rope_theta = rope_theta
436
+ self.max_position_embeddings = max_position_embeddings
437
+ self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta)
438
+ self.window_size = window_size
439
+
440
+
441
+ def _initialize_weights(self, module: nn.Module):
442
+ if getattr(module, "_is_hf_initialized", False):
443
+ return
444
+ if isinstance(module, nn.Linear):
445
+ nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5)
446
+ if module.bias is not None:
447
+ nn.init.zeros_(module.bias)
448
+ module._is_hf_initialized = True
449
+
450
+ def forward(
451
+ self,
452
+ hidden_states: torch.Tensor,
453
+ attention_mask: Optional[torch.Tensor] = None,
454
+ past_key_values1: Optional[Cache] = None,
455
+ all_past_key_values: Optional[Cache] = None,
456
+ use_cache: Optional[bool] = False,
457
+ output_attentions: Optional[bool] = False,
458
+ rnn_router: Optional[nn.Module] = None,
459
+ h_old: Optional[torch.Tensor] = None,
460
+ params: Optional[Dict] = None,
461
+ **kwargs: Unpack[Dict]
462
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache], Optional[torch.Tensor], Optional[torch.Tensor]]:
463
+ # output: return o, None, past_key_values1, past_key_values2, h_new, params
464
+ if attention_mask is not None:
465
+ assert len(attention_mask.shape) == 2, (
466
+ "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] "
467
+ "for padding purposes (0 indicating padding). "
468
+ "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed."
469
+ )
470
+
471
+ mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode
472
+ # # mode = self.mode
473
+ # mode = 'chunk'
474
+ if self.training:
475
+ assert mode == 'chunk', "Only chunk mode is supported in training."
476
+ last_state2 = None
477
+ if all_past_key_values is not None:
478
+ if all_past_key_values._seen_tokens > 0:
479
+ past_key_values1, past_key_values2 = all_past_key_values
480
+ else:
481
+ from fla.models.utils import Cache
482
+ past_key_values1, past_key_values2 = Cache(), Cache()
483
+
484
+ if len(past_key_values2) > self.layer_idx:
485
+ last_state2 = past_key_values2[self.layer_idx]
486
+ batch_size, q_len, _ = hidden_states.size()
487
+ cu_seqlens = kwargs.get('cu_seqlens', None)
488
+ max_seqlen = kwargs.get('max_seqlen', q_len)
489
+ if self.ttt:
490
+ flag = True
491
+ if self.layer_idx < self.num_prelude: # 前2层 (0-1)
492
+ if flag == True:
493
+ params = rnn_router.init_params_as_logits(batch_size, q_len)
494
+ flag = False
495
+ mask = torch.ones(batch_size, q_len, self.num_heads, 2, device=hidden_states.device).to(hidden_states.dtype)
496
+ h_new = None
497
+ special_mask_attn = torch.zeros(batch_size, q_len, 1, device=hidden_states.device).to(hidden_states.dtype)
498
+ else:
499
+ concept_qkv = self.concept_proj(hidden_states)
500
+ concept_q, concept_k, concept_v = concept_qkv.chunk(3, dim=-1)
501
+ lr_linear = F.sigmoid(self.lr1_proj(hidden_states)) * 1e-2
502
+ lr_ln = F.sigmoid(self.lr2_proj(hidden_states)) * 1e-2
503
+ # lr_linear = 1e-2
504
+ # lr_ln = 1e-2
505
+ if rnn_router is not None:
506
+ params = rnn_router.learn(concept_k, concept_v, params, lr_linear, lr_ln)
507
+
508
+ h_new = rnn_router.predict(concept_q, params)
509
+ t = F.sigmoid(self.t_proj(h_new))
510
+ t_b = 1 - t
511
+
512
+ input_router = self.router2(h_new)
513
+ # input_router = nn.Softmax(dim=-1)(input_router) # [batch_size, seq_len, head_dim, 2]
514
+ input_router = F.sigmoid(input_router) # [batch_size, seq_len, head_dim * 2]
515
+ special_mask = self.router3(h_new)
516
+ # 添加偏置使第一个位置更容易被选中(通过增加第一个位置的logits值)
517
+ bias = torch.zeros_like(special_mask)
518
+ bias[..., 0] = 2.0
519
+ if self.use_bias:
520
+ bias[..., 0] = 2.0 + self.learnable_bias0 # 给第0个位置添加正偏置,使第一个位置更容易被选为0
521
+ special_mask = F.gumbel_softmax(special_mask + bias, tau=0.1, hard=True)
522
+ special_mask_attn = special_mask[:, :, 1].unsqueeze(-1) # [batch_size, seq_len, 1]
523
+
524
+ mask = input_router
525
+ mask = mask.reshape(batch_size, q_len, self.num_heads, 2)
526
+ # if self.layer_idx >= self.num_prelude:
527
+ # hidden_states = hidden_states + special_mask_gated_delta * self.special_mask_gated_delta.reshape(1, 1, -1)
528
+ if self.use_short_conv:
529
+ conv_state_q, conv_state_k, conv_state_v = None, None, None
530
+ if last_state2 is not None:
531
+ conv_state_q, conv_state_k, conv_state_v = last_state2['conv_state']
532
+ conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None
533
+ # position_ids = kwargs.get('position_ids', None)
534
+ q_shared = self.q_proj(hidden_states)
535
+ k_shared = self.k_proj(hidden_states)
536
+ v_shared = self.v_proj(hidden_states)
537
+ q, conv_state_q = self.q_conv1d(x=q_shared,
538
+ mask=conv_mask,
539
+ cache=conv_state_q,
540
+ output_final_state=use_cache,
541
+ cu_seqlens = cu_seqlens
542
+ )
543
+ k, conv_state_k = self.k_conv1d(x=k_shared,
544
+ mask=conv_mask,
545
+ cache=conv_state_k,
546
+ output_final_state=use_cache,
547
+ cu_seqlens = cu_seqlens
548
+ )
549
+ v, conv_state_v = self.v_conv1d(x=v_shared,
550
+ mask=conv_mask,
551
+ cache=conv_state_v,
552
+ output_final_state=use_cache,
553
+ cu_seqlens = cu_seqlens
554
+ )
555
+ else:
556
+ q = self.silu(self.q_proj(hidden_states))
557
+ k = self.silu(self.k_proj(hidden_states))
558
+ v = self.silu(self.v_proj(hidden_states))
559
+
560
+ if self.layer_idx >= self.num_prelude:
561
+ hidden_states_attn = hidden_states + special_mask_attn * self.special_mask.reshape(1, 1, -1)
562
+ else:
563
+ hidden_states_attn = hidden_states
564
+ q_attn = self.q_lora(hidden_states_attn) + q_shared
565
+ k_attn = self.k_lora(hidden_states_attn) + k_shared
566
+ v_attn = self.v_lora(hidden_states_attn) + v_shared
567
+
568
+ # q_attn = input_router[:, :, 1].unsqueeze(-1) * q_attn
569
+ q_attn, k_attn, v_attn = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (q_attn, k_attn, v_attn))
570
+ # equivalent to cu_seqlens in `flash_attn`
571
+
572
+ seqlen_offset = 0
573
+ # seqlen_offset, max_seqlen = 0, q_len
574
+ if all_past_key_values is not None:
575
+ seqlen_offset = past_key_values1.get_seq_length(self.layer_idx)
576
+ max_seqlen = q_attn.shape[1] + seqlen_offset
577
+
578
+ if attention_mask is not None:
579
+ # to deliminate the offsets of padding tokens
580
+ seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]).clamp(min=0)
581
+ max_seqlen = q_attn.shape[1] + max(seqlen_offset)
582
+
583
+ if self.max_position_embeddings is not None:
584
+ max_seqlen_rotary = max(max_seqlen, self.max_position_embeddings)
585
+ else:
586
+ max_seqlen_rotary = max_seqlen
587
+ q_attn, k_attn = self.rotary(q_attn, k_attn, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen_rotary, cu_seqlens=cu_seqlens)
588
+ if all_past_key_values is not None:
589
+ k_attn, v_attn = past_key_values1.update(
590
+ attn_state=(k_attn.flatten(-2, -1), v_attn.flatten(-2, -1)),
591
+ layer_idx=self.layer_idx,
592
+ offset=q_len,
593
+ cache_kwargs=dict(window_size=self.window_size)
594
+ )['attn_state']
595
+ k_attn = rearrange(k_attn, '... (h d) -> ... h d', h=self.num_heads)
596
+ v_attn = rearrange(v_attn, '... (h d) -> ... h d', h=self.num_heads)
597
+ if flash_attn_func is None:
598
+ raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first")
599
+
600
+ # Contains at least one padding token in the sequence
601
+ if attention_mask is not None:
602
+ q_attn, k_attn, v_attn, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q_attn, k_attn, v_attn, attention_mask, q_len)
603
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
604
+ max_seqlen_q, max_seqlen_k = max_seq_lens
605
+ o_attn = flash_attn_varlen_func(
606
+ q_attn, k_attn, v_attn,
607
+ cu_seqlens_q=cu_seqlens_q,
608
+ cu_seqlens_k=cu_seqlens_k,
609
+ max_seqlen_q=max_seqlen_q,
610
+ max_seqlen_k=max_seqlen_k,
611
+ causal=True,
612
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
613
+ )
614
+ o_attn = pad_input(o_attn, indices_q, batch_size, q_len)
615
+ elif cu_seqlens is not None:
616
+ o_attn = flash_attn_varlen_func(
617
+ q_attn.squeeze(0), k_attn.squeeze(0), v_attn.squeeze(0),
618
+ cu_seqlens_q=cu_seqlens,
619
+ cu_seqlens_k=cu_seqlens,
620
+ max_seqlen_q=max_seqlen,
621
+ max_seqlen_k=max_seqlen,
622
+ causal=True,
623
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
624
+ ).unsqueeze(0)
625
+ else:
626
+ o_attn = flash_attn_func(
627
+ q_attn, k_attn, v_attn,
628
+ causal=True,
629
+ window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0)
630
+ ) # [total, num_heads, head_dim] (total = batch_size * seq_len)
631
+ if batch_size > 1:
632
+ o_attn = o_attn.reshape(batch_size, q_len, self.num_heads, self.head_dim)
633
+
634
+ if self.layer_idx >= self.num_prelude:
635
+ o_attn = torch.einsum("bnh,bnhd->bnhd", mask[:, :, :, 0], o_attn) # [batch_size, seq_len, num_heads, head_dim]
636
+
637
+ o_attn = o_attn.reshape(batch_size, q_len, self.value_dim)
638
+ # o_attn = self.o_proj_attention(o_attn)
639
+ o_attn = self.o_proj_attn(o_attn) # + self.o_proj(o_attn)
640
+ #################################################### end of attention ####################################################
641
+ k, v = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (k, v))
642
+
643
+ beta = self.b_proj(hidden_states).sigmoid()
644
+ g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias)
645
+ # dealing with padding
646
+ if attention_mask is not None:
647
+ beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None])
648
+ g = g.mul(attention_mask[:, -g.shape[-2]:, None])
649
+ recurrent_state = last_state2['recurrent_state'] if last_state2 is not None else None
650
+ # if self.layer_idx >= self.num_prelude:
651
+ # # q_plus_feature = q.clone()
652
+ # q_safe_exp = safe_exp(q)
653
+ # q_plus_feature = q + q_safe_exp * if_feature_map
654
+ # # q_random_feature = random_proj(q, self.down_proj_matrix, self.up_proj_matrix, control_vec)
655
+ # # q_plus_feature = q_plus_feature + q_random_feature * if_feature_map2
656
+ # q_lora = lora_proj(q, self.down_proj_matrix, self.up_proj_matrix, torch.ones_like(control_vec)) # F.sigmoid(control_vec)) # F.sigmoid(control_vec)
657
+ # q_gaussian_feature = gaussian_basis(q_lora, basis_a, basis_c, basis_h)
658
+ # q_plus_feature = q_plus_feature + q_gaussian_feature * if_feature_map3
659
+
660
+ # q = q_plus_feature
661
+ q = rearrange(q, 'b t (h d) -> b t h d', h=self.num_heads)
662
+
663
+ if mode == 'chunk':
664
+ o, recurrent_state = chunk_gated_delta_rule(
665
+ q=q,
666
+ k=k,
667
+ v=v,
668
+ g=g,
669
+ beta=beta,
670
+ initial_state=recurrent_state,
671
+ output_final_state=use_cache,
672
+ cu_seqlens=cu_seqlens,
673
+ head_first=False,
674
+ use_qk_l2norm_in_kernel=True
675
+ )
676
+ elif mode == 'fused_recurrent':
677
+ o, recurrent_state = fused_recurrent_gated_delta_rule(
678
+ q=q,
679
+ k=k,
680
+ v=v,
681
+ g=g,
682
+ beta=beta,
683
+ initial_state=recurrent_state,
684
+ output_final_state=use_cache,
685
+ cu_seqlens=cu_seqlens,
686
+ # head_first=False,
687
+ use_qk_l2norm_in_kernel=True
688
+ )
689
+ if all_past_key_values is not None:
690
+ past_key_values2.update(
691
+ recurrent_state=recurrent_state,
692
+ conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None,
693
+ layer_idx=self.layer_idx,
694
+ offset=q.shape[1]
695
+ )
696
+ if self.use_gate:
697
+ g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads)
698
+ o = self.o_norm(o, g)
699
+ else:
700
+ o = self.o_norm(o)
701
+ if self.layer_idx >= self.num_prelude:
702
+ o = torch.einsum("bnh,bnhd->bnhd", mask[:, :, :, 1], o) # [batch_size, seq_len, num_heads, head_dim]
703
+ o_gated_delta = rearrange(o, 'b t h d -> b t (h d)')
704
+ o_gated_delta = self.o_proj(o_gated_delta)
705
+ #################################################### end of delta rule ####################################################
706
+
707
+ if self.layer_idx < self.num_prelude:
708
+ o = o_gated_delta + o_attn
709
+ else:
710
+ o = t_b * o_gated_delta + t * o_attn
711
+ noise_std = t_b * t
712
+ noise = self.condition_interpolation(o_gated_delta, o_attn, h_new) * noise_std
713
+ o = o + noise
714
+
715
+ if all_past_key_values is not None:
716
+ all_past_key_values = (past_key_values1, past_key_values2)
717
+ return o, None, None, all_past_key_values, h_new, params
718
+
719
+ def _upad_input(self, q, k, v, attention_mask, q_len):
720
+ seqlens = attention_mask.sum(-1, dtype=torch.int32)
721
+ indices_k = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
722
+ max_seqlen_k = seqlens.max().item()
723
+ cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0))
724
+ batch_size, seq_len, num_key_value_heads, head_dim = k.shape
725
+
726
+ k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
727
+ v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k)
728
+ if q_len == seq_len:
729
+ q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k)
730
+ cu_seqlens_q = cu_seqlens_k
731
+ max_seqlen_q = max_seqlen_k
732
+ indices_q = indices_k
733
+ elif q_len == 1:
734
+ max_seqlen_q = 1
735
+ # There is a memcpy here, that is very bad.
736
+ cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device)
737
+ indices_q = cu_seqlens_q[:-1]
738
+ q = q.squeeze(1)
739
+ else:
740
+ # The -q_len: slice assumes left padding.
741
+ attention_mask = attention_mask[:, -q_len:]
742
+ q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask)
743
+
744
+ return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
745
+
746
+ if __name__ == "__main__":
747
+ gated_delta_net_attention = Task_Aware_Delta_Net()
748
+ q = torch.randn(1, 10, 6, 256)
749
+ k = torch.randn(1, 10, 6, 256)
750
+ v = torch.randn(1, 10, 6, 256)
751
+ print(q.shape, k.shape, v.shape)
752
+ # 调用forward函数
753
+ o, _, _, _ = gated_delta_net_attention.forward(hidden_states=torch.randn(2, 70, 128))
754
+ print(o.shape)
ttt_cross_layer.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def scan(f, init, xs, out, checkpoint_group=0):
7
+ """
8
+ 模拟JAX中的lax.scan函数,用于序列化处理数据。
9
+
10
+ 参数:
11
+ f: 处理函数,接收(carry, x)作为输入,返回(new_carry, y)
12
+ init: 初始状态值
13
+ xs: 输入序列,可以是字典或列表
14
+ out: 输出结果的存储张量
15
+ checkpoint_group: 梯度检查点分组数量,用于节省内存
16
+
17
+ 返回:
18
+ carry: 最终的状态值
19
+ out: 填充好的输出张量
20
+ """
21
+ # 初始化状态值
22
+ carry = init
23
+
24
+ # 确定输入序列的长度
25
+ if isinstance(xs, dict):
26
+ # 如果输入是字典,取第一个键对应值的长度
27
+ num_items = len(next(iter(xs.values())))
28
+ else:
29
+ # 如果输入是列表,取第一个元素的长度
30
+ num_items = len(xs[0])
31
+
32
+ def scan_fn(carry, i_start, i_end):
33
+ """内部扫描函数,处理从i_start到i_end的元素"""
34
+ for i in range(i_start, i_end):
35
+ # 提取当前位置的输入
36
+ if isinstance(xs, dict):
37
+ # 字典情况:创建包含每个键在位置i处值的新字典
38
+ x = {key: tensor[i] for key, tensor in xs.items()}
39
+ else:
40
+ # 列表情况:创建包含每个列表在位置i处值的新列表
41
+ x = [x[i] for x in xs]
42
+
43
+ # 调用处理函数f,获取新的状态和输出
44
+ carry, y = f(carry, x)
45
+
46
+ # 将输出存储到结果张量中
47
+ out[i] = y
48
+
49
+ # 返回最终状态
50
+ return carry
51
+
52
+ # 根据checkpoint_group决定是否使用梯度检查点
53
+ if checkpoint_group > 0:
54
+ # 计算每个检查点组包含的元素数量
55
+ ckpt_every_n = num_items // checkpoint_group
56
+
57
+ # 按组处理数据
58
+ for k in range(0, num_items, ckpt_every_n):
59
+ # 使用torch.utils.checkpoint节省内存
60
+ carry = torch.utils.checkpoint.checkpoint(
61
+ scan_fn, carry, k, min(k + ckpt_every_n, num_items), use_reentrant=False
62
+ )
63
+ else:
64
+ # 不使用检查点,直接处理所有数据
65
+ carry = scan_fn(carry, 0, num_items)
66
+
67
+ # 返回最终状态和填充好的输出张量
68
+ return carry, out
69
+
70
+ def ln_fwd(x, gamma, beta, eps=1e-6):
71
+ "Batch forward for LayerNorm."
72
+
73
+ # Mean and variance computation
74
+ mu = x.mean(dim=-1, keepdim=True)
75
+ var = x.var(dim=-1, keepdim=True, unbiased=False)
76
+
77
+ # Normalization
78
+ std = torch.sqrt(var + eps)
79
+ x_hat = (x - mu) / std
80
+
81
+ # Scale and shift
82
+ y = gamma * x_hat + beta
83
+
84
+ return y
85
+
86
+ def ln_fused_l2_bwd(x, l2_target, gamma, beta, eps=1e-6):
87
+ """
88
+ 层归一化(LayerNorm)与L2损失融合的反向传播函数。
89
+
90
+ 这个函数执行两个操作:
91
+ 1. 前向传播:对输入x进行层归一化,得到输出y
92
+ 2. 反向传播:计算L2损失(y - l2_target)对输入x的梯度
93
+
94
+ 参数:
95
+ x: 输入张量
96
+ l2_target: L2损失的目标值
97
+ gamma: 层归一化的缩放参数
98
+ beta: 层归一化的偏移参数
99
+ eps: 数值稳定性的小常数
100
+
101
+ 返回:
102
+ z: 损失对输入x的梯度
103
+ """
104
+ D = x.shape[-1] # 获取特征维度
105
+
106
+ # 计算均值和方差
107
+ mu = x.mean(dim=-1, keepdim=True) # 沿特征维度计算均值
108
+ var = x.var(dim=-1, keepdim=True, unbiased=False) # 计算方差
109
+
110
+ # 归一化处理
111
+ std = torch.sqrt(var + eps) # 计算标准差
112
+ x_hat = (x - mu) / std # 标准化输入
113
+
114
+ # 缩放和偏移
115
+ y = gamma * x_hat + beta # 层归一化的输出
116
+
117
+ # 计算梯度
118
+ grad_output = y - l2_target # L2损失的梯度
119
+ grad_x_hat = grad_output * gamma # 对标准化输入的梯度
120
+
121
+ # 完整的反向传播公式,考虑了归一化操作的链式法则
122
+ z = (
123
+ (1.0 / D)
124
+ * (
125
+ D * grad_x_hat
126
+ - grad_x_hat.sum(dim=-1, keepdim=True) # 均值项的梯度贡献
127
+ - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True) # 方差项的梯度贡献
128
+ )
129
+ / std # 除以标准差完成梯度计算
130
+ )
131
+
132
+ return z
133
+
134
+ from torch.autograd import Function
135
+ class MyLinearFunction(Function):
136
+ @staticmethod
137
+ def forward(ctx, input, weight, bias):
138
+ """
139
+ 正向计算: y = x * W^T + b
140
+ 参数:
141
+ ctx :上下文对象,用于保存反向传播时需要的信息。
142
+ input :输入 tensor, 尺寸为 (N, in_features)
143
+ weight :权重 tensor, 尺寸为 (out_features, in_features)
144
+ bias :偏置 tensor, 尺寸为 (out_features)
145
+ 返回:
146
+ 输出 tensor, 尺寸为 (N, out_features)
147
+ """
148
+ # 保存必要的中间变量,供 backward 时使用
149
+ ctx.save_for_backward(input, weight, bias)
150
+
151
+ # 计算输出
152
+ output = input.matmul(weight.t()) + bias
153
+ return output
154
+
155
+ @staticmethod
156
+ def backward(ctx, grad_output):
157
+ """
158
+ 反向传播:计算正向计算中各个输入的梯度。
159
+ 参数:
160
+ grad_output:从上层传回来的梯度,形状与 forward 的输出相同 (N, out_features)
161
+ 返回:
162
+ grad_input :关于 input 的梯度,形状 (N, in_features)
163
+ grad_weight :关于 weight 的梯度,形状 (out_features, in_features)
164
+ grad_bias :关于 bias 的梯度,形状 (out_features)
165
+ """
166
+ # 从上下文中取出保存的变量
167
+ input, weight, bias = ctx.saved_tensors
168
+
169
+ # 链式法则:已知 output = input.matmul(weight.t()) + bias
170
+ # 关于 input 的梯度:
171
+ # ∂L/∂input = ∂L/∂output * ∂output/∂input = grad_output.matmul(weight)
172
+ grad_input = grad_output.matmul(weight)
173
+
174
+ # 关于 weight 的梯度:
175
+ # ∂L/∂weight = ∂L/∂output^T * ∂output/∂weight
176
+ # 注意到 output 对 weight 的导数为 input 的转置,此处:
177
+ # grad_weight 的计算通常为:grad_output^T.matmul(input)
178
+ grad_weight = grad_output.t().matmul(input)
179
+
180
+ # 关于 bias 的梯度:
181
+ # 因为 output = ... + bias,因此每个 bias 项对应所有样本的梯度和
182
+ grad_bias = grad_output.sum(dim=0)
183
+
184
+ # 注意:返回的梯度顺序必须与 forward 中参数的顺序一致
185
+ return grad_input, grad_weight, grad_bias
186
+
187
+ class TTT_Cross_Layer(nn.Module):
188
+ def __init__(self, config):
189
+ super().__init__()
190
+ self.input_size = config.concept_dim # 128
191
+ self.concept_dim = config.concept_dim # 128
192
+ # self.linear = nn.Linear(self.input_size, self.hidden_size)
193
+ # self.ln = nn.LayerNorm(self.hidden_size)
194
+
195
+ # self.logit_dim = 32
196
+ self.logit_dim = config.logit_dim
197
+
198
+ self.weight_linear = nn.Parameter(torch.empty(self.concept_dim, self.input_size, self.logit_dim))
199
+ self.weight_ln = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim))
200
+ self.bias_linear = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim))
201
+ self.bias_ln = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim))
202
+
203
+ # self.weight_linear_tmp = torch.empty_like(self.weight_linear)
204
+ # self.weight_ln_tmp = torch.empty_like(self.weight_ln)
205
+ # self.bias_linear_tmp = torch.empty_like(self.bias_linear)
206
+ # self.bias_ln_tmp = torch.empty_like(self.bias_ln)
207
+
208
+ self.config = config
209
+ self.init_weights()
210
+ # def init_tmp_weights(self):
211
+ # weight_linear_tmp = self.weight_linear.clone().to(self.weight_linear.device).to(self.weight_linear.dtype)
212
+ # weight_ln_tmp = self.weight_ln.clone().to(self.weight_ln.device).to(self.weight_ln.dtype)
213
+ # bias_linear_tmp = self.bias_linear.clone().to(self.bias_linear.device).to(self.bias_linear.dtype)
214
+ # bias_ln_tmp = self.bias_ln.clone().to(self.bias_ln.device).to(self.bias_ln.dtype)
215
+ # params = {
216
+ # 'weight_linear_tmp': weight_linear_tmp,
217
+ # 'weight_ln_tmp': weight_ln_tmp,
218
+ # 'bias_linear_tmp': bias_linear_tmp,
219
+ # 'bias_ln_tmp': bias_ln_tmp
220
+ # }
221
+ # return params
222
+
223
+ def init_params_as_logits(self, batch_size, sequence_length):
224
+ weight_linear_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype)
225
+ weight_ln_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype)
226
+ bias_linear_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype)
227
+ bias_ln_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype)
228
+
229
+ params = {
230
+ 'weight_linear_tmp': weight_linear_tmp,
231
+ 'weight_ln_tmp': weight_ln_tmp,
232
+ 'bias_linear_tmp': bias_linear_tmp,
233
+ 'bias_ln_tmp': bias_ln_tmp
234
+ }
235
+ return params
236
+
237
+ def init_weights(self):
238
+ # torch.manual_seed(42) # 固定随机种子可能导致可预测性
239
+ nn.init.normal_(self.weight_linear, mean=0.0, std=self.config.initializer_range)
240
+ nn.init._no_grad_fill_(self.weight_ln, 1.0 / self.logit_dim)
241
+ # nn.init.zeros_(self.bias_linear)
242
+ # nn.init.zeros_(self.bias_ln)
243
+ nn.init.normal_(self.bias_linear, mean=0.0, std=self.config.initializer_range / self.logit_dim)
244
+ nn.init.normal_(self.bias_linear, mean=0.0, std=self.config.initializer_range / self.logit_dim)
245
+
246
+ def get_weight_per_token(self, params):
247
+
248
+ weight_linear_tmp = torch.einsum('iol,bsl->bsio', self.weight_linear, params['weight_linear_tmp'])
249
+ weight_ln_tmp = torch.einsum('ol,bsl->bso', self.weight_ln, params['weight_ln_tmp'])
250
+ bias_linear_tmp = torch.einsum('ol,bsl->bso', self.bias_linear, params['bias_linear_tmp'])
251
+ bias_ln_tmp = torch.einsum('ol,bsl->bso', self.bias_ln, params['bias_ln_tmp'])
252
+
253
+ return weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp
254
+
255
+ def learn(self, k, v, params, lr_linear=1, lr_ln=1):
256
+ # k和v形状: [batch_size, length, channel_dim]
257
+ # batch_size, seq_length, channel_dim = k.shape
258
+ # weight_linear_tmp = params['weight_linear_tmp']
259
+ # weight_ln_tmp = params['weight_ln_tmp']
260
+ # bias_linear_tmp = params['bias_linear_tmp']
261
+ # bias_ln_tmp = params['bias_ln_tmp']
262
+ weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp = self.get_weight_per_token(params)
263
+ # 1. 将输入重塑为二维以进行预测
264
+ # k_reshaped = k.reshape(-1, channel_dim) # [batch_size*length, channel_dim]
265
+
266
+ # output_reshaped = self.predict(k_reshaped, params) # [batch_size*length, channel_dim]
267
+ # z = F.linear(k_reshaped, params['weight_linear_tmp'], params['bias_linear_tmp'])
268
+ # mu = z.mean(dim=-1, keepdim=True)
269
+ # var = z.var(dim=-1, keepdim=True, unbiased=False)
270
+
271
+ z = torch.einsum('bsi,bsio->bso', k, weight_linear_tmp) + bias_linear_tmp
272
+ mu = z.mean(dim=-1, keepdim=True)
273
+ var = z.var(dim=-1, keepdim=True, unbiased=False)
274
+
275
+ # Normalization
276
+ eps = 1e-6
277
+ std = torch.sqrt(var + eps)
278
+ z_hat = (z - mu) / std
279
+ # output_reshaped = params['weight_ln_tmp'] * z_hat + params['bias_ln_tmp'] + k
280
+ output_reshaped = weight_ln_tmp * z_hat + bias_ln_tmp + k
281
+
282
+ # # 计算误差
283
+ # v_reshaped = v.reshape(-1, channel_dim)
284
+ # error_reshaped = output_reshaped - v_reshaped # [batch_size*length, channel_dim]
285
+ error_reshaped = output_reshaped - v
286
+ # 计算层归一化梯度
287
+ # 层归一化参数更新
288
+ # ln_rate = learning_rate * 0.1 # 降低LN学习率
289
+ grad_weight_ln_temp = error_reshaped * z_hat
290
+ # grad_weight_ln = grad_weight_ln_temp.mean(dim=0) #
291
+ # weight_ln_tmp = weight_ln_tmp - ln_rate * grad_weight_ln # sequence length, channel_dim
292
+ grad_weight_ln = grad_weight_ln_temp
293
+ # batch_size, sequence length, logit_dim
294
+ params0 = params['weight_ln_tmp'] - lr_ln * torch.einsum('ol,bso->bsl', self.weight_ln, grad_weight_ln)
295
+
296
+ # bias_update = ln_rate * error_reshaped # .mean(dim=0)
297
+ # bias_ln_tmp = bias_ln_tmp - bias_update # batch_size, sequence length, concept_dim
298
+ grad_bias_ln = error_reshaped
299
+ params1 = params['bias_ln_tmp'] - lr_ln * torch.einsum('ol,bso->bsl', self.bias_ln, grad_bias_ln)
300
+
301
+ # 线性层权重梯度: [out_dim, in_dim]
302
+ # grad_linear_temp = error_reshaped - error_reshaped.mean(dim=-1, keepdim=True) - z_hat * grad_weight_ln_temp.mean(dim=-1, keepdim=True)
303
+ grad_linear = weight_ln_tmp * error_reshaped / std # batch_size, sequence length, concept_dim
304
+ # grad_weight_linear = grad_linear.t() @ k # [channel_dim, channel_dim]
305
+ grad_weight_linear = torch.einsum('bsi,bso->bsio', k, grad_linear)
306
+ # 应用梯度 (避免使用原地操作 -=)
307
+ # weight_linear_tmp = weight_linear_tmp - learning_rate * grad_weight_linear.mean(dim=0)
308
+ params2 = params['weight_linear_tmp'] - lr_linear * torch.einsum('iol,bsio->bsl', self.weight_linear, grad_weight_linear)
309
+ # 更新偏置(如果存在) (避免使用原地操作 -=)
310
+ grad_b = grad_linear #.mean(dim=0) # [channel_dim]
311
+ # bias_linear_tmp = bias_linear_tmp - learning_rate * grad_b
312
+ params3 = params['bias_linear_tmp'] - lr_linear * torch.einsum('ol,bso->bsl', self.bias_linear, grad_b)
313
+
314
+ params_new = {
315
+ 'weight_linear_tmp': params2,
316
+ 'weight_ln_tmp': params0,
317
+ 'bias_linear_tmp': params3,
318
+ 'bias_ln_tmp': params1
319
+ }
320
+
321
+ return params_new
322
+
323
+ def predict(self, q, params):
324
+ weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp = self.get_weight_per_token(params)
325
+ z = torch.einsum('bsi,bsio->bso', q, weight_linear_tmp) + bias_linear_tmp
326
+ mu = z.mean(dim=-1, keepdim=True)
327
+ var = z.var(dim=-1, keepdim=True, unbiased=False)
328
+
329
+ # Normalization
330
+ eps = 1e-6
331
+ std = torch.sqrt(var + eps)
332
+ z_hat = (z - mu) / std
333
+ # output_reshaped = params['weight_ln_tmp'] * z_hat + params['bias_ln_tmp'] + k
334
+ output = weight_ln_tmp * z_hat + bias_ln_tmp + q
335
+
336
+ return output
337
+
338
+
339
+
340
+