# -*- coding: utf-8 -*- # Copyright (c) 2024, Songlin Yang, Yu Zhang from __future__ import annotations import math from typing import TYPE_CHECKING, Dict, Optional, Tuple import torch import torch.nn as nn from einops import rearrange from torch.nn import functional as F from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution from fla.ops.gated_delta_rule import (chunk_gated_delta_rule, fused_recurrent_gated_delta_rule) if TYPE_CHECKING: from transformers.processing_utils import Unpack from fla.models.utils import Cache def elu_p1(x): return (F.elu(x, 1., False) + 1.).to(x) def sum_norm(x): return (x / x.sum(-1, keepdim=True)).to(x) from fla.modules import RMSNorm, RotaryEmbedding if TYPE_CHECKING: from fla.models.utils import Cache import warnings try: from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import (index_first_axis, pad_input, unpad_input) except ImportError: warnings.warn( "Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", category=ImportWarning ) flash_attn_func = None # https://github.com/IDSIA/recurrent-fwp/blob/master/algorithmic/layers.py#L86C1-L146C1 def lambda_init_fn(depth): return 0.8 - 0.6 * math.exp(-0.3 * depth) # -*- coding: utf-8 -*- from typing import Optional, Tuple import torch from einops import rearrange from fla.ops.linear_attn.utils import normalize_output # def scattering_mixer( # q: torch.Tensor, # k: torch.Tensor, # v: torch.Tensor, # gamma: torch.Tensor, # # chi: torch.Tensor, # scale: Optional[float] = None, # normalize: bool = False # ) -> Tuple[torch.Tensor, torch.Tensor]: # if scale is None: # scale = q.shape[-1] ** -0.5 # chunk_size = 64 # # split_size = 2 # q = rearrange(q, 'b (n c) h d -> b h n c d', c=chunk_size) * scale # # k = rearrange(k, 'b h (n c) d -> b h n c d', c=chunk_size) # # gamma (b , n*c, h) -> (b, h, n*c, 1) # gamma = rearrange(gamma, 'b l h -> b h l').unsqueeze(-1) # gamma_cumprod = torch.cumprod(gamma, dim=2) # gamma_cumprod_chunk = rearrange(gamma_cumprod, 'b h (n c) d -> b h n c d', c=chunk_size) # gamma_cumprod_chunk = gamma_cumprod_chunk[:, :, :, -1, :].unsqueeze(-2) # [b, h, n, 1, 1] # gamma_cumprod = rearrange(gamma_cumprod, 'b h l d -> b l h d') # k_cumprod = k / gamma_cumprod # k = rearrange(k, 'b (n c) h d -> b h n c d', c=chunk_size) # k_cumprod_chunk = rearrange(k_cumprod, 'b (n c) h d -> b h n c d', c=chunk_size) # # gamma_cumprod_chunk = rearrange(gamma_cumprod, 'b h n c d -> b h (n c) d') # v = rearrange(v, 'b (n c) h d -> b h n c d', c=chunk_size) # gamma = rearrange(gamma, 'b h (n c) d -> b h n c d', c=chunk_size) # d = 1 # # gamma_cumprod_chunk_inter = torch.cumprod(gamma, dim=3) # gamma_inter = torch.cumprod(gamma, dim=3) # [b, h, n, c, 1] # kv = k_cumprod_chunk.transpose(-1, -2) @ v # [b, h, n, d, d] # kv = kv.cumsum(2) # [b, h, n, d, d] n << seq_len # kv = kv * gamma_cumprod_chunk # [b, h, n, d, d] # kv = torch.cat([torch.zeros_like(kv[:, :, :1]), kv[:, :, :-1]], dim=2) # [b, h, n, d, d] # inter = (q @ kv) * gamma_inter # [b, h, n, c, d] # intra = ( # ((q @ (k / gamma_inter).transpose(-1, -2)) ).masked_fill_( # torch.triu(torch.ones(chunk_size, chunk_size, dtype=bool, device=q.device), diagonal=1), # 0 # )) @ v * gamma_inter # [b, h, n, c, d] # o = inter + intra # [b, h, n, c, d] # if normalize: # o = normalize_output(q * scale, k, o) # return rearrange(o, 'b h n c d -> b (n c) h d') , None def scattering_mixer_recurrent( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, G0: torch.Tensor, split_size: int, past_kv: Optional[torch.Tensor] = None, beta: Optional[torch.Tensor] = None, # chi: torch.Tensor, scale: Optional[float] = None, normalize: bool = False, order: int = 2, perturb: Optional[torch.Tensor] = None ) -> Tuple[torch.Tensor, torch.Tensor]: if scale is None: scale = q.shape[-1] ** -0.5 # chunk_size = 64 q = rearrange(q, 'b l h (f s) -> b h l s f', s=split_size) * scale k = rearrange(k, 'b l h (f s) -> b h l s f', s=split_size) v = rearrange(v, 'b l h (d s) -> b h l s d', s=split_size) if order == 2: G0 = rearrange(G0, 'b l h d f -> b h l d f') # kv = k.transpose(-1, -2) @ v # [b, h, l, f, d] second_term = torch.einsum('b h l s d, b h l d f -> b h l s f', v, G0) # [b, h, l, s, f] G1 = second_term @ k.transpose(-1, -2) # [b, h, l, s, s] kv2 = k.transpose(-1, -2) @ G1 + k.transpose(-1, -2) # [b, h, l, f ,s] else: kv2 = k.transpose(-1, -2) # [b, h, l, f ,s] kv = kv2 @ v # [b, h, l, f, d] # kv = kv + kv2 perturb = rearrange(perturb, 'b l h f k -> b h l f k') # [b, h, l, f, f] M = q.transpose(-1, -2) @ q # [b, h, l, f, f] M = perturb @ M # [b, h, l, f, f] M = q @ M # [b, h, l, s, f] q = q + M # [b, h, l, s, f] if past_kv is None: if beta is not None: beta = rearrange(beta, 'b l h -> b h l') beta_cumprod = torch.cumprod(beta, dim=2) # print('the shape of beta_cumprod', beta_cumprod.shape) beta_cumprod = torch.cat([torch.ones_like(beta_cumprod[:, :, :1]), beta_cumprod[:, :, :-1]], dim=2) # kv = kv + kv2 beta_cumprod = rearrange(beta_cumprod, 'b h l -> b h l 1 1') kv = kv / beta_cumprod # [b, h, l, f, d] kv = kv.cumsum(2) # [b, h, l, f, d] kv = kv * beta_cumprod # [b, h, l, f, d] else: kv = kv.cumsum(2) # [b, h, l, f, d] o = q @ kv # [b, h, l, s, d] else: if beta is not None: beta = rearrange(beta, 'b l h -> b h l') kv = kv[:, :, -1, :, :] + past_kv * (beta[:, :, -2]).unsqueeze(-1).unsqueeze(-1) else: kv = kv[:, :, -1, :, :] + past_kv # [b, h, l, f, d] o = q @ kv # [b, h, l, s, d] # print('the shape of o', o.shape) if normalize: o = normalize_output(q * scale, k, o) # [b, h, l, s, d] return rearrange(o, 'b h l s d -> b l h (s d)') , kv def safe_exp(x): return torch.exp(x - torch.max(x,dim=-1,keepdim=True)[0]) def random_proj(q, down_proj_matrix, up_proj_matrix, control_vec): temp = q @ down_proj_matrix temp = temp * control_vec temp = temp @ up_proj_matrix return torch.concat([torch.cos(temp), torch.sin(temp)], dim=-1) def lora_proj(x, down_proj_matrix, up_proj_matrix, control_vec): temp = x @ down_proj_matrix temp = temp * control_vec temp = temp @ up_proj_matrix return temp def gaussian_basis(x, basis_a, basis_c, basis_h): # x.shape = [b, q_len, channel] x = x.unsqueeze(-1) # [b, q_len, channel, 1] # basis_a.shape = [b, q_len, 1, num_basis] # basis_c.shape = [b, q_len, 1, num_basis] # basis_h.shape = [b, q_len, 1, num_basis] eps = 1e-6 temp = F.sigmoid(basis_a) * torch.exp(-(x - basis_c) ** 2 / (2 * basis_h ** 2 + eps)) # [b, q_len, channel, num_basis] # temp = F.sigmoid(basis_a) * torch.exp(-(x - basis_c) ** 2 * (basis_h ** 2) ) # [b, q_len, channel, num_basis] return temp.sum(dim=-1, keepdim=False) # [b, q_len, channel] def pad_time_cond(t, len): t_sin = torch.cat([torch.sin(w * t) for w in range(1, len + 1)], dim=-1) t_cos = torch.cat([torch.cos(w * t) for w in range(1, len + 1)], dim=-1) t = torch.cat([t_sin, t_cos, t], dim=-1) return t class condition_interpolation(nn.Module): def __init__( self, hidden_size: int = 2048, concept_dim: int = 64, ): super().__init__() self.hidden_size = hidden_size self.concept_dim = concept_dim self.r = 8 # self.len = 15 self.lora = nn.Sequential( nn.Linear(self.hidden_size * 2 + self.concept_dim * 2, self.hidden_size // self.r, bias=False), nn.SiLU(), nn.Linear(self.hidden_size // self.r, self.hidden_size, bias=False) ) nn.init.xavier_uniform_(self.lora[0].weight) nn.init.zeros_(self.lora[2].weight) def forward(self, start, end, h_new): # t = pad_time_cond(t, self.len) x = torch.cat([start, end, h_new, h_new], dim=-1) x = self.lora(x) return x class Task_Aware_Delta_Net(nn.Module): """ The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters. Parameter alloation when use_gate=True: - 0.75 * hidden_size * hidden_size for the q_proj and k_proj each - 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each - Others are ignorably small. - In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim. Parameter allocation when use_gate=False: - 1 * hidden_size * hidden_size for the q_proj and k_proj each - 2 * hidden_size * hidden_size for the v_proj and o_proj each - Others are ignorably small. - In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size Args: hidden_size (int, Optional): The hidden size of the input. Default: 2048. expand_v (float, Optional): The expansion ratio for the value dim. Default: 2.0. head_dim (int, Optional): The dimension of each head. Default: 256. num_heads (int, Optional): The number of heads. Default: 4. mode (str, Optional): Which Gated DeltaNet kernel to use. Currently available: `chunk` and `fused_recurrent`. Default: `chunk`. use_beta (bool, Optional): Whether to use beta. Default: `True`. use_gate (bool, Optional): Whether to use output gate. Default: `True`. use_short_conv (bool, Optional): Whether to use short convolutions. Default: `True`. conv_size (int, Optional): The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. conv_bias (bool, Optional): Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. layer_idx (int, Optional): The index of the layer. Default: None. norm_eps (float, Optional): The epsilon value for the normalization layer. Default: 1e-5. """ def __init__( self, hidden_size: int = 2048, expand_v: float = 1, head_dim: int = 256, num_heads: int = 6, num_heads_delta: int = 6, mode: str = 'chunk', use_gate: bool = True, use_short_conv: bool = True, conv_size: int = 4, conv_bias: bool = False, layer_idx: int = None, norm_eps: float = 1e-5, rope_theta: float = 10000., max_position_embeddings: int = None, window_size: int = None, concept_dim: int = 128, **kwargs: Unpack[Dict] ) -> Task_Aware_Delta_Net: super().__init__() self.split_size = 64 # 64 self.mode = mode self.hidden_size = hidden_size self.expand_v = expand_v self.use_gate = use_gate self.use_short_conv = use_short_conv # self.use_short_conv = False self.conv_size = conv_size self.conv_bias = conv_bias self.head_dim = head_dim self.strict_head = False if self.strict_head: head_dim_delta = int (0.75 * hidden_size / num_heads_delta) head_dim = head_dim_delta self.head_dim_delta = head_dim_delta self.head_dim = head_dim_delta self.num_heads = num_heads self.key_dim = self.num_heads * self.head_dim self.value_dim = self.key_dim * self.expand_v self.head_qk_dim = head_dim self.head_v_dim = head_dim * self.expand_v self.layer_idx = layer_idx self.silu = nn.SiLU() assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False) A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) A_log = torch.log(A) self.A_log = nn.Parameter(A_log) self.A_log._no_weight_decay = True # self.D = nn.Parameter(torch.ones(self.num_heads)) # self.D._no_weight_decay = True # hard coded for now dt_min = 0.001 dt_max = 0.1 dt_init_floor = 1e-4 dt = torch.exp( torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min) ) dt = torch.clamp(dt, min=dt_init_floor) # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759 inv_dt = dt + torch.log(-torch.expm1(-dt)) self.dt_bias = nn.Parameter(inv_dt) # Just to be explicit. Without this we already don't put wd on dt_bias because of the check # name.endswith("bias") in param_grouping.py self.dt_bias._no_weight_decay = True if use_short_conv: self.conv_size = conv_size self.q_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, activation='silu' ) self.k_conv1d = ShortConvolution( hidden_size=self.key_dim, kernel_size=conv_size, activation='silu' ) self.v_conv1d = ShortConvolution( hidden_size=self.value_dim, kernel_size=conv_size, activation='silu' ) else: raise UserWarning( "ShortConvolution is crucial to the performance. " "Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." ) if use_gate: self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) else: self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) self.num_prelude = 2 self.ttt = True if self.ttt and self.layer_idx >= self.num_prelude: # use TTT as cross-layer concept learner self.concept_dim = concept_dim # hidden_size // 8 self.concept_proj = nn.Linear(hidden_size, self.concept_dim * 3, bias=False) self.lr1_proj = nn.Linear(hidden_size, 1, bias=False) self.lr2_proj = nn.Linear(hidden_size, 1, bias=False) # self.router = nn.Linear(hidden_size, self.num_heads * 2, bias=False) # , bias=False self.router2 = nn.Linear(self.concept_dim, self.num_heads * 2, bias=False) self.router3 = nn.Linear(self.concept_dim, 2, bias=False) self.condition_interpolation = condition_interpolation(hidden_size, concept_dim) self.t_proj = nn.Linear(concept_dim, 1, bias=False) # self.num_basis = 2 # self.basis_proj = nn.Linear(self.concept_dim, self.num_basis * 3, bias=False) self.special_mask = nn.Parameter(torch.zeros(self.hidden_size)) # self.special_mask_gated_delta = nn.Parameter(torch.zeros(self.hidden_size)) self.use_bias = True if self.use_bias: self.learnable_bias0 = nn.Parameter(torch.zeros(1)) self.apply(self._initialize_weights) # Initialize LoRA matrices for q, k, v, and o projections using nn.Sequential self.r = 4 self.q_lora = nn.Sequential( nn.Linear(self.hidden_size, self.key_dim // self.r, bias=False), nn.SiLU(), nn.Linear(self.key_dim // self.r, self.key_dim, bias=False) ) nn.init.xavier_uniform_(self.q_lora[0].weight) nn.init.zeros_(self.q_lora[2].weight) self.k_lora = nn.Sequential( nn.Linear(self.hidden_size, self.key_dim // self.r, bias=False), nn.SiLU(), nn.Linear(self.key_dim // self.r, self.key_dim, bias=False) ) nn.init.xavier_uniform_(self.k_lora[0].weight) nn.init.zeros_(self.k_lora[2].weight) self.v_lora = nn.Sequential( nn.Linear(self.hidden_size, self.value_dim // self.r, bias=False), nn.SiLU(), nn.Linear(self.value_dim // self.r, self.value_dim, bias=False) ) nn.init.xavier_uniform_(self.v_lora[0].weight) nn.init.zeros_(self.v_lora[2].weight) self.o_proj_attn = nn.Linear(self.value_dim, self.hidden_size, bias=False) nn.init.xavier_uniform_(self.o_proj_attn.weight, gain=2 ** -2.5) # self.o_proj_attention = nn.Linear(self.value_dim, self.hidden_size, bias=False) self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta) self.window_size = window_size def _initialize_weights(self, module: nn.Module): if getattr(module, "_is_hf_initialized", False): return if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) if module.bias is not None: nn.init.zeros_(module.bias) module._is_hf_initialized = True def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, past_key_values1: Optional[Cache] = None, all_past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, rnn_router: Optional[nn.Module] = None, h_old: Optional[torch.Tensor] = None, params: Optional[Dict] = None, **kwargs: Unpack[Dict] ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache], Optional[torch.Tensor], Optional[torch.Tensor]]: # output: return o, None, past_key_values1, past_key_values2, h_new, params if attention_mask is not None: assert len(attention_mask.shape) == 2, ( "Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " "for padding purposes (0 indicating padding). " "Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." ) mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode # # mode = self.mode # mode = 'chunk' if self.training: assert mode == 'chunk', "Only chunk mode is supported in training." last_state2 = None if all_past_key_values is not None: if all_past_key_values._seen_tokens > 0: past_key_values1, past_key_values2 = all_past_key_values else: from fla.models.utils import Cache past_key_values1, past_key_values2 = Cache(), Cache() if len(past_key_values2) > self.layer_idx: last_state2 = past_key_values2[self.layer_idx] batch_size, q_len, _ = hidden_states.size() cu_seqlens = kwargs.get('cu_seqlens', None) max_seqlen = kwargs.get('max_seqlen', q_len) if self.ttt: flag = True if self.layer_idx < self.num_prelude: # 前2层 (0-1) if flag == True: params = rnn_router.init_params_as_logits(batch_size, q_len) flag = False mask = torch.ones(batch_size, q_len, self.num_heads, 2, device=hidden_states.device).to(hidden_states.dtype) h_new = None special_mask_attn = torch.zeros(batch_size, q_len, 1, device=hidden_states.device).to(hidden_states.dtype) else: concept_qkv = self.concept_proj(hidden_states) concept_q, concept_k, concept_v = concept_qkv.chunk(3, dim=-1) lr_linear = F.sigmoid(self.lr1_proj(hidden_states)) * 1e-2 lr_ln = F.sigmoid(self.lr2_proj(hidden_states)) * 1e-2 # lr_linear = 1e-2 # lr_ln = 1e-2 if rnn_router is not None: params = rnn_router.learn(concept_k, concept_v, params, lr_linear, lr_ln) h_new = rnn_router.predict(concept_q, params) t = F.sigmoid(self.t_proj(h_new)) t_b = 1 - t input_router = self.router2(h_new) # input_router = nn.Softmax(dim=-1)(input_router) # [batch_size, seq_len, head_dim, 2] input_router = F.sigmoid(input_router) # [batch_size, seq_len, head_dim * 2] special_mask = self.router3(h_new) # 添加偏置使第一个位置更容易被选中(通过增加第一个位置的logits值) bias = torch.zeros_like(special_mask) bias[..., 0] = 2.0 if self.use_bias: bias[..., 0] = 2.0 + self.learnable_bias0 # 给第0个位置添加正偏置,使第一个位置更容易被选为0 special_mask = F.gumbel_softmax(special_mask + bias, tau=0.1, hard=True) special_mask_attn = special_mask[:, :, 1].unsqueeze(-1) # [batch_size, seq_len, 1] mask = input_router mask = mask.reshape(batch_size, q_len, self.num_heads, 2) # if self.layer_idx >= self.num_prelude: # hidden_states = hidden_states + special_mask_gated_delta * self.special_mask_gated_delta.reshape(1, 1, -1) if self.use_short_conv: conv_state_q, conv_state_k, conv_state_v = None, None, None if last_state2 is not None: conv_state_q, conv_state_k, conv_state_v = last_state2['conv_state'] conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None # position_ids = kwargs.get('position_ids', None) q_shared = self.q_proj(hidden_states) k_shared = self.k_proj(hidden_states) v_shared = self.v_proj(hidden_states) q, conv_state_q = self.q_conv1d(x=q_shared, mask=conv_mask, cache=conv_state_q, output_final_state=use_cache, cu_seqlens = cu_seqlens ) k, conv_state_k = self.k_conv1d(x=k_shared, mask=conv_mask, cache=conv_state_k, output_final_state=use_cache, cu_seqlens = cu_seqlens ) v, conv_state_v = self.v_conv1d(x=v_shared, mask=conv_mask, cache=conv_state_v, output_final_state=use_cache, cu_seqlens = cu_seqlens ) else: q = self.silu(self.q_proj(hidden_states)) k = self.silu(self.k_proj(hidden_states)) v = self.silu(self.v_proj(hidden_states)) if self.layer_idx >= self.num_prelude: hidden_states_attn = hidden_states + special_mask_attn * self.special_mask.reshape(1, 1, -1) else: hidden_states_attn = hidden_states q_attn = self.q_lora(hidden_states_attn) + q_shared k_attn = self.k_lora(hidden_states_attn) + k_shared v_attn = self.v_lora(hidden_states_attn) + v_shared # q_attn = input_router[:, :, 1].unsqueeze(-1) * q_attn q_attn, k_attn, v_attn = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (q_attn, k_attn, v_attn)) # equivalent to cu_seqlens in `flash_attn` seqlen_offset = 0 # seqlen_offset, max_seqlen = 0, q_len if all_past_key_values is not None: seqlen_offset = past_key_values1.get_seq_length(self.layer_idx) max_seqlen = q_attn.shape[1] + seqlen_offset if attention_mask is not None: # to deliminate the offsets of padding tokens seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]).clamp(min=0) max_seqlen = q_attn.shape[1] + max(seqlen_offset) if self.max_position_embeddings is not None: max_seqlen_rotary = max(max_seqlen, self.max_position_embeddings) else: max_seqlen_rotary = max_seqlen q_attn, k_attn = self.rotary(q_attn, k_attn, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen_rotary, cu_seqlens=cu_seqlens) if all_past_key_values is not None: k_attn, v_attn = past_key_values1.update( attn_state=(k_attn.flatten(-2, -1), v_attn.flatten(-2, -1)), layer_idx=self.layer_idx, offset=q_len, cache_kwargs=dict(window_size=self.window_size) )['attn_state'] k_attn = rearrange(k_attn, '... (h d) -> ... h d', h=self.num_heads) v_attn = rearrange(v_attn, '... (h d) -> ... h d', h=self.num_heads) if flash_attn_func is None: raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") # Contains at least one padding token in the sequence if attention_mask is not None: q_attn, k_attn, v_attn, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q_attn, k_attn, v_attn, attention_mask, q_len) cu_seqlens_q, cu_seqlens_k = cu_seq_lens max_seqlen_q, max_seqlen_k = max_seq_lens o_attn = flash_attn_varlen_func( q_attn, k_attn, v_attn, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, max_seqlen_q=max_seqlen_q, max_seqlen_k=max_seqlen_k, causal=True, window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) ) o_attn = pad_input(o_attn, indices_q, batch_size, q_len) elif cu_seqlens is not None: o_attn = flash_attn_varlen_func( q_attn.squeeze(0), k_attn.squeeze(0), v_attn.squeeze(0), cu_seqlens_q=cu_seqlens, cu_seqlens_k=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_k=max_seqlen, causal=True, window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) ).unsqueeze(0) else: o_attn = flash_attn_func( q_attn, k_attn, v_attn, causal=True, window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) ) # [total, num_heads, head_dim] (total = batch_size * seq_len) if batch_size > 1: o_attn = o_attn.reshape(batch_size, q_len, self.num_heads, self.head_dim) if self.layer_idx >= self.num_prelude: o_attn = torch.einsum("bnh,bnhd->bnhd", mask[:, :, :, 0], o_attn) # [batch_size, seq_len, num_heads, head_dim] o_attn = o_attn.reshape(batch_size, q_len, self.value_dim) # o_attn = self.o_proj_attention(o_attn) o_attn = self.o_proj_attn(o_attn) # + self.o_proj(o_attn) #################################################### end of attention #################################################### k, v = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (k, v)) beta = self.b_proj(hidden_states).sigmoid() g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias) # dealing with padding if attention_mask is not None: beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None]) g = g.mul(attention_mask[:, -g.shape[-2]:, None]) recurrent_state = last_state2['recurrent_state'] if last_state2 is not None else None # if self.layer_idx >= self.num_prelude: # # q_plus_feature = q.clone() # q_safe_exp = safe_exp(q) # q_plus_feature = q + q_safe_exp * if_feature_map # # q_random_feature = random_proj(q, self.down_proj_matrix, self.up_proj_matrix, control_vec) # # q_plus_feature = q_plus_feature + q_random_feature * if_feature_map2 # q_lora = lora_proj(q, self.down_proj_matrix, self.up_proj_matrix, torch.ones_like(control_vec)) # F.sigmoid(control_vec)) # F.sigmoid(control_vec) # q_gaussian_feature = gaussian_basis(q_lora, basis_a, basis_c, basis_h) # q_plus_feature = q_plus_feature + q_gaussian_feature * if_feature_map3 # q = q_plus_feature q = rearrange(q, 'b t (h d) -> b t h d', h=self.num_heads) if mode == 'chunk': o, recurrent_state = chunk_gated_delta_rule( q=q, k=k, v=v, g=g, beta=beta, initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens, head_first=False, use_qk_l2norm_in_kernel=True ) elif mode == 'fused_recurrent': o, recurrent_state = fused_recurrent_gated_delta_rule( q=q, k=k, v=v, g=g, beta=beta, initial_state=recurrent_state, output_final_state=use_cache, cu_seqlens=cu_seqlens, # head_first=False, use_qk_l2norm_in_kernel=True ) if all_past_key_values is not None: past_key_values2.update( recurrent_state=recurrent_state, conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, layer_idx=self.layer_idx, offset=q.shape[1] ) if self.use_gate: g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads) o = self.o_norm(o, g) else: o = self.o_norm(o) if self.layer_idx >= self.num_prelude: o = torch.einsum("bnh,bnhd->bnhd", mask[:, :, :, 1], o) # [batch_size, seq_len, num_heads, head_dim] o_gated_delta = rearrange(o, 'b t h d -> b t (h d)') o_gated_delta = self.o_proj(o_gated_delta) #################################################### end of delta rule #################################################### if self.layer_idx < self.num_prelude: o = o_gated_delta + o_attn else: o = t_b * o_gated_delta + t * o_attn noise_std = t_b * t noise = self.condition_interpolation(o_gated_delta, o_attn, h_new) * noise_std o = o + noise if all_past_key_values is not None: all_past_key_values = (past_key_values1, past_key_values2) return o, None, None, all_past_key_values, h_new, params def _upad_input(self, q, k, v, attention_mask, q_len): seqlens = attention_mask.sum(-1, dtype=torch.int32) indices_k = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_k = seqlens.max().item() cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) batch_size, seq_len, num_key_value_heads, head_dim = k.shape k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k) v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k) if q_len == seq_len: q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k) cu_seqlens_q = cu_seqlens_k max_seqlen_q = max_seqlen_k indices_q = indices_k elif q_len == 1: max_seqlen_q = 1 # There is a memcpy here, that is very bad. cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) indices_q = cu_seqlens_q[:-1] q = q.squeeze(1) else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -q_len:] q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask) return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) if __name__ == "__main__": gated_delta_net_attention = Task_Aware_Delta_Net() q = torch.randn(1, 10, 6, 256) k = torch.randn(1, 10, 6, 256) v = torch.randn(1, 10, 6, 256) print(q.shape, k.shape, v.shape) # 调用forward函数 o, _, _, _ = gated_delta_net_attention.forward(hidden_states=torch.randn(2, 70, 128)) print(o.shape)