File size: 14,652 Bytes
b510dde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
import torch
import torch.nn as nn
import torch.nn.functional as F


def scan(f, init, xs, out, checkpoint_group=0):
    """
    模拟JAX中的lax.scan函数,用于序列化处理数据。
    
    参数:
        f: 处理函数,接收(carry, x)作为输入,返回(new_carry, y)
        init: 初始状态值
        xs: 输入序列,可以是字典或列表
        out: 输出结果的存储张量
        checkpoint_group: 梯度检查点分组数量,用于节省内存
    
    返回:
        carry: 最终的状态值
        out: 填充好的输出张量
    """
    # 初始化状态值
    carry = init
    
    # 确定输入序列的长度
    if isinstance(xs, dict):
        # 如果输入是字典,取第一个键对应值的长度
        num_items = len(next(iter(xs.values())))
    else:
        # 如果输入是列表,取第一个元素的长度
        num_items = len(xs[0])

    def scan_fn(carry, i_start, i_end):
        """内部扫描函数,处理从i_start到i_end的元素"""
        for i in range(i_start, i_end):
            # 提取当前位置的输入
            if isinstance(xs, dict):
                # 字典情况:创建包含每个键在位置i处值的新字典
                x = {key: tensor[i] for key, tensor in xs.items()}
            else:
                # 列表情况:创建包含每个列表在位置i处值的新列表
                x = [x[i] for x in xs]
            
            # 调用处理函数f,获取新的状态和输出
            carry, y = f(carry, x)
            
            # 将输出存储到结果张量中
            out[i] = y
        
        # 返回最终状态
        return carry

    # 根据checkpoint_group决定是否使用梯度检查点
    if checkpoint_group > 0:
        # 计算每个检查点组包含的元素数量
        ckpt_every_n = num_items // checkpoint_group
        
        # 按组处理数据
        for k in range(0, num_items, ckpt_every_n):
            # 使用torch.utils.checkpoint节省内存
            carry = torch.utils.checkpoint.checkpoint(
                scan_fn, carry, k, min(k + ckpt_every_n, num_items), use_reentrant=False
            )
    else:
        # 不使用检查点,直接处理所有数据
        carry = scan_fn(carry, 0, num_items)

    # 返回最终状态和填充好的输出张量
    return carry, out

def ln_fwd(x, gamma, beta, eps=1e-6):
    "Batch forward for LayerNorm."

    # Mean and variance computation
    mu = x.mean(dim=-1, keepdim=True)
    var = x.var(dim=-1, keepdim=True, unbiased=False)

    # Normalization
    std = torch.sqrt(var + eps)
    x_hat = (x - mu) / std

    # Scale and shift
    y = gamma * x_hat + beta

    return y

def ln_fused_l2_bwd(x, l2_target, gamma, beta, eps=1e-6):
    """
    层归一化(LayerNorm)与L2损失融合的反向传播函数。
    
    这个函数执行两个操作:
    1. 前向传播:对输入x进行层归一化,得到输出y
    2. 反向传播:计算L2损失(y - l2_target)对输入x的梯度
    
    参数:
        x: 输入张量
        l2_target: L2损失的目标值
        gamma: 层归一化的缩放参数
        beta: 层归一化的偏移参数
        eps: 数值稳定性的小常数
        
    返回:
        z: 损失对输入x的梯度
    """
    D = x.shape[-1]  # 获取特征维度

    # 计算均值和方差
    mu = x.mean(dim=-1, keepdim=True)  # 沿特征维度计算均值
    var = x.var(dim=-1, keepdim=True, unbiased=False)  # 计算方差

    # 归一化处理
    std = torch.sqrt(var + eps)  # 计算标准差
    x_hat = (x - mu) / std  # 标准化输入

    # 缩放和偏移
    y = gamma * x_hat + beta  # 层归一化的输出

    # 计算梯度
    grad_output = y - l2_target  # L2损失的梯度
    grad_x_hat = grad_output * gamma  # 对标准化输入的梯度
    
    # 完整的反向传播公式,考虑了归一化操作的链式法则
    z = (
        (1.0 / D)
        * (
            D * grad_x_hat
            - grad_x_hat.sum(dim=-1, keepdim=True)  # 均值项的梯度贡献
            - x_hat * (grad_x_hat * x_hat).sum(dim=-1, keepdim=True)  # 方差项的梯度贡献
        )
        / std  # 除以标准差完成梯度计算
    )

    return z

