|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import math |
|
|
from typing import TYPE_CHECKING, Dict, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from einops import rearrange |
|
|
from torch.nn import functional as F |
|
|
|
|
|
from fla.modules import FusedRMSNormSwishGate, RMSNorm, ShortConvolution |
|
|
from fla.ops.gated_delta_rule import (chunk_gated_delta_rule, |
|
|
fused_recurrent_gated_delta_rule) |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from transformers.processing_utils import Unpack |
|
|
|
|
|
from fla.models.utils import Cache |
|
|
|
|
|
|
|
|
def elu_p1(x): |
|
|
return (F.elu(x, 1., False) + 1.).to(x) |
|
|
|
|
|
def sum_norm(x): |
|
|
return (x / x.sum(-1, keepdim=True)).to(x) |
|
|
|
|
|
from fla.modules import RMSNorm, RotaryEmbedding |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from fla.models.utils import Cache |
|
|
import warnings |
|
|
try: |
|
|
from flash_attn import flash_attn_func, flash_attn_varlen_func |
|
|
from flash_attn.bert_padding import (index_first_axis, pad_input, |
|
|
unpad_input) |
|
|
except ImportError: |
|
|
warnings.warn( |
|
|
"Flash Attention is not installed. Please install it via `pip install flash-attn --no-build-isolation`", |
|
|
category=ImportWarning |
|
|
) |
|
|
flash_attn_func = None |
|
|
|
|
|
|
|
|
|
|
|
def lambda_init_fn(depth): |
|
|
return 0.8 - 0.6 * math.exp(-0.3 * depth) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
import torch |
|
|
from einops import rearrange |
|
|
from fla.ops.linear_attn.utils import normalize_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scattering_mixer_recurrent( |
|
|
q: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
v: torch.Tensor, |
|
|
G0: torch.Tensor, |
|
|
split_size: int, |
|
|
past_kv: Optional[torch.Tensor] = None, |
|
|
beta: Optional[torch.Tensor] = None, |
|
|
|
|
|
scale: Optional[float] = None, |
|
|
normalize: bool = False, |
|
|
order: int = 2, |
|
|
perturb: Optional[torch.Tensor] = None |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
if scale is None: |
|
|
scale = q.shape[-1] ** -0.5 |
|
|
|
|
|
q = rearrange(q, 'b l h (f s) -> b h l s f', s=split_size) * scale |
|
|
k = rearrange(k, 'b l h (f s) -> b h l s f', s=split_size) |
|
|
v = rearrange(v, 'b l h (d s) -> b h l s d', s=split_size) |
|
|
if order == 2: |
|
|
G0 = rearrange(G0, 'b l h d f -> b h l d f') |
|
|
|
|
|
second_term = torch.einsum('b h l s d, b h l d f -> b h l s f', v, G0) |
|
|
G1 = second_term @ k.transpose(-1, -2) |
|
|
kv2 = k.transpose(-1, -2) @ G1 + k.transpose(-1, -2) |
|
|
else: |
|
|
kv2 = k.transpose(-1, -2) |
|
|
kv = kv2 @ v |
|
|
|
|
|
|
|
|
perturb = rearrange(perturb, 'b l h f k -> b h l f k') |
|
|
M = q.transpose(-1, -2) @ q |
|
|
M = perturb @ M |
|
|
M = q @ M |
|
|
q = q + M |
|
|
|
|
|
if past_kv is None: |
|
|
if beta is not None: |
|
|
beta = rearrange(beta, 'b l h -> b h l') |
|
|
beta_cumprod = torch.cumprod(beta, dim=2) |
|
|
|
|
|
beta_cumprod = torch.cat([torch.ones_like(beta_cumprod[:, :, :1]), beta_cumprod[:, :, :-1]], dim=2) |
|
|
|
|
|
beta_cumprod = rearrange(beta_cumprod, 'b h l -> b h l 1 1') |
|
|
kv = kv / beta_cumprod |
|
|
kv = kv.cumsum(2) |
|
|
kv = kv * beta_cumprod |
|
|
else: |
|
|
kv = kv.cumsum(2) |
|
|
o = q @ kv |
|
|
else: |
|
|
if beta is not None: |
|
|
beta = rearrange(beta, 'b l h -> b h l') |
|
|
kv = kv[:, :, -1, :, :] + past_kv * (beta[:, :, -2]).unsqueeze(-1).unsqueeze(-1) |
|
|
else: |
|
|
kv = kv[:, :, -1, :, :] + past_kv |
|
|
o = q @ kv |
|
|
|
|
|
if normalize: |
|
|
o = normalize_output(q * scale, k, o) |
|
|
return rearrange(o, 'b h l s d -> b l h (s d)') , kv |
|
|
|
|
|
def safe_exp(x): |
|
|
return torch.exp(x - torch.max(x,dim=-1,keepdim=True)[0]) |
|
|
|
|
|
def random_proj(q, down_proj_matrix, up_proj_matrix, control_vec): |
|
|
temp = q @ down_proj_matrix |
|
|
temp = temp * control_vec |
|
|
temp = temp @ up_proj_matrix |
|
|
return torch.concat([torch.cos(temp), torch.sin(temp)], dim=-1) |
|
|
|
|
|
def lora_proj(x, down_proj_matrix, up_proj_matrix, control_vec): |
|
|
temp = x @ down_proj_matrix |
|
|
temp = temp * control_vec |
|
|
temp = temp @ up_proj_matrix |
|
|
return temp |
|
|
|
|
|
def gaussian_basis(x, basis_a, basis_c, basis_h): |
|
|
|
|
|
x = x.unsqueeze(-1) |
|
|
|
|
|
|
|
|
|
|
|
eps = 1e-6 |
|
|
temp = F.sigmoid(basis_a) * torch.exp(-(x - basis_c) ** 2 / (2 * basis_h ** 2 + eps)) |
|
|
|
|
|
return temp.sum(dim=-1, keepdim=False) |
|
|
|
|
|
def pad_time_cond(t, len): |
|
|
t_sin = torch.cat([torch.sin(w * t) for w in range(1, len + 1)], dim=-1) |
|
|
t_cos = torch.cat([torch.cos(w * t) for w in range(1, len + 1)], dim=-1) |
|
|
t = torch.cat([t_sin, t_cos, t], dim=-1) |
|
|
return t |
|
|
|
|
|
|
|
|
class condition_interpolation(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int = 2048, |
|
|
concept_dim: int = 64, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.hidden_size = hidden_size |
|
|
self.concept_dim = concept_dim |
|
|
self.r = 8 |
|
|
|
|
|
|
|
|
self.lora = nn.Sequential( |
|
|
nn.Linear(self.hidden_size * 2 + self.concept_dim * 2, self.hidden_size // self.r, bias=False), |
|
|
nn.SiLU(), |
|
|
nn.Linear(self.hidden_size // self.r, self.hidden_size, bias=False) |
|
|
) |
|
|
nn.init.xavier_uniform_(self.lora[0].weight) |
|
|
nn.init.zeros_(self.lora[2].weight) |
|
|
|
|
|
def forward(self, start, end, h_new): |
|
|
|
|
|
|
|
|
x = torch.cat([start, end, h_new, h_new], dim=-1) |
|
|
x = self.lora(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class Task_Aware_Delta_Net(nn.Module): |
|
|
""" |
|
|
The layer implementaion for [Gated Delta Networks: Improving Mamba2 with Delta Rule](https://arxiv.org/abs/2412.06464). # noqa |
|
|
|
|
|
Similar to Mamba2, each layer contains around 6*hidden_size*hidden_size parameters. |
|
|
Parameter alloation when use_gate=True: |
|
|
- 0.75 * hidden_size * hidden_size for the q_proj and k_proj each |
|
|
- 1.5 * hidden_size * hidden_size for the v_proj, g_proj and o_proj each |
|
|
- Others are ignorably small. |
|
|
- In total = 0.75 * 2 + 1.5 * 3 = 6 * hidden_size * hidden_size |
|
|
NOTE: num_heads * head_dim = 0.75 * hidden_size, please make sure to set the correct num_heads and head_dim. |
|
|
|
|
|
Parameter allocation when use_gate=False: |
|
|
- 1 * hidden_size * hidden_size for the q_proj and k_proj each |
|
|
- 2 * hidden_size * hidden_size for the v_proj and o_proj each |
|
|
- Others are ignorably small. |
|
|
- In total = 1 * 2 + 2 * 2 = 6 * hidden_size * hidden_size |
|
|
Args: |
|
|
hidden_size (int, Optional): |
|
|
The hidden size of the input. Default: 2048. |
|
|
expand_v (float, Optional): |
|
|
The expansion ratio for the value dim. Default: 2.0. |
|
|
head_dim (int, Optional): |
|
|
The dimension of each head. Default: 256. |
|
|
num_heads (int, Optional): |
|
|
The number of heads. Default: 4. |
|
|
mode (str, Optional): |
|
|
Which Gated DeltaNet kernel to use. |
|
|
Currently available: `chunk` and `fused_recurrent`. |
|
|
Default: `chunk`. |
|
|
use_beta (bool, Optional): |
|
|
Whether to use beta. Default: `True`. |
|
|
use_gate (bool, Optional): |
|
|
Whether to use output gate. Default: `True`. |
|
|
use_short_conv (bool, Optional): |
|
|
Whether to use short convolutions. Default: `True`. |
|
|
conv_size (int, Optional): |
|
|
The kernel size of the short convolution, only used when `use_short_conv` is `True`. Default: 4. |
|
|
conv_bias (bool, Optional): |
|
|
Whether to use bias in the short convolution, only used when `use_short_conv` is `True`. Default: `False`. |
|
|
layer_idx (int, Optional): |
|
|
The index of the layer. Default: None. |
|
|
norm_eps (float, Optional): |
|
|
The epsilon value for the normalization layer. Default: 1e-5. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int = 2048, |
|
|
expand_v: float = 1, |
|
|
head_dim: int = 256, |
|
|
num_heads: int = 6, |
|
|
num_heads_delta: int = 6, |
|
|
mode: str = 'chunk', |
|
|
use_gate: bool = True, |
|
|
use_short_conv: bool = True, |
|
|
conv_size: int = 4, |
|
|
conv_bias: bool = False, |
|
|
layer_idx: int = None, |
|
|
norm_eps: float = 1e-5, |
|
|
rope_theta: float = 10000., |
|
|
max_position_embeddings: int = None, |
|
|
window_size: int = None, |
|
|
concept_dim: int = 128, |
|
|
**kwargs: Unpack[Dict] |
|
|
) -> Task_Aware_Delta_Net: |
|
|
super().__init__() |
|
|
self.split_size = 64 |
|
|
|
|
|
self.mode = mode |
|
|
self.hidden_size = hidden_size |
|
|
self.expand_v = expand_v |
|
|
|
|
|
self.use_gate = use_gate |
|
|
self.use_short_conv = use_short_conv |
|
|
|
|
|
self.conv_size = conv_size |
|
|
self.conv_bias = conv_bias |
|
|
self.head_dim = head_dim |
|
|
self.strict_head = False |
|
|
if self.strict_head: |
|
|
head_dim_delta = int (0.75 * hidden_size / num_heads_delta) |
|
|
head_dim = head_dim_delta |
|
|
self.head_dim_delta = head_dim_delta |
|
|
self.head_dim = head_dim_delta |
|
|
self.num_heads = num_heads |
|
|
self.key_dim = self.num_heads * self.head_dim |
|
|
self.value_dim = self.key_dim * self.expand_v |
|
|
self.head_qk_dim = head_dim |
|
|
self.head_v_dim = head_dim * self.expand_v |
|
|
self.layer_idx = layer_idx |
|
|
self.silu = nn.SiLU() |
|
|
assert mode in ['chunk', 'fused_recurrent'], f"Not suppoerted mode `{mode}`." |
|
|
self.q_proj = nn.Linear(hidden_size, self.key_dim, bias=False) |
|
|
self.k_proj = nn.Linear(hidden_size, self.key_dim, bias=False) |
|
|
self.v_proj = nn.Linear(hidden_size, self.value_dim, bias=False) |
|
|
self.o_proj = nn.Linear(self.value_dim, hidden_size, bias=False) |
|
|
|
|
|
self.b_proj = nn.Linear(hidden_size, self.num_heads, bias=False) |
|
|
self.a_proj = nn.Linear(hidden_size, self.num_heads, bias=False) |
|
|
A = torch.empty(self.num_heads, dtype=torch.float32).uniform_(0, 16) |
|
|
A_log = torch.log(A) |
|
|
self.A_log = nn.Parameter(A_log) |
|
|
self.A_log._no_weight_decay = True |
|
|
|
|
|
|
|
|
|
|
|
dt_min = 0.001 |
|
|
dt_max = 0.1 |
|
|
dt_init_floor = 1e-4 |
|
|
dt = torch.exp( |
|
|
torch.rand(self.num_heads) * (math.log(dt_max) - math.log(dt_min)) |
|
|
+ math.log(dt_min) |
|
|
) |
|
|
dt = torch.clamp(dt, min=dt_init_floor) |
|
|
|
|
|
inv_dt = dt + torch.log(-torch.expm1(-dt)) |
|
|
self.dt_bias = nn.Parameter(inv_dt) |
|
|
|
|
|
|
|
|
self.dt_bias._no_weight_decay = True |
|
|
if use_short_conv: |
|
|
self.conv_size = conv_size |
|
|
self.q_conv1d = ShortConvolution( |
|
|
hidden_size=self.key_dim, |
|
|
kernel_size=conv_size, |
|
|
activation='silu' |
|
|
) |
|
|
self.k_conv1d = ShortConvolution( |
|
|
hidden_size=self.key_dim, |
|
|
kernel_size=conv_size, |
|
|
activation='silu' |
|
|
) |
|
|
self.v_conv1d = ShortConvolution( |
|
|
hidden_size=self.value_dim, |
|
|
kernel_size=conv_size, |
|
|
activation='silu' |
|
|
) |
|
|
else: |
|
|
raise UserWarning( |
|
|
"ShortConvolution is crucial to the performance. " |
|
|
"Do not turn it off, i.e., setting `use_short_conv=False` unless you know what you are doing." |
|
|
) |
|
|
if use_gate: |
|
|
self.g_proj = nn.Linear(hidden_size, self.value_dim, bias=False) |
|
|
self.o_norm = FusedRMSNormSwishGate(self.head_v_dim, eps=norm_eps) |
|
|
else: |
|
|
self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps) |
|
|
|
|
|
self.num_prelude = 2 |
|
|
self.ttt = True |
|
|
if self.ttt and self.layer_idx >= self.num_prelude: |
|
|
self.concept_dim = concept_dim |
|
|
self.concept_proj = nn.Linear(hidden_size, self.concept_dim * 3, bias=False) |
|
|
self.lr1_proj = nn.Linear(hidden_size, 1, bias=False) |
|
|
self.lr2_proj = nn.Linear(hidden_size, 1, bias=False) |
|
|
|
|
|
self.router2 = nn.Linear(self.concept_dim, self.num_heads * 2, bias=False) |
|
|
self.router3 = nn.Linear(self.concept_dim, 2, bias=False) |
|
|
|
|
|
self.condition_interpolation = condition_interpolation(hidden_size, concept_dim) |
|
|
self.t_proj = nn.Linear(concept_dim, 1, bias=False) |
|
|
|
|
|
|
|
|
|
|
|
self.special_mask = nn.Parameter(torch.zeros(self.hidden_size)) |
|
|
|
|
|
self.use_bias = True |
|
|
if self.use_bias: |
|
|
self.learnable_bias0 = nn.Parameter(torch.zeros(1)) |
|
|
|
|
|
self.apply(self._initialize_weights) |
|
|
|
|
|
self.r = 4 |
|
|
self.q_lora = nn.Sequential( |
|
|
nn.Linear(self.hidden_size, self.key_dim // self.r, bias=False), |
|
|
nn.SiLU(), |
|
|
nn.Linear(self.key_dim // self.r, self.key_dim, bias=False) |
|
|
) |
|
|
nn.init.xavier_uniform_(self.q_lora[0].weight) |
|
|
nn.init.zeros_(self.q_lora[2].weight) |
|
|
self.k_lora = nn.Sequential( |
|
|
nn.Linear(self.hidden_size, self.key_dim // self.r, bias=False), |
|
|
nn.SiLU(), |
|
|
nn.Linear(self.key_dim // self.r, self.key_dim, bias=False) |
|
|
) |
|
|
nn.init.xavier_uniform_(self.k_lora[0].weight) |
|
|
nn.init.zeros_(self.k_lora[2].weight) |
|
|
|
|
|
self.v_lora = nn.Sequential( |
|
|
nn.Linear(self.hidden_size, self.value_dim // self.r, bias=False), |
|
|
nn.SiLU(), |
|
|
nn.Linear(self.value_dim // self.r, self.value_dim, bias=False) |
|
|
) |
|
|
nn.init.xavier_uniform_(self.v_lora[0].weight) |
|
|
nn.init.zeros_(self.v_lora[2].weight) |
|
|
|
|
|
self.o_proj_attn = nn.Linear(self.value_dim, self.hidden_size, bias=False) |
|
|
nn.init.xavier_uniform_(self.o_proj_attn.weight, gain=2 ** -2.5) |
|
|
|
|
|
self.rope_theta = rope_theta |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
self.rotary = RotaryEmbedding(dim=self.head_dim, base=self.rope_theta) |
|
|
self.window_size = window_size |
|
|
|
|
|
|
|
|
def _initialize_weights(self, module: nn.Module): |
|
|
if getattr(module, "_is_hf_initialized", False): |
|
|
return |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight, gain=2 ** -2.5) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
module._is_hf_initialized = True |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_values1: Optional[Cache] = None, |
|
|
all_past_key_values: Optional[Cache] = None, |
|
|
use_cache: Optional[bool] = False, |
|
|
output_attentions: Optional[bool] = False, |
|
|
rnn_router: Optional[nn.Module] = None, |
|
|
h_old: Optional[torch.Tensor] = None, |
|
|
params: Optional[Dict] = None, |
|
|
**kwargs: Unpack[Dict] |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Cache], Optional[torch.Tensor], Optional[torch.Tensor]]: |
|
|
|
|
|
if attention_mask is not None: |
|
|
assert len(attention_mask.shape) == 2, ( |
|
|
"Expected attention_mask as a 0-1 matrix with shape [batch_size, seq_len] " |
|
|
"for padding purposes (0 indicating padding). " |
|
|
"Arbitrary attention masks of shape [batch_size, seq_len, seq_len] are not allowed." |
|
|
) |
|
|
|
|
|
mode = 'fused_recurrent' if hidden_states.shape[1] <= 64 else self.mode |
|
|
|
|
|
|
|
|
if self.training: |
|
|
assert mode == 'chunk', "Only chunk mode is supported in training." |
|
|
last_state2 = None |
|
|
if all_past_key_values is not None: |
|
|
if all_past_key_values._seen_tokens > 0: |
|
|
past_key_values1, past_key_values2 = all_past_key_values |
|
|
else: |
|
|
from fla.models.utils import Cache |
|
|
past_key_values1, past_key_values2 = Cache(), Cache() |
|
|
|
|
|
if len(past_key_values2) > self.layer_idx: |
|
|
last_state2 = past_key_values2[self.layer_idx] |
|
|
batch_size, q_len, _ = hidden_states.size() |
|
|
cu_seqlens = kwargs.get('cu_seqlens', None) |
|
|
max_seqlen = kwargs.get('max_seqlen', q_len) |
|
|
if self.ttt: |
|
|
flag = True |
|
|
if self.layer_idx < self.num_prelude: |
|
|
if flag == True: |
|
|
params = rnn_router.init_params_as_logits(batch_size, q_len) |
|
|
flag = False |
|
|
mask = torch.ones(batch_size, q_len, self.num_heads, 2, device=hidden_states.device).to(hidden_states.dtype) |
|
|
h_new = None |
|
|
special_mask_attn = torch.zeros(batch_size, q_len, 1, device=hidden_states.device).to(hidden_states.dtype) |
|
|
else: |
|
|
concept_qkv = self.concept_proj(hidden_states) |
|
|
concept_q, concept_k, concept_v = concept_qkv.chunk(3, dim=-1) |
|
|
lr_linear = F.sigmoid(self.lr1_proj(hidden_states)) * 1e-2 |
|
|
lr_ln = F.sigmoid(self.lr2_proj(hidden_states)) * 1e-2 |
|
|
|
|
|
|
|
|
if rnn_router is not None: |
|
|
params = rnn_router.learn(concept_k, concept_v, params, lr_linear, lr_ln) |
|
|
|
|
|
h_new = rnn_router.predict(concept_q, params) |
|
|
t = F.sigmoid(self.t_proj(h_new)) |
|
|
t_b = 1 - t |
|
|
|
|
|
input_router = self.router2(h_new) |
|
|
|
|
|
input_router = F.sigmoid(input_router) |
|
|
special_mask = self.router3(h_new) |
|
|
|
|
|
bias = torch.zeros_like(special_mask) |
|
|
bias[..., 0] = 2.0 |
|
|
if self.use_bias: |
|
|
bias[..., 0] = 2.0 + self.learnable_bias0 |
|
|
special_mask = F.gumbel_softmax(special_mask + bias, tau=0.1, hard=True) |
|
|
special_mask_attn = special_mask[:, :, 1].unsqueeze(-1) |
|
|
|
|
|
mask = input_router |
|
|
mask = mask.reshape(batch_size, q_len, self.num_heads, 2) |
|
|
|
|
|
|
|
|
if self.use_short_conv: |
|
|
conv_state_q, conv_state_k, conv_state_v = None, None, None |
|
|
if last_state2 is not None: |
|
|
conv_state_q, conv_state_k, conv_state_v = last_state2['conv_state'] |
|
|
conv_mask = attention_mask[:, -hidden_states.shape[1]:] if attention_mask is not None else None |
|
|
|
|
|
q_shared = self.q_proj(hidden_states) |
|
|
k_shared = self.k_proj(hidden_states) |
|
|
v_shared = self.v_proj(hidden_states) |
|
|
q, conv_state_q = self.q_conv1d(x=q_shared, |
|
|
mask=conv_mask, |
|
|
cache=conv_state_q, |
|
|
output_final_state=use_cache, |
|
|
cu_seqlens = cu_seqlens |
|
|
) |
|
|
k, conv_state_k = self.k_conv1d(x=k_shared, |
|
|
mask=conv_mask, |
|
|
cache=conv_state_k, |
|
|
output_final_state=use_cache, |
|
|
cu_seqlens = cu_seqlens |
|
|
) |
|
|
v, conv_state_v = self.v_conv1d(x=v_shared, |
|
|
mask=conv_mask, |
|
|
cache=conv_state_v, |
|
|
output_final_state=use_cache, |
|
|
cu_seqlens = cu_seqlens |
|
|
) |
|
|
else: |
|
|
q = self.silu(self.q_proj(hidden_states)) |
|
|
k = self.silu(self.k_proj(hidden_states)) |
|
|
v = self.silu(self.v_proj(hidden_states)) |
|
|
|
|
|
if self.layer_idx >= self.num_prelude: |
|
|
hidden_states_attn = hidden_states + special_mask_attn * self.special_mask.reshape(1, 1, -1) |
|
|
else: |
|
|
hidden_states_attn = hidden_states |
|
|
q_attn = self.q_lora(hidden_states_attn) + q_shared |
|
|
k_attn = self.k_lora(hidden_states_attn) + k_shared |
|
|
v_attn = self.v_lora(hidden_states_attn) + v_shared |
|
|
|
|
|
|
|
|
q_attn, k_attn, v_attn = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (q_attn, k_attn, v_attn)) |
|
|
|
|
|
|
|
|
seqlen_offset = 0 |
|
|
|
|
|
if all_past_key_values is not None: |
|
|
seqlen_offset = past_key_values1.get_seq_length(self.layer_idx) |
|
|
max_seqlen = q_attn.shape[1] + seqlen_offset |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
seqlen_offset = (seqlen_offset + attention_mask.sum(-1) - attention_mask.shape[-1]).clamp(min=0) |
|
|
max_seqlen = q_attn.shape[1] + max(seqlen_offset) |
|
|
|
|
|
if self.max_position_embeddings is not None: |
|
|
max_seqlen_rotary = max(max_seqlen, self.max_position_embeddings) |
|
|
else: |
|
|
max_seqlen_rotary = max_seqlen |
|
|
q_attn, k_attn = self.rotary(q_attn, k_attn, seqlen_offset=seqlen_offset, max_seqlen=max_seqlen_rotary, cu_seqlens=cu_seqlens) |
|
|
if all_past_key_values is not None: |
|
|
k_attn, v_attn = past_key_values1.update( |
|
|
attn_state=(k_attn.flatten(-2, -1), v_attn.flatten(-2, -1)), |
|
|
layer_idx=self.layer_idx, |
|
|
offset=q_len, |
|
|
cache_kwargs=dict(window_size=self.window_size) |
|
|
)['attn_state'] |
|
|
k_attn = rearrange(k_attn, '... (h d) -> ... h d', h=self.num_heads) |
|
|
v_attn = rearrange(v_attn, '... (h d) -> ... h d', h=self.num_heads) |
|
|
if flash_attn_func is None: |
|
|
raise ImportError("Please install Flash Attention via `pip install flash-attn --no-build-isolation` first") |
|
|
|
|
|
|
|
|
if attention_mask is not None: |
|
|
q_attn, k_attn, v_attn, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(q_attn, k_attn, v_attn, attention_mask, q_len) |
|
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens |
|
|
max_seqlen_q, max_seqlen_k = max_seq_lens |
|
|
o_attn = flash_attn_varlen_func( |
|
|
q_attn, k_attn, v_attn, |
|
|
cu_seqlens_q=cu_seqlens_q, |
|
|
cu_seqlens_k=cu_seqlens_k, |
|
|
max_seqlen_q=max_seqlen_q, |
|
|
max_seqlen_k=max_seqlen_k, |
|
|
causal=True, |
|
|
window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) |
|
|
) |
|
|
o_attn = pad_input(o_attn, indices_q, batch_size, q_len) |
|
|
elif cu_seqlens is not None: |
|
|
o_attn = flash_attn_varlen_func( |
|
|
q_attn.squeeze(0), k_attn.squeeze(0), v_attn.squeeze(0), |
|
|
cu_seqlens_q=cu_seqlens, |
|
|
cu_seqlens_k=cu_seqlens, |
|
|
max_seqlen_q=max_seqlen, |
|
|
max_seqlen_k=max_seqlen, |
|
|
causal=True, |
|
|
window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) |
|
|
).unsqueeze(0) |
|
|
else: |
|
|
o_attn = flash_attn_func( |
|
|
q_attn, k_attn, v_attn, |
|
|
causal=True, |
|
|
window_size=(-1, -1) if self.window_size is None else (self.window_size-1, 0) |
|
|
) |
|
|
if batch_size > 1: |
|
|
o_attn = o_attn.reshape(batch_size, q_len, self.num_heads, self.head_dim) |
|
|
|
|
|
if self.layer_idx >= self.num_prelude: |
|
|
o_attn = torch.einsum("bnh,bnhd->bnhd", mask[:, :, :, 0], o_attn) |
|
|
|
|
|
o_attn = o_attn.reshape(batch_size, q_len, self.value_dim) |
|
|
|
|
|
o_attn = self.o_proj_attn(o_attn) |
|
|
|
|
|
k, v = map(lambda x: rearrange(x, 'b t (h d) -> b t h d', h=self.num_heads), (k, v)) |
|
|
|
|
|
beta = self.b_proj(hidden_states).sigmoid() |
|
|
g = -self.A_log.float().exp() * F.softplus(self.a_proj(hidden_states).float() + self.dt_bias) |
|
|
|
|
|
if attention_mask is not None: |
|
|
beta = beta.mul(attention_mask[:, -beta.shape[-2]:, None]) |
|
|
g = g.mul(attention_mask[:, -g.shape[-2]:, None]) |
|
|
recurrent_state = last_state2['recurrent_state'] if last_state2 is not None else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
q = rearrange(q, 'b t (h d) -> b t h d', h=self.num_heads) |
|
|
|
|
|
if mode == 'chunk': |
|
|
o, recurrent_state = chunk_gated_delta_rule( |
|
|
q=q, |
|
|
k=k, |
|
|
v=v, |
|
|
g=g, |
|
|
beta=beta, |
|
|
initial_state=recurrent_state, |
|
|
output_final_state=use_cache, |
|
|
cu_seqlens=cu_seqlens, |
|
|
head_first=False, |
|
|
use_qk_l2norm_in_kernel=True |
|
|
) |
|
|
elif mode == 'fused_recurrent': |
|
|
o, recurrent_state = fused_recurrent_gated_delta_rule( |
|
|
q=q, |
|
|
k=k, |
|
|
v=v, |
|
|
g=g, |
|
|
beta=beta, |
|
|
initial_state=recurrent_state, |
|
|
output_final_state=use_cache, |
|
|
cu_seqlens=cu_seqlens, |
|
|
|
|
|
use_qk_l2norm_in_kernel=True |
|
|
) |
|
|
if all_past_key_values is not None: |
|
|
past_key_values2.update( |
|
|
recurrent_state=recurrent_state, |
|
|
conv_state=(conv_state_q, conv_state_k, conv_state_v) if self.use_short_conv else None, |
|
|
layer_idx=self.layer_idx, |
|
|
offset=q.shape[1] |
|
|
) |
|
|
if self.use_gate: |
|
|
g = rearrange(self.g_proj(hidden_states), '... (h d) -> ... h d', h=self.num_heads) |
|
|
o = self.o_norm(o, g) |
|
|
else: |
|
|
o = self.o_norm(o) |
|
|
if self.layer_idx >= self.num_prelude: |
|
|
o = torch.einsum("bnh,bnhd->bnhd", mask[:, :, :, 1], o) |
|
|
o_gated_delta = rearrange(o, 'b t h d -> b t (h d)') |
|
|
o_gated_delta = self.o_proj(o_gated_delta) |
|
|
|
|
|
|
|
|
if self.layer_idx < self.num_prelude: |
|
|
o = o_gated_delta + o_attn |
|
|
else: |
|
|
o = t_b * o_gated_delta + t * o_attn |
|
|
noise_std = t_b * t |
|
|
noise = self.condition_interpolation(o_gated_delta, o_attn, h_new) * noise_std |
|
|
o = o + noise |
|
|
|
|
|
if all_past_key_values is not None: |
|
|
all_past_key_values = (past_key_values1, past_key_values2) |
|
|
return o, None, None, all_past_key_values, h_new, params |
|
|
|
|
|
def _upad_input(self, q, k, v, attention_mask, q_len): |
|
|
seqlens = attention_mask.sum(-1, dtype=torch.int32) |
|
|
indices_k = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
|
|
max_seqlen_k = seqlens.max().item() |
|
|
cu_seqlens_k = F.pad(torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)) |
|
|
batch_size, seq_len, num_key_value_heads, head_dim = k.shape |
|
|
|
|
|
k = index_first_axis(k.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k) |
|
|
v = index_first_axis(v.reshape(batch_size * seq_len, num_key_value_heads, head_dim), indices_k) |
|
|
if q_len == seq_len: |
|
|
q = index_first_axis(q.reshape(batch_size * seq_len, self.num_heads, head_dim), indices_k) |
|
|
cu_seqlens_q = cu_seqlens_k |
|
|
max_seqlen_q = max_seqlen_k |
|
|
indices_q = indices_k |
|
|
elif q_len == 1: |
|
|
max_seqlen_q = 1 |
|
|
|
|
|
cu_seqlens_q = torch.arange(batch_size + 1, dtype=torch.int32, device=q.device) |
|
|
indices_q = cu_seqlens_q[:-1] |
|
|
q = q.squeeze(1) |
|
|
else: |
|
|
|
|
|
attention_mask = attention_mask[:, -q_len:] |
|
|
q, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(q, attention_mask) |
|
|
|
|
|
return q, k, v, indices_q, (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
gated_delta_net_attention = Task_Aware_Delta_Net() |
|
|
q = torch.randn(1, 10, 6, 256) |
|
|
k = torch.randn(1, 10, 6, 256) |
|
|
v = torch.randn(1, 10, 6, 256) |
|
|
print(q.shape, k.shape, v.shape) |
|
|
|
|
|
o, _, _, _ = gated_delta_net_attention.forward(hidden_states=torch.randn(2, 70, 128)) |
|
|
print(o.shape) |