from torch.autograd import Function
class MyLinearFunction(Function):
    @staticmethod
    def forward(ctx, input, weight, bias):
        """
        正向计算: y = x * W^T + b
        参数:
            ctx    :上下文对象,用于保存反向传播时需要的信息。
            input  :输入 tensor, 尺寸为 (N, in_features)
            weight :权重 tensor, 尺寸为 (out_features, in_features)
            bias   :偏置 tensor, 尺寸为 (out_features)
        返回:
            输出 tensor, 尺寸为 (N, out_features)
        """
        # 保存必要的中间变量,供 backward 时使用
        ctx.save_for_backward(input, weight, bias)
        
        # 计算输出
        output = input.matmul(weight.t()) + bias
        return output

    @staticmethod
    def backward(ctx, grad_output):
        """
        反向传播:计算正向计算中各个输入的梯度。
        参数:
            grad_output:从上层传回来的梯度,形状与 forward 的输出相同 (N, out_features)
        返回:
            grad_input  :关于 input 的梯度,形状 (N, in_features)
            grad_weight :关于 weight 的梯度,形状 (out_features, in_features)
            grad_bias   :关于 bias 的梯度,形状 (out_features)
        """
        # 从上下文中取出保存的变量
        input, weight, bias = ctx.saved_tensors
        
        # 链式法则:已知 output = input.matmul(weight.t()) + bias
        # 关于 input 的梯度:
        # ∂L/∂input = ∂L/∂output * ∂output/∂input = grad_output.matmul(weight)
        grad_input = grad_output.matmul(weight)
        
        # 关于 weight 的梯度:
        # ∂L/∂weight = ∂L/∂output^T * ∂output/∂weight
        # 注意到 output 对 weight 的导数为 input 的转置,此处:
        # grad_weight 的计算通常为:grad_output^T.matmul(input)
        grad_weight = grad_output.t().matmul(input)
        
        # 关于 bias 的梯度:
        # 因为 output = ... + bias,因此每个 bias 项对应所有样本的梯度和
        grad_bias = grad_output.sum(dim=0)
        
        # 注意:返回的梯度顺序必须与 forward 中参数的顺序一致
        return grad_input, grad_weight, grad_bias

class TTT_Cross_Layer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.input_size = config.concept_dim   # 128
        self.concept_dim = config.concept_dim  # 128
        # self.linear = nn.Linear(self.input_size, self.hidden_size)
        # self.ln = nn.LayerNorm(self.hidden_size)

        # self.logit_dim = 32
        self.logit_dim = config.logit_dim

        self.weight_linear = nn.Parameter(torch.empty(self.concept_dim, self.input_size, self.logit_dim))
        self.weight_ln = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim))
        self.bias_linear = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim))
        self.bias_ln = nn.Parameter(torch.empty(self.concept_dim, self.logit_dim))

        # self.weight_linear_tmp = torch.empty_like(self.weight_linear)
        # self.weight_ln_tmp = torch.empty_like(self.weight_ln)
        # self.bias_linear_tmp = torch.empty_like(self.bias_linear)
        # self.bias_ln_tmp = torch.empty_like(self.bias_ln)
        
        self.config = config
        self.init_weights()
    # def init_tmp_weights(self):
    #     weight_linear_tmp = self.weight_linear.clone().to(self.weight_linear.device).to(self.weight_linear.dtype)
    #     weight_ln_tmp = self.weight_ln.clone().to(self.weight_ln.device).to(self.weight_ln.dtype)
    #     bias_linear_tmp = self.bias_linear.clone().to(self.bias_linear.device).to(self.bias_linear.dtype)
    #     bias_ln_tmp = self.bias_ln.clone().to(self.bias_ln.device).to(self.bias_ln.dtype)
    #     params = {
    #         'weight_linear_tmp': weight_linear_tmp,
    #         'weight_ln_tmp': weight_ln_tmp,
    #         'bias_linear_tmp': bias_linear_tmp,
    #         'bias_ln_tmp': bias_ln_tmp
    #     }
    #     return params
    
    def init_params_as_logits(self, batch_size, sequence_length):
        weight_linear_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype)
        weight_ln_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype)
        bias_linear_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype)
        bias_ln_tmp = torch.ones(batch_size, sequence_length, self.logit_dim).to(self.weight_linear.device).to(self.weight_linear.dtype)
        
        params = {
            'weight_linear_tmp': weight_linear_tmp,
            'weight_ln_tmp': weight_ln_tmp,
            'bias_linear_tmp': bias_linear_tmp,
            'bias_ln_tmp': bias_ln_tmp
        }
        return params

    def init_weights(self):
        # torch.manual_seed(42)  # 固定随机种子可能导致可预测性
        nn.init.normal_(self.weight_linear, mean=0.0, std=self.config.initializer_range)
        nn.init._no_grad_fill_(self.weight_ln, 1.0 / self.logit_dim)
        # nn.init.zeros_(self.bias_linear)
        # nn.init.zeros_(self.bias_ln)
        nn.init.normal_(self.bias_linear, mean=0.0, std=self.config.initializer_range / self.logit_dim)
        nn.init.normal_(self.bias_linear, mean=0.0, std=self.config.initializer_range / self.logit_dim)

    def get_weight_per_token(self, params):
        
        weight_linear_tmp = torch.einsum('iol,bsl->bsio', self.weight_linear, params['weight_linear_tmp'])
        weight_ln_tmp = torch.einsum('ol,bsl->bso', self.weight_ln, params['weight_ln_tmp'])
        bias_linear_tmp = torch.einsum('ol,bsl->bso', self.bias_linear, params['bias_linear_tmp'])
        bias_ln_tmp = torch.einsum('ol,bsl->bso', self.bias_ln, params['bias_ln_tmp'])

        return weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp

    def learn(self, k, v, params, lr_linear=1, lr_ln=1):
        # k和v形状: [batch_size, length, channel_dim]
        # batch_size, seq_length, channel_dim = k.shape
        # weight_linear_tmp = params['weight_linear_tmp']
        # weight_ln_tmp = params['weight_ln_tmp']
        # bias_linear_tmp = params['bias_linear_tmp']
        # bias_ln_tmp = params['bias_ln_tmp']
        weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp = self.get_weight_per_token(params)
        # 1. 将输入重塑为二维以进行预测
        # k_reshaped = k.reshape(-1, channel_dim)  # [batch_size*length, channel_dim]
        
        # output_reshaped = self.predict(k_reshaped, params)  # [batch_size*length, channel_dim]
        # z = F.linear(k_reshaped, params['weight_linear_tmp'], params['bias_linear_tmp'])
        # mu = z.mean(dim=-1, keepdim=True)
        # var = z.var(dim=-1, keepdim=True, unbiased=False)

        z = torch.einsum('bsi,bsio->bso', k, weight_linear_tmp) + bias_linear_tmp
        mu = z.mean(dim=-1, keepdim=True)
        var = z.var(dim=-1, keepdim=True, unbiased=False)

        # Normalization
        eps = 1e-6
        std = torch.sqrt(var + eps)
        z_hat = (z - mu) / std     
        # output_reshaped = params['weight_ln_tmp'] * z_hat + params['bias_ln_tmp'] + k
        output_reshaped = weight_ln_tmp * z_hat + bias_ln_tmp + k

        # # 计算误差
        # v_reshaped = v.reshape(-1, channel_dim)
        # error_reshaped = output_reshaped - v_reshaped  # [batch_size*length, channel_dim]
        error_reshaped = output_reshaped - v
        # 计算层归一化梯度
        # 层归一化参数更新
        # ln_rate = learning_rate * 0.1  # 降低LN学习率        
        grad_weight_ln_temp = error_reshaped * z_hat
        # grad_weight_ln = grad_weight_ln_temp.mean(dim=0) # 
        # weight_ln_tmp = weight_ln_tmp - ln_rate * grad_weight_ln # sequence length, channel_dim
        grad_weight_ln = grad_weight_ln_temp
        # batch_size, sequence length, logit_dim
        params0 = params['weight_ln_tmp'] - lr_ln * torch.einsum('ol,bso->bsl', self.weight_ln, grad_weight_ln)
        
        # bias_update = ln_rate * error_reshaped # .mean(dim=0)
        # bias_ln_tmp = bias_ln_tmp - bias_update # batch_size, sequence length, concept_dim
        grad_bias_ln = error_reshaped
        params1 = params['bias_ln_tmp'] - lr_ln * torch.einsum('ol,bso->bsl', self.bias_ln, grad_bias_ln)

        # 线性层权重梯度: [out_dim, in_dim]
        # grad_linear_temp = error_reshaped - error_reshaped.mean(dim=-1, keepdim=True) - z_hat * grad_weight_ln_temp.mean(dim=-1, keepdim=True)
        grad_linear = weight_ln_tmp * error_reshaped / std # batch_size, sequence length, concept_dim
        # grad_weight_linear = grad_linear.t() @ k  # [channel_dim, channel_dim]
        grad_weight_linear = torch.einsum('bsi,bso->bsio', k, grad_linear)
        # 应用梯度 (避免使用原地操作 -=)
        # weight_linear_tmp = weight_linear_tmp - learning_rate * grad_weight_linear.mean(dim=0)
        params2 = params['weight_linear_tmp'] - lr_linear * torch.einsum('iol,bsio->bsl', self.weight_linear, grad_weight_linear)
        # 更新偏置(如果存在) (避免使用原地操作 -=)
        grad_b = grad_linear #.mean(dim=0)  # [channel_dim]
        # bias_linear_tmp = bias_linear_tmp - learning_rate * grad_b
        params3 = params['bias_linear_tmp'] - lr_linear * torch.einsum('ol,bso->bsl', self.bias_linear, grad_b)
        
        params_new = {
            'weight_linear_tmp': params2,
            'weight_ln_tmp': params0,
            'bias_linear_tmp': params3,
            'bias_ln_tmp': params1
        }

        return params_new

    def predict(self, q, params):
        weight_linear_tmp, weight_ln_tmp, bias_linear_tmp, bias_ln_tmp = self.get_weight_per_token(params)
        z = torch.einsum('bsi,bsio->bso', q, weight_linear_tmp) + bias_linear_tmp
        mu = z.mean(dim=-1, keepdim=True)
        var = z.var(dim=-1, keepdim=True, unbiased=False)

        # Normalization
        eps = 1e-6
        std = torch.sqrt(var + eps)
        z_hat = (z - mu) / std     
        # output_reshaped = params['weight_ln_tmp'] * z_hat + params['bias_ln_tmp'] + k
        output = weight_ln_tmp * z_hat + bias_ln_tmp + q

        return output